Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/optimizer_config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
"type": "null"
}
],
"default": 0.5,
"default": null,
"description": "Set to float to prevent data leak between scoring and decision nodes.",
"title": "Separation Ratio"
},
Expand Down Expand Up @@ -498,7 +498,7 @@
"scheme": "ho",
"n_folds": 3,
"validation_size": 0.2,
"separation_ratio": 0.5,
"separation_ratio": null,
"is_few_shot_train": false,
"examples_per_intent": 8
}
Expand Down
10 changes: 9 additions & 1 deletion src/autointent/_optimization_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel, Field, PositiveInt, field_validator

Expand All @@ -14,6 +14,10 @@
get_default_hfmodel_config,
initialize_embedder_config,
)
from .utils import load_preset

if TYPE_CHECKING:
from .custom_types import SearchSpacePreset


class OptimizationConfig(BaseModel):
Expand Down Expand Up @@ -46,3 +50,7 @@ def validate_embedder_config(cls, v: Any) -> EmbedderConfig: # noqa: ANN401
hpo_config: HPOConfig = HPOConfig()

seed: PositiveInt = 42

@classmethod
def from_preset(cls, preset: SearchSpacePreset) -> OptimizationConfig:
return cls.model_validate(load_preset(preset))
2 changes: 1 addition & 1 deletion src/autointent/configs/_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DataConfig(BaseModel):
)
"""Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."""
separation_ratio: FloatFromZeroToOne | None = Field(
0.5, description="Set to float to prevent data leak between scoring and decision nodes."
None, description="Set to float to prevent data leak between scoring and decision nodes."
)
"""Set to float to prevent data leak between scoring and decision nodes."""
is_few_shot_train: bool = Field(False, description="Whether to use few-shot training.")
Expand Down
8 changes: 6 additions & 2 deletions tests/data/test_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def mock_split():


def test_data_handler_initialization(sample_multiclass_data):
handler = DataHandler(dataset=Dataset.from_dict(sample_multiclass_data), random_seed=42)
handler = DataHandler(
dataset=Dataset.from_dict(sample_multiclass_data), config=DataConfig(separation_ratio=0.5), random_seed=42
)

assert handler.multilabel is False
assert handler.dataset.n_classes == 2
Expand All @@ -84,7 +86,9 @@ def test_data_handler_initialization(sample_multiclass_data):


def test_data_handler_multilabel_mode(sample_multilabel_data):
handler = DataHandler(dataset=Dataset.from_dict(sample_multilabel_data), random_seed=42)
handler = DataHandler(
dataset=Dataset.from_dict(sample_multilabel_data), config=DataConfig(separation_ratio=0.5), random_seed=42
)

assert handler.multilabel is True
assert handler.dataset.n_classes == 2
Expand Down
Loading