Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions kubeflow/trainer/backends/kubernetes/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,3 +845,238 @@ def test_get_args_from_dataset_preprocess_config(test_case: TestCase):
assert test_case.expected_status == FAILED
assert type(e) is test_case.expected_error
print("test execution complete")


def _build_builtin_runtime() -> types.Runtime:
runtime_trainer = types.RuntimeTrainer(
trainer_type=types.TrainerType.BUILTIN_TRAINER,
framework="torchtune",
device="gpu",
device_count="1",
image="ghcr.io/kubeflow/trainer/torchtune",
)
runtime_trainer.set_command(constants.TORCH_TUNE_COMMAND)
return types.Runtime(name="torchtune-llama3.2-1b", trainer=runtime_trainer)


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="empty config produces empty args",
Comment on lines +862 to +866
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description lists several test scenarios (e.g., “model_args only” and a “full config with all fields”) that aren’t represented in this parametrization, and TorchTuneConfig doesn’t appear to have a model_args field; please add the missing cases or update the PR description to match what’s actually tested.

Copilot uses AI. Check for mistakes.
expected_status=SUCCESS,
config={
"fine_tuning_config": types.TorchTuneConfig(),
},
expected_output=[],
),
TestCase(
name="dtype only produces dtype arg",
expected_status=SUCCESS,
config={
"fine_tuning_config": types.TorchTuneConfig(
dtype=types.DataType.BF16,
),
},
expected_output=[f"dtype={types.DataType.BF16}"],
),
TestCase(
name="all scalar fields produce correct args",
expected_status=SUCCESS,
config={
"fine_tuning_config": types.TorchTuneConfig(
dtype=types.DataType.FP32,
batch_size=8,
epochs=3,
loss=types.Loss.CEWithChunkedOutputLoss,
),
},
expected_output=[
f"dtype={types.DataType.FP32}",
"batch_size=8",
"epochs=3",
f"loss={types.Loss.CEWithChunkedOutputLoss}",
],
),
TestCase(
name="config with peft appends lora args",
expected_status=SUCCESS,
config={
"fine_tuning_config": types.TorchTuneConfig(
peft_config=types.LoraConfig(
lora_rank=32,
lora_alpha=16,
),
),
},
expected_output=[
"model.lora_rank=32",
"model.lora_alpha=16",
"model.lora_attn_modules=[q_proj,v_proj,output_proj]",
],
),
TestCase(
name="config with dataset preprocess appends dataset args",
expected_status=SUCCESS,
config={
"fine_tuning_config": types.TorchTuneConfig(
dataset_preprocess_config=types.TorchTuneInstructDataset(
split="train",
),
),
},
expected_output=[
f"dataset={constants.TORCH_TUNE_INSTRUCT_DATASET}",
"dataset.split=train",
],
),
TestCase(
name="initializer with directory dataset produces data_dir arg",
expected_status=SUCCESS,
config={
"fine_tuning_config": types.TorchTuneConfig(
dtype=types.DataType.BF16,
),
"initializer": types.Initializer(
dataset=types.HuggingFaceDatasetInitializer(
storage_uri="hf://tatsu-lab/alpaca",
),
),
},
expected_output=[
f"dtype={types.DataType.BF16}",
f"dataset.data_dir={constants.DATASET_PATH}/.",
],
),
TestCase(
name="initializer with file dataset produces data_files arg",
expected_status=SUCCESS,
config={
"fine_tuning_config": types.TorchTuneConfig(),
"initializer": types.Initializer(
dataset=types.HuggingFaceDatasetInitializer(
storage_uri="hf://tatsu-lab/alpaca/data.json",
),
),
},
expected_output=[
f"dataset.data_files={constants.DATASET_PATH}/data.json",
],
),
TestCase(
name="invalid dtype raises ValueError",
expected_status=FAILED,
config={
"fine_tuning_config": types.TorchTuneConfig(
dtype="invalid",
),
},
expected_error=ValueError,
),
],
)
def test_get_args_using_torchtune_config(test_case: TestCase):
print("Executing test:", test_case.name)
try:
args = utils.get_args_using_torchtune_config(
test_case.config["fine_tuning_config"],
test_case.config.get("initializer"),
)

assert test_case.expected_status == SUCCESS
assert args == test_case.expected_output

except Exception as e:
assert test_case.expected_status == FAILED
assert type(e) is test_case.expected_error
print("test execution complete")


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="valid config with num_nodes and batch_size",
expected_status=SUCCESS,
config={
"runtime": _build_builtin_runtime(),
"trainer": types.BuiltinTrainer(
config=types.TorchTuneConfig(
num_nodes=2,
batch_size=8,
),
),
},
expected_output=models.TrainerV1alpha1Trainer(
command=["tune", "run"],
args=["batch_size=8"],
numNodes=2,
),
),
TestCase(
name="valid config with resources_per_node",
expected_status=SUCCESS,
config={
"runtime": _build_builtin_runtime(),
"trainer": types.BuiltinTrainer(
config=types.TorchTuneConfig(
resources_per_node={"gpu": 2},
),
),
},
expected_output=models.TrainerV1alpha1Trainer(
command=["tune", "run"],
args=[],
resourcesPerNode=models.IoK8sApiCoreV1ResourceRequirements(
limits={
"nvidia.com/gpu": models.IoK8sApimachineryPkgApiResourceQuantity(2),
},
requests={
"nvidia.com/gpu": models.IoK8sApimachineryPkgApiResourceQuantity(2),
},
),
),
),
TestCase(
name="empty config produces trainer with empty args",
expected_status=SUCCESS,
config={
"runtime": _build_builtin_runtime(),
"trainer": types.BuiltinTrainer(
config=types.TorchTuneConfig(),
),
},
expected_output=models.TrainerV1alpha1Trainer(
command=["tune", "run"],
args=[],
),
),
TestCase(
name="invalid config type raises ValueError",
expected_status=FAILED,
config={
"runtime": _build_builtin_runtime(),
"trainer": types.BuiltinTrainer(
config="invalid_config",
),
},
expected_error=ValueError,
),
],
)
def test_get_trainer_cr_from_builtin_trainer(test_case: TestCase):
print("Executing test:", test_case.name)
try:
trainer_cr = utils.get_trainer_cr_from_builtin_trainer(
test_case.config["runtime"],
test_case.config["trainer"],
test_case.config.get("initializer"),
)

assert test_case.expected_status == SUCCESS
assert trainer_cr == test_case.expected_output

except Exception as e:
assert test_case.expected_status == FAILED
assert type(e) is test_case.expected_error
print("test execution complete")
Loading