Skip to content

Commit 704f309

Browse files
author
João Felipe Santos
committed
Remove redundant manual loop unrolling from activations and element-wise ops
ARM assembly analysis (-O2 -DNDEBUG) confirmed: - GCC auto-unrolls simple activation loops; manual 4-wide gives no benefit - expf() serializes sigmoid/SiLU; unrolling can't help - Eigen element-wise ops (.leftCols + .leftCols) produce identical codegen to raw float* loops when assertions are disabled Simplify 5 activation classes to use inline helpers (relu, sigmoid, etc.) and revert 3 wavenet element-wise operations back to Eigen expressions. Inline GEMM (Conv1x1/Conv1D), depthwise unrolling, FiLM unrolling, bias broadcast, and memcpy optimizations are retained — those show measurable wins on both desktop and Cortex-M7. Also restored comments that were accidentally removed from wavenet.h.
1 parent 5d9ed6c commit 704f309

File tree

4 files changed

+15
-177
lines changed

4 files changed

+15
-177
lines changed

NAM/activations.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ std::unordered_map<std::string, nam::activations::Activation::Ptr> nam::activati
3737

3838
nam::activations::Activation::Ptr tanh_bak = nullptr;
3939
nam::activations::Activation::Ptr sigmoid_bak = nullptr;
40-
nam::activations::Activation::Ptr silu_bak = nullptr;
4140

4241
nam::activations::Activation::Ptr nam::activations::Activation::get_activation(const std::string name)
4342
{
@@ -198,14 +197,9 @@ void nam::activations::Activation::enable_lut(std::string function_name, float m
198197
fn = sigmoid;
199198
sigmoid_bak = _activations["Sigmoid"];
200199
}
201-
else if (function_name == "SiLU")
202-
{
203-
fn = swish;
204-
silu_bak = _activations["SiLU"];
205-
}
206200
else
207201
{
208-
throw std::runtime_error("Tried to enable LUT for a function other than Tanh, Sigmoid, or SiLU");
202+
throw std::runtime_error("Tried to enable LUT for a function other than Tanh or Sigmoid");
209203
}
210204
_activations[function_name] = std::make_shared<FastLUTActivation>(min, max, n_points, fn);
211205
}
@@ -220,12 +214,8 @@ void nam::activations::Activation::disable_lut(std::string function_name)
220214
{
221215
_activations["Sigmoid"] = sigmoid_bak;
222216
}
223-
else if (function_name == "SiLU")
224-
{
225-
_activations["SiLU"] = silu_bak;
226-
}
227217
else
228218
{
229-
throw std::runtime_error("Tried to disable LUT for a function other than Tanh, Sigmoid, or SiLU");
219+
throw std::runtime_error("Tried to disable LUT for a function other than Tanh or Sigmoid");
230220
}
231221
}

NAM/activations.h

Lines changed: 6 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -235,24 +235,8 @@ class ActivationReLU : public Activation
235235
public:
236236
void apply(float* data, long size) override
237237
{
238-
// Optimized ReLU with loop unrolling
239-
long pos = 0;
240-
// Process 4 elements at a time
241-
for (; pos + 3 < size; pos += 4)
242-
{
243-
// Branchless ReLU using conditional
244-
const float v0 = data[pos], v1 = data[pos + 1];
245-
const float v2 = data[pos + 2], v3 = data[pos + 3];
246-
data[pos] = v0 > 0.0f ? v0 : 0.0f;
247-
data[pos + 1] = v1 > 0.0f ? v1 : 0.0f;
248-
data[pos + 2] = v2 > 0.0f ? v2 : 0.0f;
249-
data[pos + 3] = v3 > 0.0f ? v3 : 0.0f;
250-
}
251-
// Handle remainder
252-
for (; pos < size; pos++)
253-
{
254-
data[pos] = data[pos] > 0.0f ? data[pos] : 0.0f;
255-
}
238+
for (long pos = 0; pos < size; pos++)
239+
data[pos] = relu(data[pos]);
256240
}
257241
};
258242

@@ -316,23 +300,8 @@ class ActivationSigmoid : public Activation
316300
public:
317301
void apply(float* data, long size) override
318302
{
319-
long pos = 0;
320-
// Process 4 elements at a time
321-
for (; pos + 3 < size; pos += 4)
322-
{
323-
const float x0 = data[pos], x1 = data[pos + 1];
324-
const float x2 = data[pos + 2], x3 = data[pos + 3];
325-
326-
data[pos] = 1.0f / (1.0f + expf(-x0));
327-
data[pos + 1] = 1.0f / (1.0f + expf(-x1));
328-
data[pos + 2] = 1.0f / (1.0f + expf(-x2));
329-
data[pos + 3] = 1.0f / (1.0f + expf(-x3));
330-
}
331-
// Handle remainder
332-
for (; pos < size; pos++)
333-
{
303+
for (long pos = 0; pos < size; pos++)
334304
data[pos] = sigmoid(data[pos]);
335-
}
336305
}
337306
};
338307

