Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 14 additions & 23 deletions NAM/activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ static nam::activations::ActivationLeakyHardTanh _LEAKY_HARD_TANH;
bool nam::activations::Activation::using_fast_tanh = false;

// Helper to create a non-owning shared_ptr (no-op deleter) for singletons
template<typename T>
template <typename T>
nam::activations::Activation::Ptr make_singleton_ptr(T& singleton)
{
return nam::activations::Activation::Ptr(&singleton, [](nam::activations::Activation*){});
return nam::activations::Activation::Ptr(&singleton, [](nam::activations::Activation*) {});
}

std::unordered_map<std::string, nam::activations::Activation::Ptr> nam::activations::Activation::_activations = {
Expand All @@ -31,8 +31,7 @@ std::unordered_map<std::string, nam::activations::Activation::Ptr> nam::activati
{"SiLU", make_singleton_ptr(_SWISH)},
{"Hardswish", make_singleton_ptr(_HARD_SWISH)},
{"LeakyHardtanh", make_singleton_ptr(_LEAKY_HARD_TANH)},
{"PReLU", make_singleton_ptr(_PRELU)}
};
{"PReLU", make_singleton_ptr(_PRELU)}};

nam::activations::Activation::Ptr tanh_bak = nullptr;
nam::activations::Activation::Ptr sigmoid_bak = nullptr;
Expand Down Expand Up @@ -130,20 +129,13 @@ nam::activations::Activation::Ptr nam::activations::Activation::get_activation(c
{
switch (config.type)
{
case ActivationType::Tanh:
return _activations["Tanh"];
case ActivationType::Hardtanh:
return _activations["Hardtanh"];
case ActivationType::Fasttanh:
return _activations["Fasttanh"];
case ActivationType::ReLU:
return _activations["ReLU"];
case ActivationType::Sigmoid:
return _activations["Sigmoid"];
case ActivationType::SiLU:
return _activations["SiLU"];
case ActivationType::Hardswish:
return _activations["Hardswish"];
case ActivationType::Tanh: return _activations["Tanh"];
case ActivationType::Hardtanh: return _activations["Hardtanh"];
case ActivationType::Fasttanh: return _activations["Fasttanh"];
case ActivationType::ReLU: return _activations["ReLU"];
case ActivationType::Sigmoid: return _activations["Sigmoid"];
case ActivationType::SiLU: return _activations["SiLU"];
case ActivationType::Hardswish: return _activations["Hardswish"];
case ActivationType::LeakyReLU:
if (config.negative_slope.has_value())
{
Expand All @@ -161,11 +153,10 @@ nam::activations::Activation::Ptr nam::activations::Activation::get_activation(c
}
return std::make_shared<ActivationPReLU>(0.01f);
case ActivationType::LeakyHardtanh:
return std::make_shared<ActivationLeakyHardTanh>(
config.min_val.value_or(-1.0f), config.max_val.value_or(1.0f), config.min_slope.value_or(0.01f),
config.max_slope.value_or(0.01f));
default:
return nullptr;
return std::make_shared<ActivationLeakyHardTanh>(config.min_val.value_or(-1.0f), config.max_val.value_or(1.0f),
config.min_slope.value_or(0.01f),
config.max_slope.value_or(0.01f));
default: return nullptr;
}
}

Expand Down
14 changes: 7 additions & 7 deletions NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ struct ActivationConfig
ActivationType type;

// Optional parameters (used by specific activation types)
std::optional<float> negative_slope; // LeakyReLU, PReLU (single)
std::optional<float> negative_slope; // LeakyReLU, PReLU (single)
std::optional<std::vector<float>> negative_slopes; // PReLU (per-channel)
std::optional<float> min_val; // LeakyHardtanh
std::optional<float> max_val; // LeakyHardtanh
std::optional<float> min_slope; // LeakyHardtanh
std::optional<float> max_slope; // LeakyHardtanh
std::optional<float> min_val; // LeakyHardtanh
std::optional<float> max_val; // LeakyHardtanh
std::optional<float> min_slope; // LeakyHardtanh
std::optional<float> max_slope; // LeakyHardtanh

// Convenience constructors
static ActivationConfig simple(ActivationType t);
Expand Down Expand Up @@ -274,13 +274,13 @@ class ActivationPReLU : public Activation
{
// Matrix is organized as (channels, time_steps)
unsigned long actual_channels = static_cast<unsigned long>(matrix.rows());

// Prepare the slopes for the current matrix size
std::vector<float> slopes_for_channels = negative_slopes;

// Fail loudly if input has more channels than activation
assert(actual_channels == negative_slopes.size());

// Apply each negative slope to its corresponding channel
for (unsigned long channel = 0; channel < actual_channels; channel++)
{
Expand Down
86 changes: 86 additions & 0 deletions NAM/film.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#pragma once

#include <Eigen/Dense>
#include <cassert>
#include <vector>

#include "dsp.h"

namespace nam
{
// Feature-wise Linear Modulation (FiLM)
//
// Given an `input` (input_dim x num_frames) and a `condition`
// (condition_dim x num_frames), compute:
// scale, shift = Conv1x1(condition) split across channels
// output = input * scale + shift (elementwise)
class FiLM
{
public:
FiLM(const int condition_dim, const int input_dim, const bool shift)
: _cond_to_scale_shift(condition_dim, (shift ? 2 : 1) * input_dim, /*bias=*/true)
, _do_shift(shift)
{
}

// Get the entire internal output buffer. This is intended for internal wiring
// between layers; callers should treat the buffer as pre-allocated storage
// and only consider the first `num_frames` columns valid for a given
// processing call. Slice with .leftCols(num_frames) as needed.
Eigen::MatrixXf& GetOutput() { return _output; }
const Eigen::MatrixXf& GetOutput() const { return _output; }

void SetMaxBufferSize(const int maxBufferSize)
{
_cond_to_scale_shift.SetMaxBufferSize(maxBufferSize);
_output.resize(get_input_dim(), maxBufferSize);
}

void set_weights_(std::vector<float>::iterator& weights) { _cond_to_scale_shift.set_weights_(weights); }

long get_condition_dim() const { return _cond_to_scale_shift.get_in_channels(); }
long get_input_dim() const
{
return _do_shift ? (_cond_to_scale_shift.get_out_channels() / 2) : _cond_to_scale_shift.get_out_channels();
}

// :param input: (input_dim x num_frames)
// :param condition: (condition_dim x num_frames)
// Writes (input_dim x num_frames) into internal output buffer; access via GetOutput().
void Process(const Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames)
{
assert(get_input_dim() == input.rows());
assert(get_condition_dim() == condition.rows());
assert(num_frames <= input.cols());
assert(num_frames <= condition.cols());
assert(num_frames <= _output.cols());

_cond_to_scale_shift.process_(condition, num_frames);
const auto& scale_shift = _cond_to_scale_shift.GetOutput();

const auto scale = scale_shift.topRows(get_input_dim()).leftCols(num_frames);
if (_do_shift)
{
// scale = top input_dim, shift = bottom input_dim
const auto shift = scale_shift.bottomRows(get_input_dim()).leftCols(num_frames);
_output.leftCols(num_frames).array() = input.leftCols(num_frames).array() * scale.array() + shift.array();
}
else
{
_output.leftCols(num_frames).array() = input.leftCols(num_frames).array() * scale.array();
}
}

// in-place
void Process_(Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames)
{
Process(input, condition, num_frames);
input.leftCols(num_frames).noalias() = _output.leftCols(num_frames);
}

private:
Conv1x1 _cond_to_scale_shift; // condition_dim -> (shift ? 2 : 1) * input_dim
Eigen::MatrixXf _output; // input_dim x maxBufferSize
bool _do_shift;
};
} // namespace nam
6 changes: 4 additions & 2 deletions NAM/gating_activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class GatingActivation
* @param input_channels Number of input channels (default: 1)
* @param gating_channels Number of gating channels (default: 1)
*/
GatingActivation(activations::Activation::Ptr input_act, activations::Activation::Ptr gating_act, int input_channels = 1)
GatingActivation(activations::Activation::Ptr input_act, activations::Activation::Ptr gating_act,
int input_channels = 1)
: input_activation(input_act)
, gating_activation(gating_act)
, num_channels(input_channels)
Expand Down Expand Up @@ -109,7 +110,8 @@ class BlendingActivation
* @param blend_act Activation function for blending channels
* @param input_channels Number of input channels
*/
BlendingActivation(activations::Activation::Ptr input_act, activations::Activation::Ptr blend_act, int input_channels = 1)
BlendingActivation(activations::Activation::Ptr input_act, activations::Activation::Ptr blend_act,
int input_channels = 1)
: input_activation(input_act)
, blending_activation(blend_act)
, num_channels(input_channels)
Expand Down
Loading