Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,17 @@ def _assign_event_no(
data[k][extractor_name],
index=[0] if n_rows == 1 else None,
)
if extractor_name in dataframe_dict.keys():
dataframe_dict[extractor_name].append(df)
else:
dataframe_dict[extractor_name] = [df]
if not df.empty:
if extractor_name in dataframe_dict.keys():
dataframe_dict[extractor_name].append(df)
else:
dataframe_dict[extractor_name] = [df]

# Merge each list of dataframes if wanted by writer
if self._save_method.expects_merged_dataframes:
for key in dataframe_dict.keys():
dataframe_dict[key] = pd.concat(
dataframe_dict[key], axis=0
[df for df in dataframe_dict[key] if not df.empty], axis=0
).reset_index(drop=True)
return dataframe_dict

Expand Down Expand Up @@ -275,10 +276,11 @@ def get_map_function(
"""Identify map function to use (pure python or multiprocess)."""
# Choose relevant map-function given the requested number of workers.
n_workers = min(self._num_workers, nb_files)
self._num_workers = n_workers
if n_workers > 1:
self.info(
f"Starting pool of {n_workers} workers to process"
" {nb_files} {unit}"
f"{nb_files} {unit}"
)

manager = Manager()
Expand Down Expand Up @@ -321,7 +323,10 @@ def _update_shared_variables(

@final
def merge_files(
self, files: Optional[Union[List[str], str]] = None, **kwargs: Any
self,
files: Optional[Union[List[str], str]] = None,
output_dir: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Merge converted files.

Expand All @@ -330,6 +335,9 @@ def merge_files(

Args:
files: Intermediate files to be merged.
output_dir: Directory to save the merged files in.
**kwargs: Additional keyword arguments to be passed to the
`GraphNeTWriter.merge_files` method.
"""
if (files is None) & (len(self._output_files) > 0):
# If no input files are given, but output files from conversion
Expand All @@ -349,9 +357,10 @@ def merge_files(
"and you must therefore specify argument `files`."
)
assert files is not None

if output_dir is None:
output_dir = self._output_dir
# Merge files
merge_path = os.path.join(self._output_dir, "merged")
merge_path = os.path.join(output_dir, "merged")
self.info(f"Merging files to {merge_path}")
self._save_method.merge_files(
files=files_to_merge, output_dir=merge_path, **kwargs
Expand Down
7 changes: 6 additions & 1 deletion src/graphnet/data/extractors/combine_extractors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Module for combining multiple extractors into a single extractor."""

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from graphnet.utilities.imports import has_icecube_package
from graphnet.data.extractors.icecube.i3extractor import I3Extractor
Expand Down Expand Up @@ -31,6 +31,11 @@ def __init__(self, extractors: List[I3Extractor], extractor_name: str):
super().__init__(extractor_name=extractor_name)
self._extractors = extractors

def set_gcd(self, i3_file: str, gcd_file: Optional[str] = None) -> None:
"""Set the GCD file for all extractors."""
for extractor in self._extractors:
extractor.set_gcd(i3_file, gcd_file)

def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
"""Extract data from frame using all extractors.

Expand Down
30 changes: 28 additions & 2 deletions src/graphnet/data/extractors/extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Base I3Extractor class(es)."""

from typing import Any, Union
from typing import Any, Union, Callable
from abc import ABC, abstractmethod
import pandas as pd

Expand All @@ -23,21 +23,42 @@ class Extractor(ABC, Logger):
An extractor is used in conjunction with a specific `FileReader`.
"""

def __init__(self, extractor_name: str):
def __init__(self, extractor_name: str, exclude: list = [None]):
"""Construct Extractor.

Args:
extractor_name: Name of the `Extractor` instance.
Used to keep track of the provenance of different
data, and to name tables to which this data is
saved. E.g. "mc_truth".
exclude: List of keys to exclude from the extracted data.
"""
# Member variable(s)
self._extractor_name: str = extractor_name
self._exclude = exclude

# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)

def exclude(func: Callable) -> Callable:
"""Exclude specified keys from the extracted data."""

def wrapper(
self: "Extractor", *args: Any
) -> Union[dict, pd.DataFrame]:
result = func(self, *args)
if isinstance(result, dict):
for key in self._exclude:
if key in result:
del result[key]
elif isinstance(result, pd.DataFrame):
for key in self._exclude:
if key in result.columns:
result = result.drop(columns=[key])
return result

return wrapper

@abstractmethod
def __call__(self, data: Any) -> Union[dict, pd.DataFrame]:
"""Extract information from data."""
Expand All @@ -47,3 +68,8 @@ def __call__(self, data: Any) -> Union[dict, pd.DataFrame]:
def name(self) -> str:
"""Get the name of the `Extractor` instance."""
return self._extractor_name

def __init_subclass__(cls) -> None:
"""Initialize subclass and apply the exclude decorator to __call__."""
super().__init_subclass__()
cls.__call__ = cls.exclude(cls.__call__) # type: ignore
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a pretty elegant solution!

5 changes: 3 additions & 2 deletions src/graphnet/data/extractors/icecube/i3extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ class I3Extractor(Extractor):
method.
"""

def __init__(self, extractor_name: str):
def __init__(self, extractor_name: str, exclude: list = [None]):
"""Construct I3Extractor.

Args:
extractor_name: Name of the `I3Extractor` instance. Used to keep
track of the provenance of different data, and to name tables
to which this data is saved.
exclude: List of features to exclude from the extractor.
"""
# Member variable(s)
self._i3_file: str = ""
Expand All @@ -35,7 +36,7 @@ def __init__(self, extractor_name: str):
self._calibration: Optional["icetray.I3Frame.Calibration"] = None

# Base class constructor
super().__init__(extractor_name=extractor_name)
super().__init__(extractor_name=extractor_name, exclude=exclude)

def set_gcd(self, i3_file: str, gcd_file: Optional[str] = None) -> None:
"""Extract GFrame and CFrame from i3/gcd-file pair.
Expand Down
5 changes: 3 additions & 2 deletions src/graphnet/data/extractors/icecube/i3featureextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
class I3FeatureExtractor(I3Extractor):
"""Base class for extracting specific, reconstructed features."""

def __init__(self, pulsemap: str):
def __init__(self, pulsemap: str, exclude: list = [None]):
"""Construct I3FeatureExtractor.

Args:
pulsemap: Name of the pulse (series) map for which to extract
reconstructed features.
exclude: List of keys to exclude from the extracted data.
"""
# Member variable(s)
self._pulsemap = pulsemap

# Base class constructor
super().__init__(pulsemap)
super().__init__(pulsemap, exclude=exclude)


class I3FeatureExtractorIceCube86(I3FeatureExtractor):
Expand Down
3 changes: 2 additions & 1 deletion src/graphnet/data/extractors/icecube/i3genericextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
keys: Optional[Union[str, List[str]]] = None,
exclude_keys: Optional[Union[str, List[str]]] = None,
extractor_name: str = GENERIC_EXTRACTOR_NAME,
exclude: list = [None],
):
"""Construct I3GenericExtractor.

Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(
self._exclude_keys: Optional[List[str]] = exclude_keys

# Base class constructor
super().__init__(extractor_name)
super().__init__(extractor_name, exclude=exclude)

def _get_keys(self, frame: "icetray.I3Frame") -> List[str]:
"""Get the list of keys to be queried from `frame`.
Expand Down
4 changes: 2 additions & 2 deletions src/graphnet/data/extractors/icecube/i3hybridrecoextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
class I3GalacticPlaneHybridRecoExtractor(I3Extractor):
"""Class for extracting galatictic plane hybrid reconstruction."""

def __init__(self, name: str = "dnn_hybrid"):
def __init__(self, name: str = "dnn_hybrid", exclude: list = [None]):
"""Construct I3GalacticPlaneHybridRecoExtractor."""
# Base class constructor
super().__init__(name)
super().__init__(name, exclude=exclude)

def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
"""Extract TUMs DNN reconcstructions and associated variables."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ def __init__(
self,
name: str = "northeren_tracks_muon_labels",
padding_value: int = -1,
exclude: list = [None],
):
"""Construct I3NTMuonLabelExtractor."""
# Base class constructor
super().__init__(name)
super().__init__(name, exclude=exclude)
self._padding_value = padding_value

def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
Expand Down
4 changes: 2 additions & 2 deletions src/graphnet/data/extractors/icecube/i3particleextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ class I3ParticleExtractor(I3Extractor):
with GraphNeT.
"""

def __init__(self, extractor_name: str):
def __init__(self, extractor_name: str, exclude: list = [None]):
"""Construct I3ParticleExtractor."""
# Base class constructor
super().__init__(extractor_name=extractor_name)
super().__init__(extractor_name=extractor_name, exclude=exclude)

def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
"""Extract I3Particle properties from I3Particle in frame."""
Expand Down
6 changes: 4 additions & 2 deletions src/graphnet/data/extractors/icecube/i3pisaextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
class I3PISAExtractor(I3Extractor):
"""Class for extracting quantities required by PISA."""

def __init__(self, name: str = "pisa_dependencies"):
def __init__(
self, name: str = "pisa_dependencies", exclude: list = [None]
):
"""Construct `I3PISAExtractor`."""
# Base class constructor
super().__init__(name)
super().__init__(name, exclude=exclude)

def __call__(
self, frame: "icetray.I3Frame", padding_value: float = -1.0
Expand Down
4 changes: 2 additions & 2 deletions src/graphnet/data/extractors/icecube/i3quesoextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ def __init__(
self,
name: str = "queso",
padding_value: int = -1,
exclude: list = [None],
):
"""Construct I3QUESOExtractor."""
# Base class constructor
super().__init__(name)
super().__init__(name, exclude=exclude)
self._padding_value = padding_value

def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
Expand All @@ -42,5 +43,4 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
key: value,
}
)

return output
5 changes: 3 additions & 2 deletions src/graphnet/data/extractors/icecube/i3retroextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
class I3RetroExtractor(I3Extractor):
"""Class for extracting RETRO reconstruction."""

def __init__(self, name: str = "retro"):
def __init__(self, name: str = "retro", exclude: list = [None]):
"""Construct `I3RetroExtractor`."""
# Base class constructor
super().__init__(name)
super().__init__(name, exclude=exclude)

def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
"""Extract RETRO reconstruction and associated quantities."""
Expand Down Expand Up @@ -100,6 +100,7 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
frame["I3MCWeightDict"], "weight", default_value=-1
)

output = {k: v for k, v in output.items() if k not in self._exclude}
return output

def _frame_contains_retro(self, frame: "icetray.I3Frame") -> bool:
Expand Down
5 changes: 2 additions & 3 deletions src/graphnet/data/extractors/icecube/i3splinempeextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
class I3SplineMPEICExtractor(I3Extractor):
"""Class for extracting SplineMPE pointing predictions."""

def __init__(self, name: str = "spline_mpe_ic"):
def __init__(self, name: str = "spline_mpe_ic", exclude: list = [None]):
"""Construct I3SplineMPEICExtractor."""
# Base class constructor
super().__init__(name)
super().__init__(name, exclude=exclude)

def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
"""Extract SplineMPE pointing predictions."""
Expand All @@ -26,5 +26,4 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
"azimuth_spline_mpe_ic": frame["SplineMPEIC"].dir.azimuth,
}
)

return output
4 changes: 3 additions & 1 deletion src/graphnet/data/extractors/icecube/i3truthextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
borders: Optional[List[np.ndarray]] = None,
mctree: Optional[str] = "I3MCTree",
extend_boundary: Optional[float] = 0.0,
exclude: list = [None],
):
"""Construct I3TruthExtractor.

Expand All @@ -43,9 +44,10 @@ def __init__(
mctree: Str of which MCTree to use for truth values.
extend_boundary: Distance to extend the convex hull of the detector
for defining starting events.
exclude: List of keys to exclude from the extracted data.
"""
# Base class constructor
super().__init__(name)
super().__init__(name, exclude=exclude)

if borders is None:
border_xy = np.array(
Expand Down
5 changes: 2 additions & 3 deletions src/graphnet/data/extractors/icecube/i3tumextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
class I3TUMExtractor(I3Extractor):
"""Class for extracting TUM DNN predictions."""

def __init__(self, name: str = "tum_dnn"):
def __init__(self, name: str = "tum_dnn", exclude: list = [None]):
"""Construct I3TUMExtractor."""
# Base class constructor
super().__init__(name)
super().__init__(name, exclude=exclude)

def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
"""Extract TUM DNN recoconstruction and associated variables."""
Expand All @@ -29,5 +29,4 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
"tum_bdt_sigma": frame["TUM_bdt_sigma"].value,
}
)

return output