Skip to content

Commit 849ef2d

Browse files
authored
Merge branch 'main' into head1x1
2 parents ef26ae0 + 827a6cf commit 849ef2d

File tree

15 files changed

+1191
-202
lines changed

15 files changed

+1191
-202
lines changed

NAM/convnet.cpp

Lines changed: 90 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -129,39 +129,69 @@ long nam::convnet::ConvNetBlock::get_out_channels() const
129129
return this->conv.get_out_channels();
130130
}
131131

132-
nam::convnet::_Head::_Head(const int channels, std::vector<float>::iterator& weights)
132+
nam::convnet::_Head::_Head(const int in_channels, const int out_channels, std::vector<float>::iterator& weights)
133133
{
134-
this->_weight.resize(channels);
135-
for (int i = 0; i < channels; i++)
136-
this->_weight[i] = *(weights++);
137-
this->_bias = *(weights++);
134+
// Weights are stored row-major: first row (output 0), then row 1 (output 1), etc.
135+
// For each output channel: [w0, w1, ..., w_{in_channels-1}]
136+
// Then biases: [bias0, bias1, ..., bias_{out_channels-1}]
137+
this->_weight.resize(out_channels, in_channels);
138+
for (int out_ch = 0; out_ch < out_channels; out_ch++)
139+
{
140+
for (int in_ch = 0; in_ch < in_channels; in_ch++)
141+
{
142+
this->_weight(out_ch, in_ch) = *(weights++);
143+
}
144+
}
145+
146+
// Biases for each output channel
147+
this->_bias.resize(out_channels);
148+
for (int out_ch = 0; out_ch < out_channels; out_ch++)
149+
{
150+
this->_bias(out_ch) = *(weights++);
151+
}
138152
}
139153

140-
void nam::convnet::_Head::process_(const Eigen::MatrixXf& input, Eigen::VectorXf& output, const long i_start,
154+
void nam::convnet::_Head::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start,
141155
const long i_end) const
142156
{
143157
const long length = i_end - i_start;
144-
output.resize(length);
145-
for (long i = 0, j = i_start; i < length; i++, j++)
146-
output(i) = this->_bias + input.col(j).dot(this->_weight);
158+
const long out_channels = this->_weight.rows();
159+
160+
// Resize output to (out_channels x length)
161+
output.resize(out_channels, length);
162+
163+
// Extract input slice: (in_channels x length)
164+
Eigen::MatrixXf input_slice = input.middleCols(i_start, length);
165+
166+
// Compute output = weight * input_slice: (out_channels x in_channels) * (in_channels x length) = (out_channels x
167+
// length)
168+
output.noalias() = this->_weight * input_slice;
169+
170+
// Add bias to each column: output.colwise() += bias
171+
// output is (out_channels x length), bias is (out_channels x 1), so colwise() += works
172+
output.colwise() += this->_bias;
147173
}
148174

149-
nam::convnet::ConvNet::ConvNet(const int channels, const std::vector<int>& dilations, const bool batchnorm,
150-
const std::string activation, std::vector<float>& weights,
151-
const double expected_sample_rate, const int groups)
152-
: Buffer(*std::max_element(dilations.begin(), dilations.end()), expected_sample_rate)
175+
nam::convnet::ConvNet::ConvNet(const int in_channels, const int out_channels, const int channels,
176+
const std::vector<int>& dilations, const bool batchnorm, const std::string activation,
177+
std::vector<float>& weights, const double expected_sample_rate, const int groups)
178+
: Buffer(in_channels, out_channels, *std::max_element(dilations.begin(), dilations.end()), expected_sample_rate)
153179
{
154180
this->_verify_weights(channels, dilations, batchnorm, weights.size());
155181
this->_blocks.resize(dilations.size());
156182
std::vector<float>::iterator it = weights.begin();
183+
// First block takes in_channels input, subsequent blocks take channels input
157184
for (size_t i = 0; i < dilations.size(); i++)
158-
this->_blocks[i].set_weights_(i == 0 ? 1 : channels, channels, dilations[i], batchnorm, activation, groups, it);
185+
this->_blocks[i].set_weights_(
186+
i == 0 ? in_channels : channels, channels, dilations[i], batchnorm, activation, groups, it);
159187
// Only need _block_vals for the head (one entry)
160188
// Conv1D layers manage their own buffers now
161189
this->_block_vals.resize(1);
162190
this->_block_vals[0].setZero();
163-
std::fill(this->_input_buffer.begin(), this->_input_buffer.end(), 0.0f);
164-
this->_head = _Head(channels, it);
191+
192+
// Create single head that outputs all channels
193+
this->_head = _Head(channels, out_channels, it);
194+
165195
if (it != weights.end())
166196
throw std::runtime_error("Didn't touch all the weights when initializing ConvNet");
167197

@@ -171,18 +201,25 @@ nam::convnet::ConvNet::ConvNet(const int channels, const std::vector<int>& dilat
171201
}
172202

173203

174-
void nam::convnet::ConvNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames)
204+
void nam::convnet::ConvNet::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames)
175205

