-
Notifications
You must be signed in to change notification settings - Fork 122
Head 1x1 convolution #189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Head 1x1 convolution #189
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
30c6c97
Implemented head1x1 operation, API to config and activate it still no…
e2bf465
Merged with bottleneck updates
e9f922c
Merge
abef748
Fixed some shape issues, added tests, updated all tests to pass head1…
3ab6091
Updated head1x1 test to make sure it works when number of output chan…
ef26ae0
Uncommented and simplified gated tests now that everything else is wo…
9070d18
Merge branch 'main' into head1x1
jfsantos fa2b2ee
Fixing issue post merge
5b65bfd
Updated CMakeLists.txt to solve issue with Eigen
27020f8
Updated init issue in wavenet.cpp, relaxed test boundary since test w…
1d21470
Addressing comments
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,15 +19,29 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize) | |
| // Pre-allocate output buffers | ||
| const long channels = this->get_channels(); | ||
| this->_output_next_layer.resize(channels, maxBufferSize); | ||
| // _output_head stores the activated portion: bottleneck rows (the actual bottleneck value, not doubled) | ||
| this->_output_head.resize(this->_bottleneck, maxBufferSize); | ||
| // _output_head stores the activated portion: bottleneck rows when no head1x1, or head1x1 out_channels when head1x1 is active | ||
| if (_head1x1) | ||
| { | ||
| this->_output_head.resize(_head1x1->get_out_channels(), maxBufferSize); | ||
| this->_output_head.setZero(); // Ensure consistent initialization across platforms | ||
| _head1x1->SetMaxBufferSize(maxBufferSize); | ||
| } | ||
| else | ||
| { | ||
| this->_output_head.resize(this->_bottleneck, maxBufferSize); | ||
| this->_output_head.setZero(); // Ensure consistent initialization across platforms | ||
| } | ||
| } | ||
|
|
||
| void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights) | ||
| { | ||
| this->_conv.set_weights_(weights); | ||
| this->_input_mixin.set_weights_(weights); | ||
| this->_1x1.set_weights_(weights); | ||
| if (this->_head1x1) | ||
| { | ||
| this->_head1x1->set_weights_(weights); | ||
| } | ||
| } | ||
|
|
||
| void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames) | ||
|
|
@@ -61,31 +75,39 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma | |
| _1x1.process_(_z.topRows(bottleneck), num_frames); // Might not be RT safe | ||
| } | ||
|
|
||
| // Store output to head (skip connection: activated conv output) | ||
| if (!this->_gated) | ||
| this->_output_head.leftCols(num_frames).noalias() = this->_z.leftCols(num_frames); | ||
| else | ||
| this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(bottleneck).leftCols(num_frames); | ||
| if (this->_head1x1) { | ||
| if (!this->_gated) | ||
| this->_head1x1->process_(this->_z.leftCols(num_frames), num_frames); | ||
| else | ||
| this->_head1x1->process(this->_z.topRows(bottleneck).leftCols(num_frames), num_frames); | ||
| this->_output_head.leftCols(num_frames).noalias() = this->_head1x1->GetOutput().leftCols(num_frames); | ||
| } else { | ||
| // Store output to head (skip connection: activated conv output) | ||
| if (!this->_gated) | ||
| this->_output_head.leftCols(num_frames).noalias() = this->_z.leftCols(num_frames); | ||
| else | ||
| this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(bottleneck).leftCols(num_frames); | ||
| } | ||
|
|
||
| // Store output to next layer (residual connection: input + _1x1 output) | ||
| this->_output_next_layer.leftCols(num_frames).noalias() = | ||
| input.leftCols(num_frames) + _1x1.GetOutput().leftCols(num_frames); | ||
| } | ||
|
|
||
|
|
||
| // LayerArray ================================================================= | ||
|
|
||
| nam::wavenet::_LayerArray::_LayerArray(const int input_size, const int condition_size, const int head_size, | ||
| const int channels, const int bottleneck, const int kernel_size, | ||
| const std::vector<int>& dilations, const std::string activation, | ||
| const bool gated, const bool head_bias, const int groups_input, | ||
| const int groups_1x1) | ||
| const int groups_1x1, const Head1x1Params& head1x1_params) | ||
| : _rechannel(input_size, channels, false) | ||
| , _head_rechannel(bottleneck, head_size, head_bias) | ||
| , _bottleneck(bottleneck) | ||
| { | ||
| for (size_t i = 0; i < dilations.size(); i++) | ||
| this->_layers.push_back(_Layer( | ||
| condition_size, channels, bottleneck, kernel_size, dilations[i], activation, gated, groups_input, groups_1x1)); | ||
| condition_size, channels, bottleneck, kernel_size, dilations[i], activation, gated, groups_input, groups_1x1, head1x1_params)); | ||
| } | ||
|
|
||
| void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize) | ||
|
|
@@ -212,7 +234,8 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels, | |
| layer_array_params[i].input_size, layer_array_params[i].condition_size, layer_array_params[i].head_size, | ||
| layer_array_params[i].channels, layer_array_params[i].bottleneck, layer_array_params[i].kernel_size, | ||
| layer_array_params[i].dilations, layer_array_params[i].activation, layer_array_params[i].gated, | ||
| layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_1x1)); | ||
| layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_1x1, | ||
| layer_array_params[i].head1x1_params)); | ||
| if (i > 0) | ||
| if (layer_array_params[i].channels != layer_array_params[i - 1].head_size) | ||
| { | ||
|
|
@@ -324,10 +347,17 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st | |
| const int groups_1x1 = layer_config.value("groups_1x1", 1); // defaults to 1 | ||
| const int channels = layer_config["channels"]; | ||
| const int bottleneck = layer_config.value("bottleneck", channels); // defaults to channels if not present | ||
|
|
||
| // Parse head1x1 parameters | ||
| bool head1x1_active = layer_config.value("head1x1_active", false); | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: if Nit because it'll still work for now; I just like loud failures :) |
||
| int head1x1_out_channels = layer_config.value("head1x1_out_channels", channels); | ||
| int head1x1_groups = layer_config.value("head1x1_groups", 1); | ||
| nam::wavenet::Head1x1Params head1x1_params(head1x1_active, head1x1_out_channels, head1x1_groups); | ||
|
|
||
| layer_array_params.push_back(nam::wavenet::LayerArrayParams( | ||
| layer_config["input_size"], layer_config["condition_size"], layer_config["head_size"], channels, bottleneck, | ||
| layer_config["kernel_size"], layer_config["dilations"], layer_config["activation"], layer_config["gated"], | ||
| layer_config["head_bias"], groups, groups_1x1)); | ||
| layer_config["head_bias"], groups, groups_1x1, head1x1_params)); | ||
| } | ||
| const bool with_head = !config["head"].is_null(); | ||
| const float head_scale = config["head_scale"]; | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting that this wasn't(?) a (big?) problem previously. But I'm not great at CMake :)