Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
631 changes: 508 additions & 123 deletions run.py

Large diffs are not rendered by default.

397 changes: 0 additions & 397 deletions run_api.py

This file was deleted.

4 changes: 4 additions & 0 deletions vlmeval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import ssl
import warnings

# Ignore pkg_resources warning due to jieba depends on it.
warnings.filterwarnings("ignore", category=UserWarning, message="pkg_resources is deprecated")

# Temporarily bypass SSL certificate verification to download files from oss.
ssl._create_default_https_context = ssl._create_unverified_context
Expand Down
2 changes: 1 addition & 1 deletion vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@

interns1_mini = {
"Intern-S1-mini": partial(
vlm.InternS1Chat, model_path="/mnt/shared-storage-user/mllm/lijinsong/models/Intern-S1-mini/"
vlm.InternS1Chat, model_path="internlm/Intern-S1-mini"
),
}

Expand Down
57 changes: 37 additions & 20 deletions vlmeval/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from tqdm import tqdm

from vlmeval.config import supported_VLM
from vlmeval.smp import (dump, get_pred_file_format, get_pred_file_path, get_rank_and_world_size,
load)
from vlmeval.smp import (dump, get_logger, get_pred_file_format, get_pred_file_path,
get_rank_and_world_size, load)
from vlmeval.utils import track_progress_rich

logger = get_logger(__name__)
FAIL_MSG = 'Failed to obtain answer via API.'


Expand All @@ -26,7 +27,7 @@ def parse_args():


# Only API model is accepted
def infer_data_api(model, work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
def infer_data_api(model, work_dir, model_name, dataset, index_set=None, api_nproc=4, retry_failed=True):
rank, world_size = get_rank_and_world_size()
assert rank == 0 and world_size == 1
dataset_name = dataset.dataset_name
Expand All @@ -40,6 +41,8 @@ def infer_data_api(model, work_dir, model_name, dataset, index_set=None, api_npr
model.set_dump_image(dataset.dump_image)

lt, indices = len(data), list(data['index'])
# Build str→orig mapping for checkpoint key conversion
index_str_to_orig = {str(i): i for i in indices}

structs = []
for i in range(lt):
Expand All @@ -53,7 +56,7 @@ def infer_data_api(model, work_dir, model_name, dataset, index_set=None, api_npr
struct = dataset.build_prompt(item)
structs.append(struct)

out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
out_file = f'{work_dir}/{model_name}_{dataset_name}_checkpoint.pkl'

# To reuse records in MMBench_V11
if dataset_name in ['MMBench', 'MMBench_CN']:
Expand All @@ -62,35 +65,41 @@ def infer_data_api(model, work_dir, model_name, dataset, index_set=None, api_npr
if osp.exists(v11_pred):
try:
reuse_inds = load('https://opencompass.openxlab.space/utils/mmb_reuse.pkl')
data = load(v11_pred)
ans_map = {x: y for x, y in zip(data['index'], data['prediction']) if x in reuse_inds}
data_v11 = load(v11_pred)
ans_map = {str(x): y for x, y in zip(data_v11['index'], data_v11['prediction']) if x in reuse_inds}
dump(ans_map, out_file)
except Exception as err:
print(type(err), err)

res = {}
if osp.exists(out_file):
res = load(out_file)
if ignore_failed:
if retry_failed:
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
logger.info(f'Reuse {len(res)} inference results from previous run.')

structs = [s for i, s in zip(indices, structs) if i not in res]
indices = [i for i in indices if i not in res]
structs = [s for i, s in zip(indices, structs) if str(i) not in res]
indices = [i for i in indices if str(i) not in res]

gen_func = model.generate
structs = [dict(message=struct, dataset=dataset_name) for struct in structs]

if len(structs):
track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
str_indices = [str(i) for i in indices]
track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=str_indices)

res = load(out_file)
# Load the full accumulated results (str keys)
if osp.exists(out_file):
res = load(out_file)
# Convert str keys back to original types for caller compatibility
result = {index_str_to_orig[k]: v for k, v in res.items() if k in index_str_to_orig}
if index_set is not None:
res = {k: v for k, v in res.items() if k in index_set}
os.remove(out_file)
return res
result = {k: v for k, v in result.items() if k in index_set}
return result


def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4, use_vllm=False):
def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4, use_vllm=False,
retry_failed=True):
dataset_name = dataset.dataset_name
prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
res = load(prev_file) if osp.exists(prev_file) else {}
Expand Down Expand Up @@ -144,7 +153,8 @@ def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, ap
model_name=model_name,
dataset=dataset,
index_set=set(indices),
api_nproc=api_nproc)
api_nproc=api_nproc,
retry_failed=retry_failed)
for idx in indices:
assert idx in supp
res.update(supp)
Expand Down Expand Up @@ -198,7 +208,7 @@ def _is_structured_record(v):

