Skip to content

Commit d1ffcb3

Browse files
jfsantosJoão Felipe Santos
andauthored
Head 1x1 convolution (#189)
* Implemented head1x1 operation, API to config and activate it still not implemented * Merged with bottleneck updates * Fixed some shape issues, added tests, updated all tests to pass head1x1_params when instantiating Wavenet and LayerArray * Updated head1x1 test to make sure it works when number of output channels is different from number of bottleneck channels * Uncommented and simplified gated tests now that everything else is working. * Fixing issue post merge * Updated CMakeLists.txt to solve issue with Eigen * Updated init issue in wavenet.cpp, relaxed test boundary since test was passing on ARM but failing on x86 * Addressing comments --------- Co-authored-by: João Felipe Santos <santosjf@pm.me>
1 parent 827a6cf commit d1ffcb3

File tree

9 files changed

+548
-49
lines changed

9 files changed

+548
-49
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ else()
2525
endif()
2626

2727
set(NAM_DEPS_PATH "${CMAKE_CURRENT_SOURCE_DIR}/Dependencies")
28+
include_directories(SYSTEM "${NAM_DEPS_PATH}/eigen")
2829

2930
add_subdirectory(tools)
3031

NAM/wavenet.cpp

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,29 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize)
1919
// Pre-allocate output buffers
2020
const long channels = this->get_channels();
2121
this->_output_next_layer.resize(channels, maxBufferSize);
22-
// _output_head stores the activated portion: bottleneck rows (the actual bottleneck value, not doubled)
23-
this->_output_head.resize(this->_bottleneck, maxBufferSize);
22+
// _output_head stores the activated portion: bottleneck rows when no head1x1, or head1x1 out_channels when head1x1 is active
23+
if (_head1x1)
24+
{
25+
this->_output_head.resize(_head1x1->get_out_channels(), maxBufferSize);
26+
this->_output_head.setZero(); // Ensure consistent initialization across platforms
27+
_head1x1->SetMaxBufferSize(maxBufferSize);
28+
}
29+
else
30+
{
31+
this->_output_head.resize(this->_bottleneck, maxBufferSize);
32+
this->_output_head.setZero(); // Ensure consistent initialization across platforms
33+
}
2434
}
2535

2636
void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights)
2737
{
2838
this->_conv.set_weights_(weights);
2939
this->_input_mixin.set_weights_(weights);
3040
this->_1x1.set_weights_(weights);
41+
if (this->_head1x1)
42+
{
43+
this->_head1x1->set_weights_(weights);
44+
}
3145
}
3246

3347
void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames)
@@ -61,31 +75,39 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
6175
_1x1.process_(_z.topRows(bottleneck), num_frames); // Might not be RT safe
6276
}
6377

64-
// Store output to head (skip connection: activated conv output)
65-
if (!this->_gated)
66-
this->_output_head.leftCols(num_frames).noalias() = this->_z.leftCols(num_frames);
67-
else
68-
this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(bottleneck).leftCols(num_frames);
78+
if (this->_head1x1) {
79+
if (!this->_gated)
80+
this->_head1x1->process_(this->_z.leftCols(num_frames), num_frames);
81+
else
82+
this->_head1x1->process(this->_z.topRows(bottleneck).leftCols(num_frames), num_frames);
83+
this->_output_head.leftCols(num_frames).noalias() = this->_head1x1->GetOutput().leftCols(num_frames);
84+
} else {
85+
// Store output to head (skip connection: activated conv output)
86+
if (!this->_gated)
87+
this->_output_head.leftCols(num_frames).noalias() = this->_z.leftCols(num_frames);
88+
else
89+
this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(bottleneck).leftCols(num_frames);
90+
}
91+
6992
// Store output to next layer (residual connection: input + _1x1 output)
7093
this->_output_next_layer.leftCols(num_frames).noalias() =
7194
input.leftCols(num_frames) + _1x1.GetOutput().leftCols(num_frames);
7295
}
7396

74-
7597
// LayerArray =================================================================
7698