176206
{
177207
this->_update_buffers_(input, num_frames);
178-
// Main computation!
179-
const long i_start = this->_input_buffer_offset;
180-
const long i_end = i_start + num_frames;
208+
const int in_channels = NumInputChannels();
209+
const int out_channels = NumOutputChannels();
210+
211+
// For multi-channel, we process each input channel independently through the network
212+
// and sum outputs to each output channel (simple implementation)
213+
// This can be extended later for more sophisticated cross-channel processing
181214

182-
// Convert input buffer to matrix for first layer
183-
Eigen::MatrixXf input_matrix(1, num_frames);
184-
for (int i = 0; i < num_frames; i++)
185-
input_matrix(0, i) = this->_input_buffer[i_start + i];
215+
// Convert input buffers to matrix for first layer (stack input channels)
216+
Eigen::MatrixXf input_matrix(in_channels, num_frames);
217+
const long i_start = this->_input_buffer_offset;
218+
for (int ch = 0; ch < in_channels; ch++)
219+
{
220+
for (int i = 0; i < num_frames; i++)
221+
input_matrix(ch, i) = this->_input_buffers[ch][i_start + i];
222+
}
186223

187224
// Process through ConvNetBlock layers
188225
// Each block now uses Conv1D's internal buffers via Process() and GetOutput()
@@ -206,23 +243,33 @@ void nam::convnet::ConvNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const
206243
this->_blocks[i].Process(block_input, num_frames);
207244
}
208245

209-
// Process head with output from last Conv1D
210-
// Head still needs the old interface, so we need to provide it via a matrix
211-
// We still need _block_vals[0] for the head interface
246+
// Process head for all output channels at once
247+
// We need _block_vals[0] for the head interface
248+
const long buffer_size = (long)this->_input_buffers[0].size();
212249
if (this->_block_vals[0].rows() != this->_blocks.back().get_out_channels()
213-
|| this->_block_vals[0].cols() != (long)this->_input_buffer.size())
250+
|| this->_block_vals[0].cols() != buffer_size)
214251
{
215-
this->_block_vals[0].resize(this->_blocks.back().get_out_channels(), this->_input_buffer.size());
252+
this->_block_vals[0].resize(this->_blocks.back().get_out_channels(), buffer_size);
216253
}
254+
217255
// Copy last block output to _block_vals for head
218256
auto last_output = this->_blocks.back().GetOutput(num_frames);
219-
this->_block_vals[0].middleCols(i_start, num_frames) = last_output;
220-
221-
this->_head.process_(this->_block_vals[0], this->_head_output, i_start, i_end);
222-
223-
// Copy to required output array
224-
for (int s = 0; s < num_frames; s++)
225-
output[s] = this->_head_output(s);
257+
const long buffer_offset = this->_input_buffer_offset;
258+
const long buffer_i_end = buffer_offset + num_frames;
259+
// last_output is (channels x num_frames), _block_vals[0] is (channels x buffer_size)
260+
// Copy to the correct location in _block_vals
261+
this->_block_vals[0].block(0, buffer_offset, last_output.rows(), num_frames) = last_output;
262+
263+
// Process head - outputs all channels at once
264+
// Head will resize _head_output internally
265+
this->_head.process_(this->_block_vals[0], this->_head_output, buffer_offset, buffer_i_end);
266+
267+
// Copy to output arrays for each channel
268+
for (int ch = 0; ch < out_channels; ch++)
269+
{
270+
for (int s = 0; s < num_frames; s++)
271+
output[ch][s] = this->_head_output(ch, s);
272+
}
226273

