Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions skythought/evals/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .base import ConversationType, TaskConfig, TaskHandler
from .gpqa_diamond.gpqa_diamond_handler import GPQADiamondTaskHandler
from .gsm8k.gsm8k_handler import GSM8KTaskHandler
from .liveaops.liveaops_handler import LiveAOPSTaskHandler
from .livecodebench.livecodebench_handler import LiveCodeBenchTaskHandler
from .math.math_handler import MathTaskHandler
from .minervamath.minervamath_handler import MinervaMathTaskHandler
Expand All @@ -33,6 +34,7 @@
"minervamath": MinervaMathTaskHandler,
"olympiadbench_math": OlympiadBenchMathTaskHandler,
"omni_math": OMNIMathTaskHandler,
"liveaops": LiveAOPSTaskHandler,
}
TASK_NAMES_TO_YAML = get_tasks(os.path.dirname(__file__))

Expand Down
31 changes: 24 additions & 7 deletions skythought/evals/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse

import pandas as pd
import yaml
Expand All @@ -14,7 +15,7 @@ class TaskConfig(BaseModel):
handler: str
dataset_path: str
dataset_subset: Optional[str] = None
dataset_split: str
dataset_split: Optional[str] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
question_key: str
# Optional answer key for datasets with a single correct answer
Expand Down Expand Up @@ -82,12 +83,28 @@ def make_conversations(
return conversations

def load_dataset(self, subset=None, split=None, **kwargs) -> HFDataset:
dataset = load_dataset(
path=self.task_config.dataset_path,
name=subset if subset else self.task_config.dataset_subset,
split=split if split else self.task_config.dataset_split,
**self.task_config.dataset_kwargs,
)
# check if the path provided is a valid URL
parsed = urlparse(self.task_config.dataset_path)
if not parsed.scheme:
# HF dataset
dataset = load_dataset(
path=self.task_config.dataset_path,
name=subset if subset else self.task_config.dataset_subset,
split=split if split else self.task_config.dataset_split,
**self.task_config.dataset_kwargs,
)
else:
# Try to load URL
# Only JSON supported for now
if split is not None or subset is not None:
raise ValueError(
"URL-based dataset does not support loading arguments like `split`, `subset`"
)
# By default, Huggingface will create a DatasetDict object with "train" split
dataset = load_dataset("json", data_files=[self.task_config.dataset_path])[
"train"
]

# add an index column efficiently with map
dataset = dataset.map(add_idx_map, with_indices=True)
return dataset
Expand Down
8 changes: 8 additions & 0 deletions skythought/evals/tasks/liveaops/liveaops.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
handler: liveaops
dataset_path: https://livemathbench.github.io/data/LiveAoPSBench-2024.jsonl
dataset_subset: null # which subset on huggingface. Not applicable for a URL dataset
dataset_split: null # Rule based evaluation
question_key: question
answer_key: answer
templating_parameters:
template: "Return your final response within \\boxed{{}}. {question}"
26 changes: 26 additions & 0 deletions skythought/evals/tasks/liveaops/liveaops_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from skythought.evals.util.math_parsing_util import (
extract_answer,
math_equal,
strip_answer_string,
)

from ..math.math_handler import MathTaskHandler


class LiveAOPSTaskHandler(MathTaskHandler):
def generate_prompt(self, problem):
return self.task_config.templating_parameters["template"].format(**problem)

def check_correctness(self, problem, generation):
# no preprocessing needed
answer = problem[self.task_config.answer_key]
pred = extract_answer(generation)
pred = strip_answer_string(pred)
return math_equal(pred, answer)

def load_and_filter_dataset(
self, start, end, split=None, subset=None, difficulty=None
):
assert difficulty is None, "LiveAOPS does not support `difficulty` argument"
dataset = self.load_dataset(subset=subset, split=split).to_pandas()
return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:]
2 changes: 1 addition & 1 deletion skythought/evals/tasks/omni_math/omni_math.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
handler: math
handler: omni_math
dataset_path: "KbsdJames/Omni-MATH" # repo ID in huggingface
dataset_subset: null # which subset on huggingface
dataset_split: test_rule_based # Rule based evaluation
Expand Down