7799
nam::wavenet::_LayerArray::_LayerArray(const int input_size, const int condition_size, const int head_size,
78100
const int channels, const int bottleneck, const int kernel_size,
79101
const std::vector<int>& dilations, const std::string activation,
80102
const bool gated, const bool head_bias, const int groups_input,
81-
const int groups_1x1)
103+
const int groups_1x1, const Head1x1Params& head1x1_params)
82104
: _rechannel(input_size, channels, false)
83105
, _head_rechannel(bottleneck, head_size, head_bias)
84106
, _bottleneck(bottleneck)
85107
{
86108
for (size_t i = 0; i < dilations.size(); i++)
87109
this->_layers.push_back(_Layer(
88-
condition_size, channels, bottleneck, kernel_size, dilations[i], activation, gated, groups_input, groups_1x1));
110+
condition_size, channels, bottleneck, kernel_size, dilations[i], activation, gated, groups_input, groups_1x1, head1x1_params));
89111
}
90112

91113
void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize)
@@ -212,7 +234,8 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels,
212234
layer_array_params[i].input_size, layer_array_params[i].condition_size, layer_array_params[i].head_size,
213235
layer_array_params[i].channels, layer_array_params[i].bottleneck, layer_array_params[i].kernel_size,
214236
layer_array_params[i].dilations, layer_array_params[i].activation, layer_array_params[i].gated,
215-
layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_1x1));
237+
layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_1x1,
238+
layer_array_params[i].head1x1_params));
216239
if (i > 0)
217240
if (layer_array_params[i].channels != layer_array_params[i - 1].head_size)
218241
{
@@ -324,10 +347,17 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
324347
const int groups_1x1 = layer_config.value("groups_1x1", 1); // defaults to 1
325348
const int channels = layer_config["channels"];
326349
const int bottleneck = layer_config.value("bottleneck", channels); // defaults to channels if not present
350+
351+
// Parse head1x1 parameters
352+
bool head1x1_active = layer_config.value("head1x1_active", false);
353+
int head1x1_out_channels = layer_config.value("head1x1_out_channels", channels);
354+
int head1x1_groups = layer_config.value("head1x1_groups", 1);
355+
nam::wavenet::Head1x1Params head1x1_params(head1x1_active, head1x1_out_channels, head1x1_groups);
356+
327357
layer_array_params.push_back(nam::wavenet::LayerArrayParams(
328358
layer_config["input_size"], layer_config["condition_size"], layer_config["head_size"], channels, bottleneck,
329359
layer_config["kernel_size"], layer_config["dilations"], layer_config["activation"], layer_config["gated"],
330-
layer_config["head_bias"], groups, groups_1x1));
360+
layer_config["head_bias"], groups, groups_1x1, head1x1_params));
331361
}
332362
const bool with_head = !config["head"].is_null();
333363
const float head_scale = config["head_scale"];

NAM/wavenet.h

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <string>
44
#include <vector>
5+
#include <memory>
56