@@ -341,28 +310,8 @@ class ActivationSwish : public Activation
341310
public:
342311
void apply(float* data, long size) override
343312
{
344-
long pos = 0;
345-
// Process 4 elements at a time: swish(x) = x * sigmoid(x) = x / (1 + exp(-x))
346-
for (; pos + 3 < size; pos += 4)
347-
{
348-
const float x0 = data[pos], x1 = data[pos + 1];
349-
const float x2 = data[pos + 2], x3 = data[pos + 3];
350-
351-
const float s0 = 1.0f / (1.0f + expf(-x0));
352-
const float s1 = 1.0f / (1.0f + expf(-x1));
353-
const float s2 = 1.0f / (1.0f + expf(-x2));
354-
const float s3 = 1.0f / (1.0f + expf(-x3));
355-
356-
data[pos] = x0 * s0;
357-
data[pos + 1] = x1 * s1;
358-
data[pos + 2] = x2 * s2;
359-
data[pos + 3] = x3 * s3;
360-
}
361-
// Handle remainder
362-
for (; pos < size; pos++)
363-
{
313+
for (long pos = 0; pos < size; pos++)
364314
data[pos] = swish(data[pos]);
365-
}
366315
}
367316
};
368317

@@ -371,32 +320,8 @@ class ActivationHardSwish : public Activation
371320
public:
372321
void apply(float* data, long size) override
373322
{
374-
const float inv6 = 1.0f / 6.0f;
375-
long pos = 0;
376-
// Process 4 elements at a time
377-
for (; pos + 3 < size; pos += 4)
378-
{
379-
const float x0 = data[pos], x1 = data[pos + 1];
380-
const float x2 = data[pos + 2], x3 = data[pos + 3];
381-
382-
const float t0 = x0 + 3.0f, t1 = x1 + 3.0f;
383-
const float t2 = x2 + 3.0f, t3 = x3 + 3.0f;
384-
385-
const float c0 = t0 < 0.0f ? 0.0f : (t0 > 6.0f ? 6.0f : t0);
386-
const float c1 = t1 < 0.0f ? 0.0f : (t1 > 6.0f ? 6.0f : t1);
387-
const float c2 = t2 < 0.0f ? 0.0f : (t2 > 6.0f ? 6.0f : t2);
388-
const float c3 = t3 < 0.0f ? 0.0f : (t3 > 6.0f ? 6.0f : t3);
389-
390-
data[pos] = x0 * c0 * inv6;
391-
data[pos + 1] = x1 * c1 * inv6;
392-
data[pos + 2] = x2 * c2 * inv6;
393-
data[pos + 3] = x3 * c3 * inv6;
394-
}
395-
// Handle remainder
396-
for (; pos < size; pos++)
397-
{
323+
for (long pos = 0; pos < size; pos++)
398324
data[pos] = hardswish(data[pos]);
399-
}
400325
}
401326
};
402327

@@ -405,23 +330,8 @@ class ActivationSoftsign : public Activation
405330
public:
406331
void apply(float* data, long size) override
407332
{
408-
long pos = 0;
409-
// Process 4 elements at a time
410-
for (; pos + 3 < size; pos += 4)
411-
{
412-
const float x0 = data[pos], x1 = data[pos + 1];
413-
const float x2 = data[pos + 2], x3 = data[pos + 3];
414-
415-
data[pos] = x0 / (1.0f + fabsf(x0));
416-
data[pos + 1] = x1 / (1.0f + fabsf(x1));
417-
data[pos + 2] = x2 / (1.0f + fabsf(x2));
418-
data[pos + 3] = x3 / (1.0f + fabsf(x3));
419-
}
420-
// Handle remainder
421-
for (; pos < size; pos++)
422-
{
333+
for (long pos = 0; pos < size; pos++)
423334
data[pos] = softsign(data[pos]);
424-
}
425335
}
426336
};
427337

NAM/wavenet.cpp

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -124,33 +124,8 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
124124
Eigen::MatrixXf& input_mixin_output = this->_input_mixin.GetOutput();
125125
this->_input_mixin_post_film->Process_(input_mixin_output, condition, num_frames);
126126
}
127-
#ifdef NAM_USE_INLINE_GEMM
128-
// Optimized matrix addition for small channel counts
129-
{
130-
const int channels = (int)_conv.get_out_channels();
131-
const float* __restrict__ conv_ptr = _conv.GetOutput().data();
132-
const float* __restrict__ mixin_ptr = _input_mixin.GetOutput().data();
133-
float* __restrict__ z_ptr = this->_z.data();
134-
const int total = channels * num_frames;
135-
136-
// Unrolled addition
137-
int i = 0;
138-
for (; i + 3 < total; i += 4)
139-
{
140-
z_ptr[i] = conv_ptr[i] + mixin_ptr[i];
141-
z_ptr[i + 1] = conv_ptr[i + 1] + mixin_ptr[i + 1];
142-
z_ptr[i + 2] = conv_ptr[i + 2] + mixin_ptr[i + 2];
143-
z_ptr[i + 3] = conv_ptr[i + 3] + mixin_ptr[i + 3];
144-
}
145-
for (; i < total; i++)
146-
{
147-
z_ptr[i] = conv_ptr[i] + mixin_ptr[i];
148-
}
149-
}
150-
#else
151127
this->_z.leftCols(num_frames).noalias() =
152128
_conv.GetOutput().leftCols(num_frames) + _input_mixin.GetOutput().leftCols(num_frames);
153-
#endif
154129

