Skip to content

Commit 7848cda

Browse files
author
João Felipe Santos
committed
Restoring missing comments
1 parent ec5097c commit 7848cda

File tree

6 files changed

+24
-5
lines changed

6 files changed

+24
-5
lines changed

NAM/convnet.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,10 @@ nam::convnet::ConvNetConfig nam::convnet::parse_config_json(const nlohmann::json
329329
c.channels = config["channels"];
330330
c.dilations = config["dilations"].get<std::vector<int>>();
331331
c.batchnorm = config["batchnorm"];
332+
// Parse JSON into typed ActivationConfig at model loading boundary
332333
c.activation = activations::ActivationConfig::from_json(config["activation"]);
333-
c.groups = config.value("groups", 1);
334+
c.groups = config.value("groups", 1); // defaults to 1
335+
// Default to 1 channel in/out for backward compatibility
334336
c.in_channels = config.value("in_channels", 1);
335337
c.out_channels = config.value("out_channels", 1);
336338
return c;

NAM/dsp.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ nam::linear::LinearConfig nam::linear::parse_config_json(const nlohmann::json& c
306306
LinearConfig c;
307307
c.receptive_field = config["receptive_field"];
308308
c.bias = config["bias"];
309+
// Default to 1 channel in/out for backward compatibility
309310
c.in_channels = config.value("in_channels", 1);
310311
c.out_channels = config.value("out_channels", 1);
311312
return c;

NAM/get_dsp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ std::unique_ptr<DSP> create_dsp(std::unique_ptr<ModelConfig> config, std::vector
173173
{
174174
auto out = config->create(std::move(weights), metadata.sample_rate);
175175
apply_metadata(*out, metadata);
176+
// "pre-warm" the model to settle initial conditions
177+
// Can this be removed now that it's part of Reset()?
176178
out->prewarm();
177179
return out;
178180
}

NAM/lstm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ nam::lstm::LSTMConfig nam::lstm::parse_config_json(const nlohmann::json& config)
170170
c.num_layers = config["num_layers"];
171171
c.input_size = config["input_size"];
172172
c.hidden_size = config["hidden_size"];
173+
// Default to 1 channel in/out for backward compatibility
173174
c.in_channels = config.value("in_channels", 1);
174175
c.out_channels = config.value("out_channels", 1);
175176
return c;

NAM/registry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class FactoryConfig : public ModelConfig
5454
/// automatically register a factory when the program starts.
5555
struct Helper
5656
{
57+
/// \param name Architecture name
58+
/// \param factory Factory function
5759
Helper(const std::string& name, FactoryFunction factory)
5860
{
5961
// Capture factory by value in the lambda

NAM/wavenet.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,11 +592,11 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
592592
{
593593
nlohmann::json layer_config = config["layers"][i];
594594

595-
const int groups = layer_config.value("groups_input", 1);
596-
const int groups_input_mixin = layer_config.value("groups_input_mixin", 1);
595+
const int groups = layer_config.value("groups_input", 1); // defaults to 1
596+
const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); // defaults to 1
597597

598598
const int channels = layer_config["channels"];
599-
const int bottleneck = layer_config.value("bottleneck", channels);
599+
const int bottleneck = layer_config.value("bottleneck", channels); // defaults to channels if not present
600600

601601
// Parse layer1x1 parameters
602602
bool layer1x1_active = true;
@@ -633,12 +633,13 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
633633
}
634634
else
635635
{
636+
// Single activation config - duplicate it for all layers
636637
const activations::ActivationConfig activation_config =
637638
activations::ActivationConfig::from_json(layer_config["activation"]);
638639
activation_configs.resize(num_layers, activation_config);
639640
}
640641

641-
// Parse gating mode(s)
642+
// Parse gating mode(s) - support both single value and array, and old "gated" boolean
642643
std::vector<GatingMode> gating_modes;
643644
std::vector<activations::ActivationConfig> secondary_activation_configs;
644645

@@ -663,6 +664,7 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
663664
GatingMode mode = parse_gating_mode_str(gating_mode_str);
664665
gating_modes.push_back(mode);
665666

667+
// Parse corresponding secondary activation if gating is enabled
666668
if (mode != GatingMode::NONE)
667669
{
668670
if (layer_config.find("secondary_activation") != layer_config.end())
@@ -680,12 +682,14 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
680682
}
681683
else
682684
{
685+
// Single secondary activation - use for all gated layers
683686
secondary_activation_configs.push_back(
684687
activations::ActivationConfig::from_json(layer_config["secondary_activation"]));
685688
}
686689
}
687690
else
688691
{
692+
// Default to Sigmoid for backward compatibility
689693
secondary_activation_configs.push_back(
690694
activations::ActivationConfig::simple(activations::ActivationType::Sigmoid));
691695
}
@@ -701,6 +705,7 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
701705
+ std::to_string(gating_modes.size()) + ") must match dilations size ("
702706
+ std::to_string(num_layers) + ")");
703707
}
708+
// Validate secondary_activation array size if it's an array
704709
if (layer_config.find("secondary_activation") != layer_config.end()
705710
&& layer_config["secondary_activation"].is_array())
706711
{
@@ -714,6 +719,7 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
714719
}
715720
else
716721
{
722+
// Single gating mode - duplicate for all layers
717723
std::string gating_mode_str = layer_config["gating_mode"].get<std::string>();
718724
GatingMode gating_mode = parse_gating_mode_str(gating_mode_str);
719725
gating_modes.resize(num_layers, gating_mode);
@@ -728,12 +734,14 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
728734
}
729735
else
730736
{
737+
// Default to Sigmoid for backward compatibility
731738
secondary_activation_config = activations::ActivationConfig::simple(activations::ActivationType::Sigmoid);
732739
}
733740
}
734741
secondary_activation_configs.resize(num_layers, secondary_activation_config);
735742
}
736743
}
744+
// Backward compatibility: convert old "gated" boolean to new enum
737745
else if (layer_config.find("gated") != layer_config.end())
738746
{
739747
bool gated = layer_config["gated"];
@@ -753,6 +761,7 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
753761
}
754762
else
755763
{
764+
// Default to NONE for all layers
756765
gating_modes.resize(num_layers, GatingMode::NONE);
757766
secondary_activation_configs.resize(num_layers, activations::ActivationConfig{});
758767
}
@@ -785,6 +794,7 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
785794
return nam::wavenet::_FiLMParams(active, shift, film_groups);
786795
};
787796

797+
// Parse FiLM parameters
788798
nam::wavenet::_FiLMParams conv_pre_film_params = parse_film_params("conv_pre_film");
789799
nam::wavenet::_FiLMParams conv_post_film_params = parse_film_params("conv_post_film");
790800
nam::wavenet::_FiLMParams input_mixin_pre_film_params = parse_film_params("input_mixin_pre_film");
@@ -794,6 +804,7 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
794804
nam::wavenet::_FiLMParams _layer1x1_post_film_params = parse_film_params("layer1x1_post_film");
795805
nam::wavenet::_FiLMParams head1x1_post_film_params = parse_film_params("head1x1_post_film");
796806

807+
// Validation: if layer1x1_post_film is active, layer1x1 must also be active
797808
if (_layer1x1_post_film_params.active && !layer1x1_active)
798809
{
799810
throw std::runtime_error("Layer array " + std::to_string(i)

0 commit comments

Comments
 (0)