Skip to content

Conversation

@slimfrkha
Copy link
Contributor

@slimfrkha slimfrkha commented Jun 9, 2025

Problem:

When evaluating reasoning models, we use n_repeat to generate for the same prompt several DIFFERENT answers. To ensure variability in generations for the same prompt, a different seed is given to each n_repeat value.
A recent PR was introduced in evalchemy to call vllm once instead of calling it for every n_repeat. This make evaluations way faster.
The problem is: because Collator currently splits on different seed values, this optimization is useless.

collator = Collator()
chunks = ... # split requests with collator using different gen_kwargs values - including seed
for chunk in chunks: # there are n_repeat chunks instead of 1 because of different seeds
    self._model_generate(chunk) # very slow + in case DP is enabled, models will be loaded dp*n_repeat times ..

Solution:

In Collator, ignore seed when splitting batch in chunks with groupby. Seed doesn't have any incidence on the correctness of generations in VLLM (which is not the case for others args like temperature etc)

Key Changes

with n_repeat=64, x15 speedup in evaluting AIME25 for a ~7b/10b model

Tests

Reproduced AIME25 results for an open source reasoning model (skywork-or1) in evalchemy.

@CLAassistant
Copy link

CLAassistant commented Jun 9, 2025

CLA assistant check
All committers have signed the CLA.

@slimfrkha
Copy link
Contributor Author

Hi @baberabb @StellaAthena
Could you take a quick look at this PR when you get a chance? Just need your feedback to decide whether to keep it open or close it. Thanks!

@baberabb
Copy link
Contributor

baberabb commented Aug 25, 2025

Hi @baberabb @StellaAthena Could you take a quick look at this PR when you get a chance? Just need your feedback to decide whether to keep it open or close it. Thanks!

Sorry, missed this! will take a closer look.

Does this pass an individual sampling_params for each sample now? If so we can set group_by=None in the Collator, and it should still batch and sort the requests by context length.

@slimfrkha
Copy link
Contributor Author

slimfrkha commented Aug 26, 2025

yes in the AIME25 benchmark for instance, we pass a sampling_params per sample.
I didnt want to opt for group_by=None because you would still want to groupby sample by the remaining gen kwargs. Imagine a batch with different temperature and seeds. we would want to split by temperature but ignore the seed.
if we hard code group_by to None, then a batch with different temp would give incorrect output. unless we expose group_by in generate_until for vllm_causallm. But this we'll break the standardized generate_until definition between all models classes.

@baberabb
Copy link
Contributor

baberabb commented Aug 27, 2025

My understanding is vllm's continuous batching does allow each request to have independent generation params (so for example temp=1 for one request and temp=0.6,seed=2 for another). And as we are now passing a list of sampling_params, so grouping the requests seems unnecessary.

And it wouldn't really break the generate_until api, as each item does have it's own respective gen_kwargs dict. Previously it wasn't technically possible feasible with HF, where we require the same params for all items in the batch (unless using bsz 1).

Please correct me if I'm misunderstanding anything here!

@slimfrkha
Copy link
Contributor Author

humm, so you're suggesting the following ?

  1. replace gen_kwargs value with None here to stop splitting on different sampling params values
  2. manage passing a list of sampling_params instead of one by fixing this, which this PR does

Should we also update other models (sglang ??) to use the same logic if they support continuous batching ?

@baberabb
Copy link
Contributor

humm, so you're suggesting the following ?

  1. replace gen_kwargs value with None here to stop splitting on different sampling params values
  2. manage passing a list of sampling_params instead of one by fixing this, which this PR does

Should we also update other models (sglang ??) to use the same logic if they support continuous batching ?

yes! left a comment if it makes it more clear. And, would definitely appreciate if we could use the same approach for slang as well, if they support it. They use something similar to continuous batching, but I'm not as familiar with their API.

@slimfrkha
Copy link
Contributor Author

@baberabb
I made the changes. should be ok now.
FYI:

  • had a look at other models apis. only vllm and sglang would support the non splitting
  • sglang accepts a list of sampling params as seen here

@baberabb
Copy link
Contributor

baberabb commented Sep 8, 2025

