Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions mipcandy/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABCMeta
from collections.abc import Generator
from math import log, ceil
from os import PathLike, listdir
from os.path import isdir, basename, exists
Expand Down Expand Up @@ -68,16 +69,36 @@ def predict_image(self, image: torch.Tensor, *, batch: bool = False) -> torch.Te
output = restoring_module(output)
return output if batch else output.squeeze(0)

def _predict(self, x: SupportedPredictant | UnsupervisedDataset) -> tuple[list[torch.Tensor], list[str] | None]:
def _predict(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[
tuple[torch.Tensor, str | None], None, None]:
if isinstance(x, PathBasedUnsupervisedDataset):
return [self.predict_image(case) for case in x], x.paths()
for case, path in zip(x, x.paths()):
yield self.predict_image(case), path
return
if isinstance(x, UnsupervisedDataset):
return [self.predict_image(case) for case in x], None
images, filenames = parse_predictant(x, Loader)
return [self.predict_image(image) for image in images], filenames
for case in x:
yield self.predict_image(case), None
return
if isinstance(x, str):
if isdir(x):
for case in listdir(x):
yield self.predict_image(Loader.do_load(f"{x}/{case}")), case
else:
yield self.predict_image(Loader.do_load(x)), basename(x)
return
if isinstance(x, torch.Tensor):
yield self.predict_image(x), None
return
for case in x:
if isinstance(case, str):
yield self.predict_image(Loader.do_load(case)), case[case.rfind("/") + 1:]
elif isinstance(case, torch.Tensor):
yield self.predict_image(case), None
else:
raise TypeError(f"Unexpected type of element {type(case)}")

def predict(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]:
return self._predict(x)[0]
return [output for output, _ in self._predict(x)]

@staticmethod
def save_prediction(output: torch.Tensor, path: str | PathLike[str]) -> None:
Expand All @@ -96,9 +117,19 @@ def save_predictions(self, outputs: Sequence[torch.Tensor], folder: str | PathLi

def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset,
folder: str | PathLike[str]) -> list[str] | None:
outputs, filenames = self._predict(x)
self.save_predictions(outputs, folder, filenames=filenames)
return filenames
if not exists(folder):
raise FileNotFoundError(f"Folder {folder} does not exist")
result: list[str] | None = None
for i, (output, name) in enumerate(self._predict(x)):
if name is not None:
if result is None:
result = []
result.append(name)
else:
ext = "png" if output.ndim == 3 and output.shape[0] in (1, 3) else "mha"
name = f"prediction_{i}.{ext}"
self.save_prediction(output, f"{folder}/{name}")
return result

def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]:
return self.predict(x)