Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nam/models/_constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Model version is independent from package version as of package version 0.5.2 so that
# the schema of the package can iterate at a different pace from that of the model
# files.
MODEL_VERSION = "0.6.0"
MODEL_VERSION = "0.7.0"
253 changes: 67 additions & 186 deletions nam/models/wavenet/_slimmable_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
Implements the "channel slicing" method introduced in https://arxiv.org/abs/2511.07470
"""

import abc as _abc
from typing import Optional as _Optional
from typing import Sequence as _Sequence
from typing import Tuple as _Tuple
Expand All @@ -17,15 +16,12 @@
from ._slimmable import Slimmable as _Slimmable


def _ratio_to_channels(
ratio: float, allowed_channels: _Sequence[int]
) -> _Tuple[int, int]:
def _ratio_to_channel_index(ratio: float, num_options: int) -> int:
"""
Convert ratio in [0, 1] to integer channel count, minimum 1.
Also return the index of the chosen channel entry.
"""
i = min(int(_np.floor(ratio * len(allowed_channels))), len(allowed_channels) - 1)
return allowed_channels[i], i
return min(int(_np.floor(ratio * num_options)), num_options - 1)


class _AllowedChannelsValueError(ValueError):
Expand Down Expand Up @@ -108,7 +104,11 @@ def _init_channel_causal(


class SlimmableConv1dBase(_conv.Conv1d, _Slimmable):
"""Base for slimmable 1D conv layers. Subclasses implement _get_adjusted_weight_and_bias."""
"""
Base for slimmable 1D conv layers.

Subclasses configure allowed_in_channels and allowed_out_channels in __init__
"""

def __init__(
self,
Expand Down Expand Up @@ -179,12 +179,61 @@ def forward(self, input: _torch.Tensor) -> _torch.Tensor:
input, w, b, self.stride, self.padding, self.dilation, self.groups
)

@_abc.abstractmethod
def _get_adjusted_weight_and_bias(
self,
) -> _Tuple[_torch.Tensor, _Optional[_torch.Tensor]]:
"""Get weight and bias tensors for the current adjust size."""
pass
i_in, i_out = self._get_slim_indices()

out_channels = self._allowed_out_channels[i_out]
in_channels = self._allowed_in_channels[i_in]

w_full = self.weight[:out_channels, :in_channels, :]
b_full = None if self.bias is None else self.bias[:out_channels]
if not self._boosting or (i_in == 0 and i_out == 0):
return w_full, b_full

# Boosting: Previous channels
in_channels_prev = self._allowed_in_channels[i_in - 1]
out_channels_prev = self._allowed_out_channels[i_out - 1]

## Boosting mask for w:
mask_w1 = _torch.zeros_like(w_full, dtype=_torch.bool)
mask_w2 = _torch.ones_like(w_full, dtype=_torch.bool)
mask_w1[:out_channels_prev, :in_channels_prev, :] = True
mask_w2[:out_channels_prev, :in_channels_prev, :] = False

w_boost = (mask_w1 * w_full).detach() + mask_w2 * w_full

# Boosting mask for b:
if b_full is not None:
mask_b1 = _torch.zeros_like(b_full, dtype=_torch.bool)
mask_b2 = _torch.ones_like(b_full, dtype=_torch.bool)
mask_b1[:out_channels_prev] = True
mask_b2[:out_channels_prev] = False
b_boost = (mask_b1 * b_full).detach() + mask_b2 * b_full
else:
b_boost = None
return w_boost, b_boost

def _get_slim_indices(
self,
) -> _Tuple[int, int]:
"""
The main part of channel-slicing slimming: figure out which channel we're
slicing to in the input and outputs.