@baberabb I made the changes. should be ok now. FYI:

  • had a look at other models apis. only vllm and sglang would support the non splitting
  • sglang accepts a list of sampling params as seen here

Thanks for the PR! This is great! Tested on my end as well!

@baberabb baberabb merged commit 4439847 into EleutherAI:main Sep 8, 2025
6 checks passed
@fxmarty-amd
Copy link
Contributor

you should add "" to type hints that come from conditional imports

@Dornavineeth
Copy link
Contributor

Dornavineeth commented Sep 16, 2025

@baberabb @slimfrkha. Thanks for the good work. I am thinking of working on this PR to migrate lm_eval version in evalchemy to this commit. Could you share the test script you used to verify this PR using data_parallel_size>1.

I set VLLM_USE_V1=0 and vllm=0.7.0 and get the following error.

2025-09-16 23:35:09,781 INFO worker.py:1942 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
(pid=2507019) INFO 09-16 23:35:32 __init__.py:183] Automatically detected platform cuda.
(run_inference_one_model pid=2507019) Model Args: {'model': 'open-thoughts/OpenThinker3-1.5B', 'gpu_memory_utilization': 0.9, 'revision': None, 'dtype': 'auto', 'tokenizer': None, 'tokenizer_mode': 'auto', 'tokenizer_revision': None, 'trust_remote_code': False, 'tensor_parallel_size': 1, 'max_model_len': None, 'max_num_seqs': None, 'swap_space': 4, 'quantization': None, 'seed': 1234, 'enable_lora': False, 'max_lora_rank': 16, 'distributed_executor_backend': 'ray'}
(run_inference_one_model pid=2507022) `torch_dtype` is deprecated! Use `dtype` instead!
(run_inference_one_model pid=2507019) INFO 09-16 23:35:57 config.py:520] This model supports multiple tasks: {'generate', 'score', 'embed', 'reward', 'classify'}. Defaulting to 'generate'.
(run_inference_one_model pid=2507019) INFO 09-16 23:35:57 llm_engine.py:232] Initializing an LLM engine (v0.7.0) with config: model='open-thoughts/OpenThinker3-1.5B', speculative_config=None, tokenizer='open-thoughts/OpenThinker3-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=1234, served_model_name=open-thoughts/OpenThinker3-1.5B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False, 
(pid=2507022) INFO 09-16 23:35:32 __init__.py:183] Automatically detected platform cuda.
(run_inference_one_model pid=2507022) Model Args: {'model': 'open-thoughts/OpenThinker3-1.5B', 'gpu_memory_utilization': 0.9, 'revision': None, 'dtype': 'auto', 'tokenizer': None, 'tokenizer_mode': 'auto', 'tokenizer_revision': None, 'trust_remote_code': False, 'tensor_parallel_size': 1, 'max_model_len': None, 'max_num_seqs': None, 'swap_space': 4, 'quantization': None, 'seed': 1234, 'enable_lora': False, 'max_lora_rank': 16, 'distributed_executor_backend': 'ray'}
(run_inference_one_model pid=2507022) INFO 09-16 23:35:57 config.py:520] This model supports multiple tasks: {'score', 'classify', 'reward', 'embed', 'generate'}. Defaulting to 'generate'.
(run_inference_one_model pid=2507019) INFO 09-16 23:35:59 cuda.py:174] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
(run_inference_one_model pid=2507019) INFO 09-16 23:35:59 cuda.py:222] Using XFormers backend.
Traceback (most recent call last):
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "<PATH>/reasoning_refactor/evalchemy/eval/eval.py", line 627, in <module>
    cli_evaluate()
  File "<PATH>/reasoning_refactor/evalchemy/eval/eval.py", line 412, in cli_evaluate
    results = evaluate(
  File "<PATH>/reasoning_refactor/evalchemy/eval/eval.py", line 193, in evaluate
    result = method(lm)
  File "<PATH>/reasoning_refactor/evalchemy/eval/chat_benchmarks/AIME24/eval_instruct.py", line 102, in generate_responses
    all_outputs = self.compute(model, all_instances)
  File "<PATH>/reasoning_refactor/evalchemy/eval/task.py", line 94, in compute
    results = model.generate_until(prompts)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/lm_eval/models/vllm_causallms.py", line 642, in generate_until
    cont = self._model_generate(
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/lm_eval/models/vllm_causallms.py", line 409, in _model_generate
    results = ray.get(object_refs)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
    return func(*args, **kwargs)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/ray/_private/worker.py", line 2882, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/ray/_private/worker.py", line 968, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::run_inference_one_model() (pid=2507022, ip=10.1.15.201)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/lm_eval/models/vllm_causallms.py", line 391, in run_inference_one_model
    llm = LLM(**model_args)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/vllm/utils.py", line 1039, in inner
    return fn(*args, **kwargs)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 239, in __init__
    self.llm_engine = self.engine_class.from_engine_args(
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 482, in from_engine_args
    engine = cls(
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 271, in __init__
    self.model_executor = executor_class(vllm_config=vllm_config, )
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 49, in __init__
    self._init_executor()
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 39, in _init_executor
    self.collective_rpc("init_device")
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 49, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/vllm/utils.py", line 2208, in run_method
    return func(*args, **kwargs)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/vllm/worker/worker.py", line 154, in init_device
    torch.cuda.set_device(self.device)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/torch/cuda/__init__.py", line 478, in set_device
    torch._C._cuda_setDevice(device)
  File "<PATH>/miniconda3/envs/evalchemy3/lib/python3.10/site-packages/torch/cuda/__init__.py", line 319, in _lazy_init
    torch._C._cuda_init()
RuntimeError: No CUDA GPUs are available
(run_inference_one_model pid=2507022) INFO 09-16 23:35:57 llm_engine.py:232] Initializing an LLM engine (v0.7.0) with config: model='open-thoughts/OpenThinker3-1.5B', speculative_config=None, tokenizer='open-thoughts/OpenThinker3-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=1234, served_model_name=open-thoughts/OpenThinker3-1.5B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False, 
(run_inference_one_model pid=2507022) INFO 09-16 23:35:59 cuda.py:174] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
(run_inference_one_model pid=2507022) INFO 09-16 23:35:59 cuda.py:222] Using XFormers backend.
(run_inference_one_model pid=2507019) `torch_dtype` is deprecated! Use `dtype` instead!
Running generate_until requests:   0%|                                                                                                       | 0/300 [01:03<?, ?it/s]

@slimfrkha
Copy link
Contributor Author

that PR is from a collegue of mine who also opened another PR to enable CUDA Graphs with vLLM Data Parallel when dp > 1.
we use these 3 fixes (2 PR in lm eval + the one in evalchemy) to test benchs like AIME end to end.

@Dornavineeth
Copy link
Contributor

that PR is from a collegue of mine who also opened another PR to enable CUDA Graphs with vLLM Data Parallel when dp > 1. we use these 3 fixes (2 PR in lm eval + the one in evalchemy) to test benchs like AIME end to end.

Okay, thanks for the quick reply @slimfrkha.

Also, when you say “PR in evalchemy,” are you referring to this PR? It seems incomplete to me, just minor fixes like updating the lm_eval version in pyproject.toml and a few other small changes.

P.S. I resolved the issue of mine setting vllm=0.10.0.

@slimfrkha
Copy link
Contributor Author

slimfrkha commented Sep 17, 2025

yeah that PR
Evalchemy, it seems, is no longer maintained. maintainers dont respond anymore so PR is working but was not finalized.

JessicaOjo pushed a commit to JessicaOjo/lm-evaluation-harness that referenced this pull request Dec 10, 2025
)

* feat(vllm_causallms): make collator ignore seed when splitting batch into chunks

* fix(collator): revert PR changes

* fix(vllm-causallm): update collator call with groupby None

* feat(sglang-causallms): make generation accept a list of sampling params

---------

Co-authored-by: Baber <[email protected]>
JessicaOjo pushed a commit to JessicaOjo/lm-evaluation-harness that referenced this pull request Dec 10, 2025
)

* feat(vllm_causallms): make collator ignore seed when splitting batch into chunks

* fix(collator): revert PR changes

* fix(vllm-causallm): update collator call with groupby None

* feat(sglang-causallms): make generation accept a list of sampling params

---------

Co-authored-by: Baber <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants