Skip to content

Commit 302b973

Browse files
authored
Eval Support Update (#5658)
* fix(evaluate): Remove GPT OSS model evaluation restriction Remove the check that blocked evaluation for openai-reasoning-gpt-oss-20b and openai-reasoning-gpt-oss-120b base models. * test(evaluate): Update GPT OSS tests to verify models are allowed Update TestGPTOSSModelValidation to assert that openai-reasoning-gpt-oss-20b and openai-reasoning-gpt-oss-120b models can be used for evaluation, matching the removal of the restriction in base_evaluator.
1 parent 27f4b43 commit 302b973

File tree

2 files changed

+21
-27
lines changed

2 files changed

+21
-27
lines changed

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,6 @@ def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: d
314314
sagemaker_session=session
315315
)
316316

317-
# Check if model is GPT OSS (not supported for evaluation)
318-
if model_info.base_model_name in ["openai-reasoning-gpt-oss-20b", "openai-reasoning-gpt-oss-120b"]:
319-
raise ValueError(
320-
"Evaluation is currently not supported for models created from GPT OSS 20B base model"
321-
)
322-
323317
# If model is a ModelPackage object or ARN (has source_model_package_arn),
324318
# validate that the resolved base_model_arn is a hub content ARN
325319
if model_info.source_model_package_arn:

sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,43 +1012,43 @@ def test_evaluate_not_implemented(self, mock_resolve, mock_session, mock_model_i
10121012

10131013

10141014
class TestGPTOSSModelValidation:
1015-
"""Tests for GPT OSS model validation."""
1015+
"""Tests for GPT OSS model validation - models should be allowed for evaluation."""
10161016

10171017
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
1018-
def test_gpt_oss_20b_model_blocked(self, mock_resolve, mock_session):
1019-
"""Test that GPT OSS 20B model is blocked from evaluation."""
1018+
def test_gpt_oss_20b_model_allowed(self, mock_resolve, mock_session):
1019+
"""Test that GPT OSS 20B model is allowed for evaluation."""
10201020
mock_info = MagicMock()
10211021
mock_info.base_model_name = "openai-reasoning-gpt-oss-20b"
10221022
mock_info.base_model_arn = DEFAULT_HUB_CONTENT_ARN
10231023
mock_info.source_model_package_arn = None
10241024
mock_resolve.return_value = mock_info
10251025

1026-
with pytest.raises(ValidationError, match="Evaluation is currently not supported for models created from GPT OSS 20B base model"):
1027-
BaseEvaluator(
1028-
model="openai-reasoning-gpt-oss-20b",
1029-
s3_output_path=DEFAULT_S3_OUTPUT,
1030-
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1031-
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
1032-
sagemaker_session=mock_session,
1033-
)
1026+
evaluator = BaseEvaluator(
1027+
model="openai-reasoning-gpt-oss-20b",
1028+
s3_output_path=DEFAULT_S3_OUTPUT,
1029+
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1030+
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
1031+
sagemaker_session=mock_session,
1032+
)
1033+
assert evaluator.model == "openai-reasoning-gpt-oss-20b"
10341034

10351035
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
1036-
def test_gpt_oss_120b_model_blocked(self, mock_resolve, mock_session):
1037-
"""Test that GPT OSS 120B model is blocked from evaluation."""
1036+
def test_gpt_oss_120b_model_allowed(self, mock_resolve, mock_session):
1037+
"""Test that GPT OSS 120B model is allowed for evaluation."""
10381038
mock_info = MagicMock()
10391039
mock_info.base_model_name = "openai-reasoning-gpt-oss-120b"
10401040
mock_info.base_model_arn = DEFAULT_HUB_CONTENT_ARN
10411041
mock_info.source_model_package_arn = None
10421042
mock_resolve.return_value = mock_info
10431043

1044-
with pytest.raises(ValidationError, match="Evaluation is currently not supported for models created from GPT OSS 20B base model"):
1045-
BaseEvaluator(
1046-
model="openai-reasoning-gpt-oss-120b",
1047-
s3_output_path=DEFAULT_S3_OUTPUT,
1048-
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1049-
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
1050-
sagemaker_session=mock_session,
1051-
)
1044+
evaluator = BaseEvaluator(
1045+
model="openai-reasoning-gpt-oss-120b",
1046+
s3_output_path=DEFAULT_S3_OUTPUT,
1047+
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1048+
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
1049+
sagemaker_session=mock_session,
1050+
)
1051+
assert evaluator.model == "openai-reasoning-gpt-oss-120b"
10521052

10531053
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
10541054
def test_non_gpt_oss_model_allowed(self, mock_resolve, mock_session, mock_model_info):

0 commit comments

Comments
 (0)