67
#include "json.hpp"
78
#include <Eigen/Dense>
@@ -13,17 +14,36 @@ namespace nam
1314
{
1415
namespace wavenet
1516
{
17+
// Parameters for head1x1 configuration
18+
struct Head1x1Params
19+
{
20+
Head1x1Params(bool active_, int out_channels_, int groups_)
21+
: active(active_), out_channels(out_channels_), groups(groups_) {}
22+
23+
const bool active;
24+
const int out_channels;
25+
const int groups;
26+
};
27+
1628
class _Layer
1729
{
1830
public:
1931
_Layer(const int condition_size, const int channels, const int bottleneck, const int kernel_size, const int dilation,
20-
const std::string activation, const bool gated, const int groups_input, const int groups_1x1)
21-
: _conv(channels, gated ? 2 * bottleneck : bottleneck, kernel_size, true, dilation, groups_input)
32+
const std::string activation, const bool gated, const int groups_input, const int groups_1x1,
33+
const Head1x1Params& head1x1_params)
34+
: _conv(channels, gated ? 2 * bottleneck : bottleneck, kernel_size, true, dilation)
2235
, _input_mixin(condition_size, gated ? 2 * bottleneck : bottleneck, false)
23-
, _1x1(bottleneck, channels, true, groups_1x1)
36+
, _1x1(bottleneck, channels, groups_1x1)
2437
, _activation(activations::Activation::get_activation(activation)) // needs to support activations with parameters
25-
, _gated(gated)
26-
, _bottleneck(bottleneck) {};
38+
, _gated(gated)
39+
, _bottleneck(bottleneck)
40+
{
41+
if (head1x1_params.active)
42+
{
43+
_head1x1 = std::make_unique<Conv1x1>(bottleneck, head1x1_params.out_channels, true, head1x1_params.groups);
44+
}
45+
};
46+
2747
// Resize all arrays to be able to process `maxBufferSize` frames.
2848
void SetMaxBufferSize(const int maxBufferSize);
2949
// Set the parameters of this module
@@ -63,6 +83,8 @@ class _Layer
6383
Conv1x1 _input_mixin;
6484
// The post-activation 1x1 convolution
6585
Conv1x1 _1x1;
86+
// The post-activation 1x1 convolution outputting to the head, optional
87+
std::unique_ptr<Conv1x1> _head1x1;
6688
// The internal state
6789
Eigen::MatrixXf _z;
6890
// Output to next layer (residual connection: input + _1x1 output)
@@ -81,7 +103,7 @@ class LayerArrayParams
81103
LayerArrayParams(const int input_size_, const int condition_size_, const int head_size_, const int channels_,
82104
const int bottleneck_, const int kernel_size_, const std::vector<int>&& dilations_,
83105
const std::string activation_, const bool gated_, const bool head_bias_, const int groups_input,
84-
const int groups_1x1_)
106+
const int groups_1x1_, const Head1x1Params& head1x1_params_)
85107
: input_size(input_size_)
86108
, condition_size(condition_size_)
87109
, head_size(head_size_)
@@ -94,6 +116,7 @@ class LayerArrayParams
94116
, head_bias(head_bias_)
95117
, groups_input(groups_input)
96118
, groups_1x1(groups_1x1_)
119+
, head1x1_params(head1x1_params_)
97120
{
98121
}
99122

@@ -109,6 +132,7 @@ class LayerArrayParams
109132
const bool head_bias;
110133
const int groups_input;
111134
const int groups_1x1;
135+
const Head1x1Params head1x1_params;
112136
};
113137

114138
// An array of layers with the same channels, kernel sizes, activations.
@@ -118,7 +142,7 @@ class _LayerArray
118142
_LayerArray(const int input_size, const int condition_size, const int head_size, const int channels,
119143
const int bottleneck, const int kernel_size, const std::vector<int>& dilations,
120144
const std::string activation, const bool gated, const bool head_bias, const int groups_input,
121-
const int groups_1x1);
145+
const int groups_1x1, const Head1x1Params& head1x1_params);
122146

123147
void SetMaxBufferSize(const int maxBufferSize);
124148

tools/run_tests.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "test/test_wavenet/test_layer_array.cpp"
1515
#include "test/test_wavenet/test_full.cpp"
1616
#include "test/test_wavenet/test_real_time_safe.cpp"
17+
#include "test/test_wavenet/test_head1x1.cpp"
1718
#include "test/test_gating_activations.cpp"
1819
#include "test/test_wavenet_gating_compatibility.cpp"
1920
#include "test/test_blending_detailed.cpp"
@@ -115,6 +116,11 @@ int main()
115116
test_wavenet::test_full::test_wavenet_zero_input();
116117
test_wavenet::test_full::test_wavenet_different_buffer_sizes();
117118
test_wavenet::test_full::test_wavenet_prewarm();
119+
test_wavenet::test_head1x1::test_head1x1_inactive();
120+
test_wavenet::test_head1x1::test_head1x1_active();
121+
test_wavenet::test_head1x1::test_head1x1_gated();
122+
test_wavenet::test_head1x1::test_head1x1_groups();
123+
test_wavenet::test_head1x1::test_head1x1_different_out_channels();
118124
test_wavenet::test_allocation_tracking_pass();
119125
test_wavenet::test_allocation_tracking_fail();
120126
test_wavenet::test_conv1d_process_realtime_safe();

