Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions nam/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
Loss functions
"""

from typing import Literal as _Literal
from typing import Optional as _Optional

import torch as _torch
import torch.nn as _nn

from .._dependencies.auraloss.freq import (
MultiResolutionSTFTLoss as _MultiResolutionSTFTLoss,
Expand Down Expand Up @@ -111,3 +113,99 @@ def mse_fft(preds: _torch.Tensor, targets: _torch.Tensor) -> _torch.Tensor:
ft = _torch.fft.fft(targets)
e = fp - ft
return _torch.mean(_torch.square(e.abs()))


class SpectralBandLoss(_nn.Module):
"""
Penalize spectral error in a specific frequency band (e.g. 7–12 kHz for ring).
Uses time-averaged spectrum to avoid transient-driven gradients.

:param sample_rate: Sample rate in Hz
:param fft_size: FFT size (4096 gives ~12 Hz resolution at 48 kHz)
:param hop_length: STFT hop length
:param low_hz: Bottom of problem band (Hz)
:param high_hz: Top of problem band (Hz)
:param weight: Loss weight
:param penalize: "excess" = only when pred > target (fixes ring), "missing" = only when pred < target
:param log_scale: If True, use log magnitude (ear-like, reduces transient dominance)
"""

def __init__(
self,
sample_rate: int = 48000,
fft_size: int = 4096,
hop_length: _Optional[int] = None,
low_hz: float = 8000,
high_hz: float = 12000,
weight: float = 1.0,
penalize: _Literal["excess", "missing"] = "excess",
log_scale: bool = True,
):
super().__init__()
self.fft_size = fft_size
self.hop_length = hop_length if hop_length is not None else fft_size // 4
self.weight = weight
self.penalize = penalize
self.log_scale = log_scale
self._low_hz = low_hz
self._high_hz = high_hz
self._sample_rate = sample_rate

freqs = _torch.fft.rfftfreq(fft_size, d=1.0 / sample_rate)
self.register_buffer(
"band_mask",
((freqs >= low_hz) & (freqs <= high_hz)).float(),
)

def update_sample_rate(self, sample_rate: int) -> None:
"""Recompute band_mask for a different sample rate (e.g. after dataset handshake)."""
if sample_rate == self._sample_rate:
return
self._sample_rate = sample_rate
freqs = _torch.fft.rfftfreq(self.fft_size, d=1.0 / sample_rate)
new_mask = ((freqs >= self._low_hz) & (freqs <= self._high_hz)).float()
self.band_mask.copy_(new_mask.to(device=self.band_mask.device))

def forward(
self,
pred: _torch.Tensor,
target: _torch.Tensor,
window: _Optional[_torch.Tensor] = None,
) -> _torch.Tensor:
"""
:param pred: (B, L) or (B, 1, L)
:param target: Same shape as pred
:param window: Optional Hann window; created if None
:return: Scalar loss
"""
if pred.dim() == 3:
pred = pred.squeeze(1)
if target.dim() == 3:
target = target.squeeze(1)

if window is None:
window = _torch.hann_window(self.fft_size, device=pred.device)

def mean_spectrum(x: _torch.Tensor) -> _torch.Tensor:
stft = _torch.stft(
x,
n_fft=self.fft_size,
hop_length=self.hop_length,
window=window,
return_complex=True,
)
mag = stft.abs() # (B, F, T)
mean_mag = mag.mean(dim=-1) # (B, F)
if self.log_scale:
return _torch.log(mean_mag + 1e-8)
return mean_mag

pred_spec = mean_spectrum(pred)
target_spec = mean_spectrum(target)

if self.penalize == "excess":
band_error = _torch.relu(pred_spec - target_spec) * self.band_mask
else:
band_error = _torch.relu(target_spec - pred_spec) * self.band_mask

return self.weight * _torch.mean(band_error**2)
2 changes: 2 additions & 0 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,8 @@ def train(
)
sample_rate = train_dataloader.dataset.sample_rate
model.net.sample_rate = sample_rate
if model._spectral_band is not None:
model._spectral_band.update_sample_rate(sample_rate)

# Put together the metadata that's needed in checkpoints:
settings_metadata = _metadata.Settings(ignore_checks=ignore_checks)
Expand Down
3 changes: 3 additions & 0 deletions nam/train/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def main(
):
if not outdir.exists():
raise RuntimeError(f"No output location found at {outdir}")

# Write
for basename, config in (
("data", data_config),
Expand All @@ -165,6 +166,8 @@ def main(
f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}"
)
model.net.sample_rate = dataset_train.sample_rate
if model._spectral_band is not None:
model._spectral_band.update_sample_rate(dataset_train.sample_rate)

# Perform handshakes:
dataset_train.handshake(model.net)
Expand Down
26 changes: 26 additions & 0 deletions nam/train/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..models.factory import init as _init_model
from ..models.factory import register as _register_model
from ..models.linear import Linear as _Linear
from ..models.losses import SpectralBandLoss as _SpectralBandLoss
from ..models.losses import apply_pre_emphasis_filter as _apply_pre_emphasis_filter
from ..models.losses import esr as _esr
from ..models.losses import mse as _mse
Expand Down Expand Up @@ -97,6 +98,7 @@ class ValLossNameError(ValueError):
pre_emph_coef: _Optional[float] = None
pre_emph_mrstft_weight: _Optional[float] = None
pre_emph_mrstft_coef: _Optional[float] = None
spectral_band: _Optional[_Dict[str, _Any]] = None
custom_losses: _Optional[_Dict[str, _CustomLoss]] = None

@classmethod
Expand Down Expand Up @@ -162,6 +164,7 @@ def get_mrstft_weight() -> _Optional[float]:
"mrstft_weight": mrstft_weight,
"pre_emph_mrstft_weight": config.get("pre_emph_mrstft_weight"),
"pre_emph_mrstft_coef": config.get("pre_emph_mrstft_coef"),
"spectral_band": config.get("spectral_band"),
"custom_losses": custom_losses,
}

Expand Down Expand Up @@ -211,6 +214,22 @@ def __init__(
# Keeping it on-device is preferable, but if that fails, then remember to drop
# it to cpu from then on.
self._mrstft_device: _Optional[_torch.device] = None
# Create _spectral_band eagerly when configured so checkpoint loading (strict=True) works
if self._loss_config.spectral_band is not None:
cfg = self._loss_config.spectral_band
sample_rate = int(getattr(self._net, "sample_rate", None) or 48000)
self._spectral_band = _SpectralBandLoss(
sample_rate=sample_rate,
fft_size=cfg.get("fft_size", 4096),
hop_length=cfg.get("hop_length"),
low_hz=cfg.get("low_hz", 8000),
high_hz=cfg.get("high_hz", 12000),
weight=1.0,
penalize=cfg.get("penalize", "excess"),
log_scale=cfg.get("log_scale", True),
)
else:
self._spectral_band = None

@classmethod
def init_from_config(cls, config):
Expand Down Expand Up @@ -404,6 +423,13 @@ def get_mse_loss():
loss_dict["MRSTFT"] = _LossItem(
self._loss_config.mrstft_weight, self._mrstft_loss(preds, targets)
)
# Spectral band loss (penalize energy in problem band, e.g. 7–12 kHz ring)
if self._spectral_band is not None:
weight = self._loss_config.spectral_band.get("weight", 0.1)
loss_dict["SpectralBand"] = _LossItem(
weight,
self._spectral_band(preds, targets),
)
# Pre-emphasized MRSTFT
if self._loss_config.pre_emph_mrstft_weight is not None:
loss_dict["Pre-emphasized MRSTFT"] = _LossItem(
Expand Down
16 changes: 16 additions & 0 deletions tests/test_nam/test_models/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,21 @@ def test_mrstft_loss_doesnt_fall_back_to_cpu():
_losses.multi_resolution_stft_loss(preds, targets, device="mps")


def test_spectral_band_loss():
"""SpectralBandLoss penalizes error in specified frequency band."""
loss = _losses.SpectralBandLoss(
sample_rate=44100,
low_hz=7000,
high_hz=12000,
penalize="excess",
weight=1.0,
)
pred = _torch.randn(2, 16000)
target = _torch.randn(2, 16000)
out = loss(pred, target)
assert out.dim() == 0
assert out.item() >= 0


if __name__ == "__main__":
_pytest.main()
Loading