Skip to content

Commit 1124fe4

Browse files
add from_preset to OptimizationConfig (#279)
* add `from_preset` to `OptimizationConfig` * set default separation ratio to None * Update optimizer_config.schema.json * upd unit tests --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent e103bc7 commit 1124fe4

File tree

4 files changed

+18
-6
lines changed

4 files changed

+18
-6
lines changed

docs/optimizer_config.schema.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
"type": "null"
114114
}
115115
],
116-
"default": 0.5,
116+
"default": null,
117117
"description": "Set to float to prevent data leak between scoring and decision nodes.",
118118
"title": "Separation Ratio"
119119
},
@@ -498,7 +498,7 @@
498498
"scheme": "ho",
499499
"n_folds": 3,
500500
"validation_size": 0.2,
501-
"separation_ratio": 0.5,
501+
"separation_ratio": null,
502502
"is_few_shot_train": false,
503503
"examples_per_intent": 8
504504
}

src/autointent/_optimization_config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
44

55
from pydantic import BaseModel, Field, PositiveInt, field_validator
66

@@ -14,6 +14,10 @@
1414
get_default_hfmodel_config,
1515
initialize_embedder_config,
1616
)
17+
from .utils import load_preset
18+
19+
if TYPE_CHECKING:
20+
from .custom_types import SearchSpacePreset
1721

1822

1923
class OptimizationConfig(BaseModel):
@@ -46,3 +50,7 @@ def validate_embedder_config(cls, v: Any) -> EmbedderConfig: # noqa: ANN401
4650
hpo_config: HPOConfig = HPOConfig()
4751

4852
seed: PositiveInt = 42
53+
54+
@classmethod
55+
def from_preset(cls, preset: SearchSpacePreset) -> OptimizationConfig:
56+
return cls.model_validate(load_preset(preset))

src/autointent/configs/_optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class DataConfig(BaseModel):
2929
)
3030
"""Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."""
3131
separation_ratio: FloatFromZeroToOne | None = Field(
32-
0.5, description="Set to float to prevent data leak between scoring and decision nodes."
32+
None, description="Set to float to prevent data leak between scoring and decision nodes."
3333
)
3434
"""Set to float to prevent data leak between scoring and decision nodes."""
3535
is_few_shot_train: bool = Field(False, description="Whether to use few-shot training.")

tests/data/test_data_handler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def mock_split():
7373

7474

7575
def test_data_handler_initialization(sample_multiclass_data):
76-
handler = DataHandler(dataset=Dataset.from_dict(sample_multiclass_data), random_seed=42)
76+
handler = DataHandler(
77+
dataset=Dataset.from_dict(sample_multiclass_data), config=DataConfig(separation_ratio=0.5), random_seed=42
78+
)
7779

7880
assert handler.multilabel is False
7981
assert handler.dataset.n_classes == 2
@@ -84,7 +86,9 @@ def test_data_handler_initialization(sample_multiclass_data):
8486

8587

8688
def test_data_handler_multilabel_mode(sample_multilabel_data):
87-
handler = DataHandler(dataset=Dataset.from_dict(sample_multilabel_data), random_seed=42)
89+
handler = DataHandler(
90+
dataset=Dataset.from_dict(sample_multilabel_data), config=DataConfig(separation_ratio=0.5), random_seed=42
91+
)
8892

8993
assert handler.multilabel is True
9094
assert handler.dataset.n_classes == 2

0 commit comments

Comments
 (0)