227274
// Prepare for next call:
228275
nam::Buffer::_advance_input_buffer_(num_frames);
@@ -245,11 +292,12 @@ void nam::convnet::ConvNet::SetMaxBufferSize(const int maxBufferSize)
245292
}
246293
}
247294

248-
void nam::convnet::ConvNet::_update_buffers_(NAM_SAMPLE* input, const int num_frames)
295+
void nam::convnet::ConvNet::_update_buffers_(NAM_SAMPLE** input, const int num_frames)
249296
{
250297
this->Buffer::_update_buffers_(input, num_frames);
251298

252-
const long buffer_size = (long)this->_input_buffer.size();
299+
// All channels use the same buffer size
300+
const long buffer_size = (long)this->_input_buffers[0].size();
253301

254302
// Only need _block_vals[0] for the head
255303
// Conv1D layers manage their own buffers now
@@ -281,8 +329,11 @@ std::unique_ptr<nam::DSP> nam::convnet::Factory(const nlohmann::json& config, st
281329
const bool batchnorm = config["batchnorm"];
282330
const std::string activation = config["activation"];
283331
const int groups = config.value("groups", 1); // defaults to 1
332+
// Default to 1 channel in/out for backward compatibility
333+
const int in_channels = config.value("in_channels", 1);
334+
const int out_channels = config.value("out_channels", 1);
284335
return std::make_unique<nam::convnet::ConvNet>(
285-
channels, dilations, batchnorm, activation, weights, expectedSampleRate, groups);
336+
in_channels, out_channels, channels, dilations, batchnorm, activation, weights, expectedSampleRate, groups);
286337
}
287338

288339
namespace

NAM/convnet.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,32 +66,33 @@ class _Head
6666
{
6767
public:
6868
_Head() {};
69-
_Head(const int channels, std::vector<float>::iterator& weights);
70-
void process_(const Eigen::MatrixXf& input, Eigen::VectorXf& output, const long i_start, const long i_end) const;
69+
_Head(const int in_channels, const int out_channels, std::vector<float>::iterator& weights);
70+
void process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long i_end) const;
7171

7272
private:
73-
Eigen::VectorXf _weight;
74-
float _bias = 0.0f;
73+
Eigen::MatrixXf _weight; // (out_channels, in_channels)
74+
Eigen::VectorXf _bias; // (out_channels,)
7575
};
7676

7777
class ConvNet : public Buffer
7878
{
7979
public:
80-
ConvNet(const int channels, const std::vector<int>& dilations, const bool batchnorm, const std::string activation,
81-
std::vector<float>& weights, const double expected_sample_rate = -1.0, const int groups = 1);
80+
ConvNet(const int in_channels, const int out_channels, const int channels, const std::vector<int>& dilations,
81+
const bool batchnorm, const std::string activation, std::vector<float>& weights,
82+
const double expected_sample_rate = -1.0, const int groups = 1);
8283
~ConvNet() = default;
8384

84-
void process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames) override;
85+
void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override;
8586
void SetMaxBufferSize(const int maxBufferSize) override;
8687

8788
protected:
8889
std::vector<ConvNetBlock> _blocks;
8990
std::vector<Eigen::MatrixXf> _block_vals;
90-
Eigen::VectorXf _head_output;
91+
Eigen::MatrixXf _head_output; // (out_channels, num_frames)
9192
_Head _head;
9293
void _verify_weights(const int channels, const std::vector<int>& dilations, const bool batchnorm,
9394
const size_t actual_weights);
94-
void _update_buffers_(NAM_SAMPLE* input, const int num_frames) override;
95+
void _update_buffers_(NAM_SAMPLE** input, const int num_frames) override;
9596
void _rewind_buffers_() override;
9697

9798
int mPrewarmSamples = 0; // Pre-compute during initialization

0 commit comments

Comments
 (0)