diff --git a/mipcandy/inference.py b/mipcandy/inference.py index b493ed1..16c5be5 100644 --- a/mipcandy/inference.py +++ b/mipcandy/inference.py @@ -1,4 +1,5 @@ from abc import ABCMeta +from collections.abc import Generator from math import log, ceil from os import PathLike, listdir from os.path import isdir, basename, exists @@ -68,16 +69,36 @@ def predict_image(self, image: torch.Tensor, *, batch: bool = False) -> torch.Te output = restoring_module(output) return output if batch else output.squeeze(0) - def _predict(self, x: SupportedPredictant | UnsupervisedDataset) -> tuple[list[torch.Tensor], list[str] | None]: + def _predict(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[ + tuple[torch.Tensor, str | None], None, None]: if isinstance(x, PathBasedUnsupervisedDataset): - return [self.predict_image(case) for case in x], x.paths() + for case, path in zip(x, x.paths()): + yield self.predict_image(case), path + return if isinstance(x, UnsupervisedDataset): - return [self.predict_image(case) for case in x], None - images, filenames = parse_predictant(x, Loader) - return [self.predict_image(image) for image in images], filenames + for case in x: + yield self.predict_image(case), None + return + if isinstance(x, str): + if isdir(x): + for case in listdir(x): + yield self.predict_image(Loader.do_load(f"{x}/{case}")), case + else: + yield self.predict_image(Loader.do_load(x)), basename(x) + return + if isinstance(x, torch.Tensor): + yield self.predict_image(x), None + return + for case in x: + if isinstance(case, str): + yield self.predict_image(Loader.do_load(case)), case[case.rfind("/") + 1:] + elif isinstance(case, torch.Tensor): + yield self.predict_image(case), None + else: + raise TypeError(f"Unexpected type of element {type(case)}") def predict(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]: - return self._predict(x)[0] + return [output for output, _ in self._predict(x)] @staticmethod def save_prediction(output: torch.Tensor, path: str | PathLike[str]) -> None: @@ -96,9 +117,19 @@ def save_predictions(self, outputs: Sequence[torch.Tensor], folder: str | PathLi def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset, folder: str | PathLike[str]) -> list[str] | None: - outputs, filenames = self._predict(x) - self.save_predictions(outputs, folder, filenames=filenames) - return filenames + if not exists(folder): + raise FileNotFoundError(f"Folder {folder} does not exist") + result: list[str] | None = None + for i, (output, name) in enumerate(self._predict(x)): + if name is not None: + if result is None: + result = [] + result.append(name) + else: + ext = "png" if output.ndim == 3 and output.shape[0] in (1, 3) else "mha" + name = f"prediction_{i}.{ext}" + self.save_prediction(output, f"{folder}/{name}") + return result def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]: return self.predict(x)