Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/fixtures/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def foo_config_file(config_dir, input_dir, output_dir, foo_system_output_tiny_mu
"scorer": {
"input_dir": str(input_dir),
"output_dir": str(output_dir),
"systems": [".".join(path.stem.split(".")[:-1]) for path in foo_system_output_tiny_multilingual],
"systems": [str(path.parent).split("/")[-1] for path in foo_system_output_tiny_multilingual],
"languages": languages,
"file_suffix": "txt",
"metrics": [
Expand Down
6 changes: 3 additions & 3 deletions tests/fixtures/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ def foo_system_output_tiny(input_dir, foo_text_tiny):
def foo_system_output_tiny_multilingual(input_dir, foo_text_tiny, languages):
"""TODO"""
paths = []
Path(input_dir, "system_multi_1").mkdir()
for lang in languages:
paths.append(Path(input_dir, f"system_multi_1.{lang}.txt"))
paths.append(Path(input_dir, "system_multi_1", f"{lang}.txt"))
with open_file(paths[-1], "w") as fh:
print(foo_text_tiny, file=fh)
return paths


@pytest.fixture(scope="session")
def foo_dataset(input_dir, foo_system_output_tiny_multilingual, languages):
system_name = ".".join(foo_system_output_tiny_multilingual[0].name.split(".")[:-2])
return {
"data_dir": input_dir,
"systems": [system_name],
"systems": [str(path.parent).split("/")[-1] for path in foo_system_output_tiny_multilingual],
"languages": languages,
}
6 changes: 4 additions & 2 deletions tests/tokeval/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def foo_data(request, foo_text_tiny, tmp_path_factory):
"""TODO"""

def create_file(file_path, text): # noqa: ANN202
if not file_path.parent.exists():
file_path.parent.mkdir(parents=True)
with open_file(file_path, "w") as fh:
print(text, file=fh)

Expand All @@ -35,8 +37,8 @@ def create_file(file_path, text): # noqa: ANN202
elif "multi" in request.param:
data["languages"] = LANGUAGES
for lang in LANGUAGES:
data["filenames"].append(f"{system}.{lang}." + data["file_suffix"])
create_file(Path(data["input_dir"], data["filenames"][-1]), foo_text_tiny)
data["filenames"].append(f"{lang}." + data["file_suffix"])
create_file(Path(data["input_dir"], f"{system}", data["filenames"][-1]), foo_text_tiny)
else:
pytest.fail(f"Unknown system label ({request.param}).")

Expand Down
60 changes: 58 additions & 2 deletions tokeval/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from pathlib import Path
from typing import Any

from attrs import converters, define, field, validators

Expand All @@ -9,6 +10,29 @@

TextType = list[list[str]]

LANG_SPEC_LEN = 3


@define(kw_only=True)
class LanguageInfo(dict):
"""TODO"""

name: str = field(validator=validators.instance_of(str))
scripts: list[str] = field(validator=validators.instance_of(list))
glottocodes: list[str] = field(validator=validators.instance_of(list))
families: str = field(validator=validators.instance_of(str))
speakers: int = field(validator=validators.instance_of(int))
continent: str = field(validator=validators.instance_of(str))
wikipedia: str = field(validator=validators.instance_of(str))
tier: int = field(validator=validators.instance_of(int))
morphology: str = field(validator=validators.instance_of(str))
fineweb2: dict[str, int] = field(validator=validators.optional(validators.instance_of(dict)))

@classmethod
def create_entry(cls: "LanguageInfo", entry: dict) -> "LanguageInfo":
"""TODO"""
return cls(**entry)


@define(kw_only=True)
class TokEvalData:
Expand All @@ -17,6 +41,7 @@ class TokEvalData:
data_dir: Path = field(converter=Path)
systems: list[str] = field(converter=converters.optional(list), factory=list)
languages: list[str] = field(converter=converters.optional(list), factory=list)
languages_info: dict[str, Any] = field(validator=validators.optional(validators.instance_of(dict)), default=None)
metrics: list["TokEvalMetric"] = field(factory=list) # noqa: F821
file_suffix: str = field(validator=validators.instance_of(str), default="txt")
input_file_stem: str = field(validator=validators.instance_of(str), default="input")
Expand All @@ -36,9 +61,9 @@ def __attrs_post_init__(self) -> None:
logger.debug("Loading %s ...", filename)
self._data[system_label] = load_tokenized_text_file(Path(self.data_dir, filename))
else:
logger.debug("Loading %s.{%s}.%s ...", system_label, ",".join(self.languages), self.file_suffix)
logger.debug("Loading %s/{%s}.%s ...", system_label, ",".join(self.languages), self.file_suffix)
self._data[system_label] = {
lang: load_tokenized_text_file(Path(self.data_dir, f"{system_label}.{lang}.{self.file_suffix}"))
lang: load_tokenized_text_file(Path(self.data_dir, f"{system_label}", f"{lang}.{self.file_suffix}"))
for lang in self.languages
}

Expand All @@ -51,6 +76,37 @@ def __attrs_post_init__(self) -> None:
filename = f"{self.reference_file_stem}.{self.file_suffix}"
self._data[self._reference_key] = load_tokenized_text_file(Path(self.data_dir, filename))

if self.languages_info is not None:
# The language info needs to follow a strict data structure. In such case, the language specification also
# needs to follow it.
for lang in self.languages:
lang_split = lang.split("_")
assert len(lang_split) == LANG_SPEC_LEN
if lang_split[0] not in self.languages_info:
logger.exception(
"Language %s not in the provided languages_info JSON file.\nAvailable languages: [%s]",
lang_split[0],
",".join(self.languages_info.keys()),
)
if lang_split[1] not in self.languages_info[lang_split[0]]["scripts"]:
logger.exception(
"Script %s of language %s not listed in the languages_info JSON file.\n"
"langages_info['%s'] = %s",
lang_split[1],
lang_split[0],
lang_split[0],
self.languages_info[lang_split[0]],
)
if lang_split[2] not in self.languages_info[lang_split[0]]["glottocodes"]:
logger.exception(
"Glottocode %s of language %s not listed in the languages_info JSON file.\n"
"langages_info['%s'] = %s",
lang_split[2],
lang_split[0],
lang_split[0],
self.languages_info[lang_split[0]],
)

@property
def has_input_text(self) -> bool:
"""TODO"""
Expand Down
2 changes: 1 addition & 1 deletion tokeval/metrics/tokeval_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def score_all(self, data: TokEvalData, systems: list[str], languages: list[str])
res = np.zeros(shape=[len(systems), len(languages)])
for i, system_label in enumerate(systems):
for j, lang in enumerate(languages):
logger.debug("[%s] Scoring system %s (%lang)...", self.metric_label, system_label, lang)
logger.debug("[%s] Scoring system %s (%s)...", self.metric_label, system_label, lang)
res[i, j] = self.score(data=data, system_label=system_label, language=lang)
return res

Expand Down
24 changes: 22 additions & 2 deletions tokeval/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import json
import logging
from pathlib import Path
from typing import ClassVar
from typing import Any, ClassVar

import numpy as np
from attrs import converters, define, field, fields, validators
from omegaconf import DictConfig

from tokeval.data import TokEvalData
from tokeval.data import LanguageInfo, TokEvalData
from tokeval.metrics import TokEvalMetric, build_metric

logger = logging.getLogger(__name__)
Expand All @@ -26,6 +26,10 @@ class ScorerResultSaver(dict):
metrics: list[str] = field(validator=validators.instance_of(list))
languages: list[str] = field(validator=validators.instance_of(list))

# TODO(varisd): replace any with proper data structure definition
languages_info: dict[str, Any] = field(validator=validators.optional(validators.instance_of(dict)))

_languages_info_filename: ClassVar[str] = "languages_info.json"
_metadata_filename: ClassVar[str] = "metadata.json"
_results_filename: ClassVar[str] = "results.npz"

Expand Down Expand Up @@ -54,6 +58,11 @@ def save_results(self, results: dict) -> None:
path = Path(self.output_dir, self._results_filename)
np.savez(path, **results)

if self.languages_info:
path = Path(self.output_dir, self._languages_info_filename)
with path.open("w") as fh:
json.dump(self.languages_info, sort_keys=True, indent=2, fp=fh)


@define(kw_only=True)
class TokEvalScorer:
Expand All @@ -80,6 +89,7 @@ class TokEvalScorer:
output_dir: Path = field(converter=converters.optional(Path), init=False, default=None)
systems: list[str] = field(init=False)
languages: list[str] = field(init=False, factory=list)
languages_info: dict[str, Any] = field(init=False, default=None)
file_suffix: str = field(init=False, default="txt")

metrics: dict[str, TokEvalMetric] = field(init=False, default=None)
Expand All @@ -91,6 +101,14 @@ def __attrs_post_init__(self) -> None:
for param in self.list_parameters():
if param.name == "config":
continue
if param.name == "languages_info":
if getattr(self.config.scorer, param.name, None) is not None:
path = Path(getattr(self.config.scorer, param.name))
self.languages_info = {
lang: LanguageInfo.create_entry(entry)
for lang, entry in json.load(path.open("r", encoding="utf-8")).items()
}
continue
if hasattr(self.config.scorer, param.name):
setattr(self, param.name, getattr(self.config.scorer, param.name))
if param.default is not None and getattr(self, param.name, None) is None:
Expand All @@ -103,6 +121,7 @@ def __attrs_post_init__(self) -> None:
data_dir=self.input_dir,
systems=self.systems,
languages=self.languages,
languages_info=self.languages_info,
metrics=self.metrics.values(),
file_suffix=self.file_suffix,
)
Expand Down Expand Up @@ -131,6 +150,7 @@ def run(self) -> dict[str, dict[str, np.ndarray]]:
tokenizers=list(self.systems),
metrics=list(self.metrics.keys()),
languages=list(self.languages),
languages_info=self.languages_info,
).save_results(results)
else:
logger.info("No scorer.output_dir was provided. Printing results to STDOUT:\n")
Expand Down