# A wrapper for infer_data, do the pre & post processing
def infer_data_job(
model, work_dir, model_name, dataset, verbose=False, api_nproc=4, ignore_failed=False, use_vllm=False
model, work_dir, model_name, dataset, verbose=False, api_nproc=4, retry_failed=True, use_vllm=False
):
rank, world_size = get_rank_and_world_size()
dataset_name = dataset.dataset_name
Expand All @@ -209,9 +219,8 @@ def infer_data_job(
if osp.exists(result_file):
if rank == 0:
data = load(result_file)
# breakpoint()
results = {k: v for k, v in zip(data['index'], data['prediction'])}
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load(result_file) can return a list of records when PRED_FORMAT=json (because dump() serializes DataFrames to to_dict('records') for JSON). This code assumes data is a DataFrame/dict with data['index'] and data['prediction'], which will raise at runtime when resuming from a JSON result file. Consider normalizing data to a DataFrame (or reusing the shared prediction-table parsing logic in vlmeval.smp.file._prediction_table) before zipping index/prediction.

Suggested change
results = {k: v for k, v in zip(data['index'], data['prediction'])}
# Normalize loaded prediction table to a dict[index] -> prediction
if isinstance(data, list):
# JSON format: list of record dicts, e.g. [{'index': ..., 'prediction': ...}, ...]
results = {
rec['index']: rec['prediction']
for rec in data
if isinstance(rec, dict) and 'index' in rec and 'prediction' in rec
}
else:
# DataFrame-like or dict-like with 'index' and 'prediction' columns/keys
try:
indices = data['index']
predictions = data['prediction']
results = {k: v for k, v in zip(indices, predictions)}
except (TypeError, KeyError):
# Unexpected structure; fall back to empty results to avoid crashing
results = {}

Copilot uses AI. Check for mistakes.
if not ignore_failed:
if retry_failed:
results = {k: v for k, v in results.items() if FAIL_MSG not in str(v)}
dump(results, prev_file)
if world_size > 1:
Expand All @@ -222,7 +231,8 @@ def infer_data_job(

model = infer_data(
model=model, work_dir=work_dir, model_name=model_name, dataset=dataset,
out_file=out_file, verbose=verbose, api_nproc=api_nproc, use_vllm=use_vllm)
out_file=out_file, verbose=verbose, api_nproc=api_nproc, use_vllm=use_vllm,
retry_failed=retry_failed)
if world_size > 1:
dist.barrier()

Expand Down Expand Up @@ -275,6 +285,13 @@ def split_thinking(s):
dump(data, result_file)
for i in range(world_size):
os.remove(tmpl.format(i))
# Clean up API checkpoint file
checkpoint_file = f'{work_dir}/{model_name}_{dataset_name}_checkpoint.pkl'
if osp.exists(checkpoint_file):
os.remove(checkpoint_file)
# Clean up PREV file
if osp.exists(prev_file):
os.remove(prev_file)
if world_size > 1:
dist.barrier()
return model
100 changes: 83 additions & 17 deletions vlmeval/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
run_infer: bool = True,
run_eval: bool = True,
debug: bool = False,
retry_failed: bool = True,
):
"""
Args:
Expand All @@ -159,13 +160,15 @@ def __init__(
run_infer: Whether to inference.
run_eval: Whether to eval.
debug: Debug mode (Evaluate in the main process).
retry_failed: Whether to retry previously failed samples.
"""
self.dataset_configs = dataset_configs
self.concurrency = concurrency
self.monitor_interval = monitor_interval
self.run_infer = run_infer
self.run_eval = run_eval
self.debug = debug
self.retry_failed = retry_failed
self.all_infer_done = False

self.infer_executor = ThreadPoolExecutor(max_workers=concurrency)
Expand Down Expand Up @@ -206,22 +209,43 @@ def _release_dataset_memory(self, cfg: DatasetConfig):
logger.warning(f" [{dataset_name}] Failed to release dataset memory: {e}")

def _shutdown_executors(self):
"""Shutdown all Executor."""
"""Shutdown all executors and terminate child processes."""
try:
self.infer_executor.shutdown(wait=False)
self.infer_executor.shutdown(wait=False, cancel_futures=True)
logger.debug("Shutdown infer_executor")

self.eval_executor.shutdown(wait=False)
logger.debug("Shutdown eval_executor")
for name, executor in [
("eval_executor", self.eval_executor),
("producer_executor", self.producer_executor),
]:
# 必须在 shutdown() 前获取进程引用。shutdown() 会唤醒管理线程
# 执行清理,可能将 _processes 置为 None,之后就无法访问了。
processes = getattr(executor, '_processes', None) or {}
alive = [p for p in processes.values() if p.is_alive()]
executor.shutdown(wait=False, cancel_futures=True)
self._terminate_workers(name, alive)

self.producer_executor.shutdown(wait=False)
logger.debug("Shutdown producer_executor")

