@@ -129,39 +129,69 @@ long nam::convnet::ConvNetBlock::get_out_channels() const
129129 return this ->conv .get_out_channels ();
130130}
131131
132- nam::convnet::_Head::_Head (const int channels , std::vector<float >::iterator& weights)
132+ nam::convnet::_Head::_Head (const int in_channels, const int out_channels , std::vector<float >::iterator& weights)
133133{
134- this ->_weight .resize (channels);
135- for (int i = 0 ; i < channels; i++)
136- this ->_weight [i] = *(weights++);
137- this ->_bias = *(weights++);
134+ // Weights are stored row-major: first row (output 0), then row 1 (output 1), etc.
135+ // For each output channel: [w0, w1, ..., w_{in_channels-1}]
136+ // Then biases: [bias0, bias1, ..., bias_{out_channels-1}]
137+ this ->_weight .resize (out_channels, in_channels);
138+ for (int out_ch = 0 ; out_ch < out_channels; out_ch++)
139+ {
140+ for (int in_ch = 0 ; in_ch < in_channels; in_ch++)
141+ {
142+ this ->_weight (out_ch, in_ch) = *(weights++);
143+ }
144+ }
145+
146+ // Biases for each output channel
147+ this ->_bias .resize (out_channels);
148+ for (int out_ch = 0 ; out_ch < out_channels; out_ch++)
149+ {
150+ this ->_bias (out_ch) = *(weights++);
151+ }
138152}
139153
140- void nam::convnet::_Head::process_ (const Eigen::MatrixXf& input, Eigen::VectorXf & output, const long i_start,
154+ void nam::convnet::_Head::process_ (const Eigen::MatrixXf& input, Eigen::MatrixXf & output, const long i_start,
141155 const long i_end) const
142156{
143157 const long length = i_end - i_start;
144- output.resize (length);
145- for (long i = 0 , j = i_start; i < length; i++, j++)
146- output (i) = this ->_bias + input.col (j).dot (this ->_weight );
158+ const long out_channels = this ->_weight .rows ();
159+
160+ // Resize output to (out_channels x length)
161+ output.resize (out_channels, length);
162+
163+ // Extract input slice: (in_channels x length)
164+ Eigen::MatrixXf input_slice = input.middleCols (i_start, length);
165+
166+ // Compute output = weight * input_slice: (out_channels x in_channels) * (in_channels x length) = (out_channels x
167+ // length)
168+ output.noalias () = this ->_weight * input_slice;
169+
170+ // Add bias to each column: output.colwise() += bias
171+ // output is (out_channels x length), bias is (out_channels x 1), so colwise() += works
172+ output.colwise () += this ->_bias ;
147173}
148174
149- nam::convnet::ConvNet::ConvNet (const int channels , const std::vector< int >& dilations , const bool batchnorm ,
150- const std::string activation, std::vector< float >& weights ,
151- const double expected_sample_rate, const int groups)
152- : Buffer(*std::max_element (dilations.begin(), dilations.end()), expected_sample_rate)
175+ nam::convnet::ConvNet::ConvNet (const int in_channels , const int out_channels , const int channels ,
176+ const std::vector< int >& dilations, const bool batchnorm, const std::string activation ,
177+ std::vector< float >& weights, const double expected_sample_rate, const int groups)
178+ : Buffer(in_channels, out_channels, *std::max_element (dilations.begin(), dilations.end()), expected_sample_rate)
153179{
154180 this ->_verify_weights (channels, dilations, batchnorm, weights.size ());
155181 this ->_blocks .resize (dilations.size ());
156182 std::vector<float >::iterator it = weights.begin ();
183+ // First block takes in_channels input, subsequent blocks take channels input
157184 for (size_t i = 0 ; i < dilations.size (); i++)
158- this ->_blocks [i].set_weights_ (i == 0 ? 1 : channels, channels, dilations[i], batchnorm, activation, groups, it);
185+ this ->_blocks [i].set_weights_ (
186+ i == 0 ? in_channels : channels, channels, dilations[i], batchnorm, activation, groups, it);
159187 // Only need _block_vals for the head (one entry)
160188 // Conv1D layers manage their own buffers now
161189 this ->_block_vals .resize (1 );
162190 this ->_block_vals [0 ].setZero ();
163- std::fill (this ->_input_buffer .begin (), this ->_input_buffer .end (), 0 .0f );
164- this ->_head = _Head (channels, it);
191+
192+ // Create single head that outputs all channels
193+ this ->_head = _Head (channels, out_channels, it);
194+
165195 if (it != weights.end ())
166196 throw std::runtime_error (" Didn't touch all the weights when initializing ConvNet" );
167197
@@ -171,18 +201,25 @@ nam::convnet::ConvNet::ConvNet(const int channels, const std::vector<int>& dilat
171201}
172202
173203
174- void nam::convnet::ConvNet::process (NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames)
204+ void nam::convnet::ConvNet::process (NAM_SAMPLE** input, NAM_SAMPLE* * output, const int num_frames)
175205
176206{
177207 this ->_update_buffers_ (input, num_frames);
178- // Main computation!
179- const long i_start = this ->_input_buffer_offset ;
180- const long i_end = i_start + num_frames;
208+ const int in_channels = NumInputChannels ();
209+ const int out_channels = NumOutputChannels ();
210+
211+ // For multi-channel, we process each input channel independently through the network
212+ // and sum outputs to each output channel (simple implementation)
213+ // This can be extended later for more sophisticated cross-channel processing
181214
182- // Convert input buffer to matrix for first layer
183- Eigen::MatrixXf input_matrix (1 , num_frames);
184- for (int i = 0 ; i < num_frames; i++)
185- input_matrix (0 , i) = this ->_input_buffer [i_start + i];
215+ // Convert input buffers to matrix for first layer (stack input channels)
216+ Eigen::MatrixXf input_matrix (in_channels, num_frames);
217+ const long i_start = this ->_input_buffer_offset ;
218+ for (int ch = 0 ; ch < in_channels; ch++)
219+ {
220+ for (int i = 0 ; i < num_frames; i++)
221+ input_matrix (ch, i) = this ->_input_buffers [ch][i_start + i];
222+ }
186223
187224 // Process through ConvNetBlock layers
188225 // Each block now uses Conv1D's internal buffers via Process() and GetOutput()
@@ -206,23 +243,33 @@ void nam::convnet::ConvNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const
206243 this ->_blocks [i].Process (block_input, num_frames);
207244 }
208245
209- // Process head with output from last Conv1D
210- // Head still needs the old interface, so we need to provide it via a matrix
211- // We still need _block_vals [0] for the head interface
246+ // Process head for all output channels at once
247+ // We need _block_vals[0] for the head interface
248+ const long buffer_size = ( long ) this -> _input_buffers [0 ]. size ();
212249 if (this ->_block_vals [0 ].rows () != this ->_blocks .back ().get_out_channels ()
213- || this ->_block_vals [0 ].cols () != ( long ) this -> _input_buffer . size () )
250+ || this ->_block_vals [0 ].cols () != buffer_size )
214251 {
215- this ->_block_vals [0 ].resize (this ->_blocks .back ().get_out_channels (), this -> _input_buffer . size () );
252+ this ->_block_vals [0 ].resize (this ->_blocks .back ().get_out_channels (), buffer_size );
216253 }
254+
217255 // Copy last block output to _block_vals for head
218256 auto last_output = this ->_blocks .back ().GetOutput (num_frames);
219- this ->_block_vals [0 ].middleCols (i_start, num_frames) = last_output;
220-
221- this ->_head .process_ (this ->_block_vals [0 ], this ->_head_output , i_start, i_end);
222-
223- // Copy to required output array
224- for (int s = 0 ; s < num_frames; s++)
225- output[s] = this ->_head_output (s);
257+ const long buffer_offset = this ->_input_buffer_offset ;
258+ const long buffer_i_end = buffer_offset + num_frames;
259+ // last_output is (channels x num_frames), _block_vals[0] is (channels x buffer_size)
260+ // Copy to the correct location in _block_vals
261+ this ->_block_vals [0 ].block (0 , buffer_offset, last_output.rows (), num_frames) = last_output;
262+
263+ // Process head - outputs all channels at once
264+ // Head will resize _head_output internally
265+ this ->_head .process_ (this ->_block_vals [0 ], this ->_head_output , buffer_offset, buffer_i_end);
266+
267+ // Copy to output arrays for each channel
268+ for (int ch = 0 ; ch < out_channels; ch++)
269+ {
270+ for (int s = 0 ; s < num_frames; s++)
271+ output[ch][s] = this ->_head_output (ch, s);
272+ }
226273
227274 // Prepare for next call:
228275 nam::Buffer::_advance_input_buffer_ (num_frames);
@@ -245,11 +292,12 @@ void nam::convnet::ConvNet::SetMaxBufferSize(const int maxBufferSize)
245292 }
246293}
247294
248- void nam::convnet::ConvNet::_update_buffers_ (NAM_SAMPLE* input, const int num_frames)
295+ void nam::convnet::ConvNet::_update_buffers_ (NAM_SAMPLE** input, const int num_frames)
249296{
250297 this ->Buffer ::_update_buffers_ (input, num_frames);
251298
252- const long buffer_size = (long )this ->_input_buffer .size ();
299+ // All channels use the same buffer size
300+ const long buffer_size = (long )this ->_input_buffers [0 ].size ();
253301
254302 // Only need _block_vals[0] for the head
255303 // Conv1D layers manage their own buffers now
@@ -281,8 +329,11 @@ std::unique_ptr<nam::DSP> nam::convnet::Factory(const nlohmann::json& config, st
281329 const bool batchnorm = config[" batchnorm" ];
282330 const std::string activation = config[" activation" ];
283331 const int groups = config.value (" groups" , 1 ); // defaults to 1
332+ // Default to 1 channel in/out for backward compatibility
333+ const int in_channels = config.value (" in_channels" , 1 );
334+ const int out_channels = config.value (" out_channels" , 1 );
284335 return std::make_unique<nam::convnet::ConvNet>(
285- channels, dilations, batchnorm, activation, weights, expectedSampleRate, groups);
336+ in_channels, out_channels, channels, dilations, batchnorm, activation, weights, expectedSampleRate, groups);
286337}
287338
288339namespace
0 commit comments