tools/test/test_wavenet/test_full.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ void test_wavenet_model()
2929
const bool with_head = false;
3030
const int groups = 1;
3131
const int groups_1x1 = 1;
32+
const bool head1x1_active = false;
3233

34+
nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, 1);
3335
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, bottleneck, kernel_size,
34-
std::move(dilations), activation, gated, head_bias, groups, groups_1x1);
36+
std::move(dilations), activation, gated, head_bias, groups, groups_1x1, head1x1_params);
3537
std::vector<nam::wavenet::LayerArrayParams> layer_array_params;
3638
layer_array_params.push_back(std::move(params));
3739

@@ -91,14 +93,17 @@ void test_wavenet_multiple_arrays()
9193
std::vector<int> dilations1{1};
9294
const int bottleneck = channels;
9395
const int groups_1x1 = 1;
96+
const bool head1x1_active = false;
97+
98+
nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, 1);
9499
layer_array_params.push_back(nam::wavenet::LayerArrayParams(input_size, condition_size, head_size, channels,
95-
bottleneck, kernel_size, std::move(dilations1),
96-
activation, gated, head_bias, groups, groups_1x1));
100+
bottleneck, kernel_size, std::move(dilations1), activation,
101+
gated, head_bias, groups, groups_1x1, head1x1_params));
97102
// Second array (head_size of first must match channels of second)
98103
std::vector<int> dilations2{1};
99104
layer_array_params.push_back(nam::wavenet::LayerArrayParams(head_size, condition_size, head_size, channels,
100-
bottleneck, kernel_size, std::move(dilations2),
101-
activation, gated, head_bias, groups, groups_1x1));
105+
bottleneck, kernel_size, std::move(dilations2), activation,
106+
gated, head_bias, groups, groups_1x1, head1x1_params));
102107

103108
std::vector<float> weights;
104109
// Array 0: rechannel, layer, head_rechannel
@@ -145,9 +150,11 @@ void test_wavenet_zero_input()
145150
const bool with_head = false;
146151
const int groups = 1;
147152
const int groups_1x1 = 1;
153+
const bool head1x1_active = false;
154+
nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, 1);
148155

149156
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, bottleneck, kernel_size,
150-
std::move(dilations), activation, gated, head_bias, groups, groups_1x1);
157+
std::move(dilations), activation, gated, head_bias, groups, groups_1x1, head1x1_params);
151158
std::vector<nam::wavenet::LayerArrayParams> layer_array_params;
152159
layer_array_params.push_back(std::move(params));
153160

@@ -190,9 +197,11 @@ void test_wavenet_different_buffer_sizes()
190197
const bool with_head = false;
191198
const int groups = 1;
192199
const int groups_1x1 = 1;
200+
const bool head1x1_active = false;
201+
nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, 1);
193202

194203
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, bottleneck, kernel_size,
195-
std::move(dilations), activation, gated, head_bias, groups, groups_1x1);
204+
std::move(dilations), activation, gated, head_bias, groups, groups_1x1, head1x1_params);
196205
std::vector<nam::wavenet::LayerArrayParams> layer_array_params;
197206
layer_array_params.push_back(std::move(params));
198207

@@ -238,9 +247,12 @@ void test_wavenet_prewarm()
238247
const bool with_head = false;
239248
const int groups = 1;
240249
const int groups_1x1 = 1;
250+
const bool head1x1_active = false;
251+
252+
nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, 1);
241253

242254
nam::wavenet::LayerArrayParams params(input_size, condition_size, head_size, channels, bottleneck, kernel_size,
243-
std::move(dilations), activation, gated, head_bias, groups, groups_1x1);
255+
std::move(dilations), activation, gated, head_bias, groups, groups_1x1, head1x1_params);
244256
std::vector<nam::wavenet::LayerArrayParams> layer_array_params;
245257
layer_array_params.push_back(std::move(params));
246258

0 commit comments

Comments
 (0)