logger.info("🧹 All executors shutdown")
logger.info("All executors shutdown")

except Exception as e:
logger.warning(f"Failed to shutdown executors: {e}")

@staticmethod
def _terminate_workers(name, alive, timeout=5):
"""Terminate worker processes.

Sends SIGTERM first, waits up to *timeout* seconds, then SIGKILL
for any process that is still alive.
"""
for p in alive:
logger.debug(f"Terminating {name} worker (pid={p.pid})")
p.terminate()
for p in alive:
p.join(timeout=timeout)
if p.is_alive():
logger.debug(f"Force killing {name} worker (pid={p.pid})")
p.kill()

def _get_checkpoint_file(self, dataset_name: str) -> Path:
cfg = self.states[dataset_name]
return Path(cfg.work_dir) / f"{cfg.model_name}_{dataset_name}_checkpoint.pkl"
Expand All @@ -236,6 +260,8 @@ def _load_checkpoint(self, dataset_name: str) -> Dict[str, Any]:
if checkpoint_file.exists():
try:
results = load(str(checkpoint_file))
if self.retry_failed:
results = {k: v for k, v in results.items() if FAIL_MSG not in str(v)}
logger.info(f" [{dataset_name}] Loaded {len(results)} results from checkpoint")
except Exception as e:
logger.warning(f" [{dataset_name}] Failed to load checkpoint: {e}")
Expand All @@ -246,11 +272,17 @@ def _load_checkpoint(self, dataset_name: str) -> Dict[str, Any]:
try:
data = load(str(result_path))
if isinstance(data, pd.DataFrame):
existing_results = {
str(idx): pred
for idx, pred in zip(data['index'], data['prediction'])
if FAIL_MSG not in str(pred)
}
if self.retry_failed:
existing_results = {
str(idx): pred
for idx, pred in zip(data['index'], data['prediction'])
if FAIL_MSG not in str(pred)
}
else:
existing_results = {
str(idx): pred
for idx, pred in zip(data['index'], data['prediction'])
}
results.update(existing_results)
logger.info(f" [{dataset_name}] Loaded {len(existing_results)} "
"results from result file")
Expand Down Expand Up @@ -309,8 +341,37 @@ def _save_final_result(self, dataset_name: str) -> bool:

return True

def _create_symlinks(self, dataset_name: str):
"""Create symbolic links for dataset results in the model base directory.

Links are created as relative paths so that moving the output root
directory does not break them.
"""
cfg = self.states[dataset_name]
pred_root = Path(cfg.work_dir)
model_base_dir = pred_root.parent

try:
if not pred_root.exists():
return
for f in pred_root.iterdir():
if not f.is_file():
continue
if f'{cfg.model_name}_{dataset_name}' not in f.name:
continue
# Skip temporary intermediate files
if f.name.endswith(('_checkpoint.pkl', '_PREV.pkl', '_structs.pkl')):
continue
link_addr = model_base_dir / f.name
rel_target = f.relative_to(model_base_dir)
if link_addr.exists() or link_addr.is_symlink():
link_addr.unlink()
link_addr.symlink_to(rel_target)
except Exception as e:
logger.warning(f" [{dataset_name}] Failed to create symlinks: {e}")

async def _producer(self):
"""Genearte all sampels to inference."""
"""Generate all samples to inference."""
logger.info("📦 Initializing tasks and checking checkpoints...")

for cfg in self.dataset_configs:
Expand Down Expand Up @@ -339,6 +400,7 @@ async def _producer(self):
# Save result file if not exists.
if not Path(cfg.result_file).exists():
self._save_final_result(dataset_name)
self._create_symlinks(dataset_name)
# Trigger evaluation.
asyncio.create_task(self._trigger_eval(dataset_name))
continue
Expand Down Expand Up @@ -566,9 +628,11 @@ def inference_call():
f"[{task.dataset_name}] Sample {task.sample_index}: "
f"{output_preview} (took {inference_time:.2f}s)")

# Trigger evaluation if all tasks of the dataset is done.
if cfg.processed == cfg.total_samples and cfg.eval_status == EvalStatus.Pending:
if self._save_final_result(task.dataset_name):
# Save final result and create symlinks when all samples are done.
if cfg.processed == cfg.total_samples:
self._save_final_result(task.dataset_name)
self._create_symlinks(task.dataset_name)
if cfg.eval_status == EvalStatus.Pending:
asyncio.create_task(self._trigger_eval(task.dataset_name))

except Exception as e:
Expand Down Expand Up @@ -677,6 +741,8 @@ def _handle_eval_result(self,
f" Check log file: {eval_log_path}"
)

# Update symlinks to capture evaluation output files.
self._create_symlinks(dataset_name)
# Release dataset data after evaluation.
self._release_dataset_memory(cfg)

Expand Down
Loading
Loading