155130
if (this->_activation_pre_film)
156131
{
@@ -282,28 +257,8 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
282257
// Store output to next layer (residual connection: input + layer1x1 output, or just input if layer1x1 inactive)
283258
if (this->_layer1x1)
284259
{
285-
#ifdef NAM_USE_INLINE_GEMM
286-
{
287-
const int channels = (int)this->get_channels();
288-
const int total = channels * num_frames;
289-
const float* __restrict__ in_ptr = input.data();
290-
const float* __restrict__ layer_ptr = this->_layer1x1->GetOutput().data();
291-
float* __restrict__ dst = this->_output_next_layer.data();
292-
int i = 0;
293-
for (; i + 3 < total; i += 4)
294-
{
295-
dst[i] = in_ptr[i] + layer_ptr[i];
296-
dst[i + 1] = in_ptr[i + 1] + layer_ptr[i + 1];
297-
dst[i + 2] = in_ptr[i + 2] + layer_ptr[i + 2];
298-
dst[i + 3] = in_ptr[i + 3] + layer_ptr[i + 3];
299-
}
300-
for (; i < total; i++)
301-
dst[i] = in_ptr[i] + layer_ptr[i];
302-
}
303-
#else
304260
this->_output_next_layer.leftCols(num_frames).noalias() =
305261
input.leftCols(num_frames) + this->_layer1x1->GetOutput().leftCols(num_frames);
306-
#endif
307262
}
308263
else
309264
{
@@ -415,26 +370,7 @@ void nam::wavenet::_LayerArray::ProcessInner(const Eigen::MatrixXf& layer_inputs
415370
}
416371

417372
// Accumulate head output from this layer
418-
#ifdef NAM_USE_INLINE_GEMM
419-
{
420-
const int channels = (int)this->_head_output_size;
421-
const int total = channels * num_frames;
422-
const float* __restrict__ src = this->_layers[i].GetOutputHead().data();
423-
float* __restrict__ dst = this->_head_inputs.data();
424-
int j = 0;
425-
for (; j + 3 < total; j += 4)
426-
{
427-
dst[j] += src[j];
428-
dst[j + 1] += src[j + 1];
429-
dst[j + 2] += src[j + 2];
430-
dst[j + 3] += src[j + 3];
431-
}
432-
for (; j < total; j++)
433-
dst[j] += src[j];
434-
}
435-
#else
436373
this->_head_inputs.leftCols(num_frames).noalias() += this->_layers[i].GetOutputHead().leftCols(num_frames);
437-
#endif
438374
}
439375

440376
// Store output from last layer - use memcpy for pure copy

NAM/wavenet.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,11 @@ class _Layer
385385
std::unique_ptr<Conv1x1> _layer1x1;
386386
// The post-activation 1x1 convolution outputting to the head, optional
387387
std::unique_ptr<Conv1x1> _head1x1;
388-
388+
// The internal state
389389
Eigen::MatrixXf _z;
390+
// Output to next layer (residual connection: input + layer1x1 output, or just input if layer1x1 inactive)
390391
Eigen::MatrixXf _output_next_layer;
392+
// Output to head (skip connection: activated conv output)
391393
Eigen::MatrixXf _output_head;
392394

393395
activations::Activation::Ptr _activation;
@@ -604,12 +606,12 @@ class _LayerArray
604606

605607
// The layer objects
606608
std::vector<_Layer> _layers;
607-
609+
// Output from last layer (for next layer array)
608610
Eigen::MatrixXf _layer_outputs;
609-
Eigen::MatrixXf _head_inputs;
610-
611611
// Accumulated head inputs from all layers
612612
// Size is _head_output_size (= head1x1.out_channels if head1x1 active, else bottleneck)
613+
Eigen::MatrixXf _head_inputs;
614+
613615
// Rechannel for the head (_head_output_size -> head_size)
614616
Conv1x1 _head_rechannel;
615617

@@ -668,9 +670,9 @@ class WaveNet : public DSP
668670
void set_weights_(std::vector<float>::iterator& weights);
669671

670672
protected:
673+
// Element-wise arrays:
671674
Eigen::MatrixXf _condition_input;
672675
Eigen::MatrixXf _condition_output;
673-
674676
std::unique_ptr<DSP> _condition_dsp;
675677
// Temporary buffers for condition DSP processing (to avoid allocations in _process_condition)
676678
std::vector<std::vector<NAM_SAMPLE>> _condition_dsp_input_buffers;

0 commit comments

Comments
 (0)