:return:
i_in: index in _allowed_in_channels to slice to
i_out: index in _allowed_out_channels to slice to
"""
i_in = _ratio_to_channel_index(
self._slimming_value, len(self._allowed_in_channels)
)
i_out = _ratio_to_channel_index(
self._slimming_value, len(self._allowed_out_channels)
)
return i_in, i_out


class _SlimmableRechannelIn(_conv.RechannelIn, SlimmableConv1dBase):
Expand Down Expand Up @@ -214,40 +263,6 @@ def __init__(
super().__init__(
in_channels, *args, allowed_in_channels=allowed_in_channels, **kwargs
)
self._is_first = is_first

def _get_adjusted_weight_and_bias(
self,
) -> _Tuple[_torch.Tensor, _Optional[_torch.Tensor]]:
out_channels, i_out = _ratio_to_channels(
self._slimming_value, self._allowed_out_channels
)
if self._is_first:
in_channels = self.in_channels
i_in = 0
else:
in_channels, i_in = _ratio_to_channels(
self._slimming_value, self._allowed_in_channels
)
w_full = self.weight[:out_channels, :in_channels, :]
b_full = None if self.bias is None else self.bias[:out_channels]
if not self._boosting or (self._is_first or i_in == 0):
return w_full, b_full
in_channels_prev = self._allowed_in_channels[i_in - 1]
out_channels_prev = self._allowed_out_channels[i_out - 1]
w_prev = w_full[:out_channels_prev, :in_channels_prev, :]
w_diff = _torch.zeros_like(w_full)
w_diff[:out_channels_prev, :in_channels_prev, :] = (
self.weight[:out_channels_prev, :in_channels_prev, :].detach() - w_prev
)
w = w_full + w_diff
if self.bias is None:
return w, None
b_slice = self.bias[:out_channels]
b_prev = b_slice[:out_channels_prev]
b_diff = _torch.zeros_like(b_slice)
b_diff[:out_channels_prev] = self.bias[:out_channels_prev].detach() - b_prev
return w, b_slice + b_diff


class _SlimmableLayerConv(_conv.LayerConv, SlimmableConv1dBase):
Expand Down Expand Up @@ -285,36 +300,6 @@ def __init__(
allowed_out_channels=allowed_out_channels,
**kwargs,
)
self._output_paired = output_paired

def _get_adjusted_weight_and_bias(
self,
) -> _Tuple[_torch.Tensor, _Optional[_torch.Tensor]]:
in_channels, i_in = _ratio_to_channels(
self._slimming_value, self._allowed_in_channels
)
out_channels = 2 * in_channels if self._output_paired else in_channels
w_full = self.weight[:out_channels, :in_channels, :]
b_full = None if self.bias is None else self.bias[:out_channels]
if not self._boosting or i_in == 0:
return w_full, b_full
in_channels_prev = self._allowed_in_channels[i_in - 1]
out_channels_prev = (
2 * in_channels_prev if self._output_paired else in_channels_prev
)
w_prev = w_full[:out_channels_prev, :in_channels_prev, :]
w_diff = _torch.zeros_like(w_full)
w_diff[:out_channels_prev, :in_channels_prev, :] = (
self.weight[:out_channels_prev, :in_channels_prev, :].detach() - w_prev
)
w = w_full + w_diff
if self.bias is None:
return w, None
b_slice = self.bias[:out_channels]
b_prev = b_slice[:out_channels_prev]
b_diff = _torch.zeros_like(b_slice)
b_diff[:out_channels_prev] = self.bias[:out_channels_prev].detach() - b_prev
return w, b_slice + b_diff


class _SlimmableInputMixer(_conv.InputMixer, SlimmableConv1dBase):
Expand Down Expand Up @@ -353,95 +338,30 @@ def __init__(
output_paired=output_paired,
**kwargs,
)
self._output_paired = output_paired

def _get_adjusted_weight_and_bias(
self,
) -> _Tuple[_torch.Tensor, _Optional[_torch.Tensor]]:
out_channels, i_out = _ratio_to_channels(
self._slimming_value, self._allowed_out_channels
)
w_full = self.weight[:out_channels, :, :]
b_full = None if self.bias is None else self.bias[:out_channels]
if not self._boosting or i_out == 0:
return w_full, b_full
out_channels_prev = self._allowed_out_channels[i_out - 1]
w_prev = w_full[:out_channels_prev, :, :]
w_diff = _torch.zeros_like(w_full)
w_diff[:out_channels_prev, :, :] = (
self.weight[:out_channels_prev, :, :].detach() - w_prev
)
w = w_full + w_diff
if self.bias is None:
return w, None
b_slice = self.bias[:out_channels]
b_prev = b_slice[:out_channels_prev]
b_diff = _torch.zeros_like(b_slice)
b_diff[:out_channels_prev] = self.bias[:out_channels_prev].detach() - b_prev
return w, b_slice + b_diff


class _SlimmableLayer1x1(SlimmableConv1dBase):
"""1x1 conv in residual path. Slice both in and out (must be equal for slimmable)."""

def _get_adjusted_weight_and_bias(
self,
) -> _Tuple[_torch.Tensor, _Optional[_torch.Tensor]]:
in_channels, i_in = _ratio_to_channels(
self._slimming_value, self._allowed_in_channels
)
out_channels, _ = _ratio_to_channels(
self._slimming_value, self._allowed_out_channels
)
w_full = self.weight[:out_channels, :in_channels, :]
b_full = None if self.bias is None else self.bias[:out_channels]
if not self._boosting or i_in == 0:
return w_full, b_full
in_channels_prev = self._allowed_in_channels[i_in - 1]
out_channels_prev = self._allowed_out_channels[i_in - 1]
w_prev = w_full[:out_channels_prev, :in_channels_prev, :]
w_diff = _torch.zeros_like(w_full)
w_diff[:out_channels_prev, :in_channels_prev, :] = (
self.weight[:out_channels_prev, :in_channels_prev, :].detach() - w_prev
)
w = w_full + w_diff
if self.bias is None:
return w, None
b_slice = self.bias[:out_channels]
b_prev = b_slice[:out_channels_prev]
b_diff = _torch.zeros_like(b_slice)
b_diff[:out_channels_prev] = self.bias[:out_channels_prev].detach() - b_prev
return w, b_slice + b_diff


class _SlimmableHead1x1(SlimmableConv1dBase):
"""
1x1 conv to the head collector
"""

def _get_adjusted_weight_and_bias(
def __init__(
self,
) -> _Tuple[_torch.Tensor, _Optional[_torch.Tensor]]:
in_channels: int,
out_channels: int,
*args,
allowed_in_channels: _Optional[_Sequence[int]] = None,
allowed_out_channels: _Optional[_Sequence[int]] = None,
**kwargs,
):
raise NotImplementedError("Slimmable head 1x1 not implemented")

# Layer1x1 code...

# def max_adjust_size() -> int:
# if self.in_channels != self.out_channels:
# raise NotImplementedError(
# "Slimmable 1x1 conv with different input and output channels not implemented"
# )
# return self.in_channels

# adj = _ratio_to_channels(self._slimming_value, max_adjust_size())
# w = self.weight[:adj, :adj, :]
# b = None if self.bias is None else self.bias[:adj]
# return w, b


class _SlimmableHeadRechannel(_conv.HeadRechannel, SlimmableConv1dBase):
"""
Head rechannel: output size si fixed on the last layer array."""
"""Head rechannel: output size is fixed on the last layer array."""

def __init__(
self,
Expand Down Expand Up @@ -473,45 +393,6 @@ def __init__(
allowed_out_channels=allowed_out_channels,
**kwargs,
)
self._is_last = is_last

def _get_adjusted_weight_and_bias(
self,
) -> _Tuple[_torch.Tensor, _Optional[_torch.Tensor]]:
in_channels, i_in = _ratio_to_channels(
self._slimming_value, self._allowed_in_channels
)
if self._is_last:
out_channels = self._allowed_out_channels[0]
out_channels_prev = out_channels
i_out = i_in
else:
out_channels, i_out = _ratio_to_channels(
self._slimming_value, self._allowed_out_channels
)
w_full = self.weight[:out_channels, :in_channels, :]
b_full = None if self.bias is None else self.bias[:out_channels]
if not self._boosting or i_in == 0:
return w_full, b_full
in_channels_prev = self._allowed_in_channels[i_in - 1]
out_channels_prev = (
self._allowed_out_channels[0]
if self._is_last
else self._allowed_out_channels[i_out - 1]
)
w_prev = w_full[:out_channels_prev, :in_channels_prev, :]
w_diff = _torch.zeros_like(w_full)
w_diff[:out_channels_prev, :in_channels_prev, :] = (
self.weight[:out_channels_prev, :in_channels_prev, :].detach() - w_prev
)
w = w_full + w_diff
if self.bias is None:
return w, None
b_slice = self.bias[:out_channels]
b_prev = b_slice[:out_channels_prev]
b_diff = _torch.zeros_like(b_slice)
b_diff[:out_channels_prev] = self.bias[:out_channels_prev].detach() - b_prev
return w, b_slice + b_diff


class_set = _conv.ClassSet(
Expand Down
12 changes: 6 additions & 6 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,13 +1738,13 @@ def validate_data(
try:
ds = _init_dataset(data_config, split)
ds.teardown()
pytorch_data_split_validation_dict[split.value] = (
_PyTorchDataSplitValidation(passed=True, msg=None)
)
pytorch_data_split_validation_dict[
split.value
] = _PyTorchDataSplitValidation(passed=True, msg=None)
except _DataError as e:
pytorch_data_split_validation_dict[split.value] = (
_PyTorchDataSplitValidation(passed=False, msg=str(e))
)
pytorch_data_split_validation_dict[
split.value
] = _PyTorchDataSplitValidation(passed=False, msg=str(e))
pytorch_data_validation = _PyTorchDataValidation(
passed=all(v.passed for v in pytorch_data_split_validation_dict.values()),
**pytorch_data_split_validation_dict,
Expand Down
Loading
Loading