diff --git a/NAM/activations.cpp b/NAM/activations.cpp index 3db6024..c9eeb95 100644 --- a/NAM/activations.cpp +++ b/NAM/activations.cpp @@ -1,29 +1,43 @@ #include "activations.h" -nam::activations::ActivationTanh _TANH = nam::activations::ActivationTanh(); -nam::activations::ActivationFastTanh _FAST_TANH = nam::activations::ActivationFastTanh(); -nam::activations::ActivationHardTanh _HARD_TANH = nam::activations::ActivationHardTanh(); -nam::activations::ActivationReLU _RELU = nam::activations::ActivationReLU(); -nam::activations::ActivationLeakyReLU _LEAKY_RELU = - nam::activations::ActivationLeakyReLU(0.01); // FIXME does not parameterize LeakyReLU -nam::activations::ActivationPReLU _PRELU = nam::activations::ActivationPReLU(0.01); // Same as leaky ReLU by default -nam::activations::ActivationSigmoid _SIGMOID = nam::activations::ActivationSigmoid(); -nam::activations::ActivationSwish _SWISH = nam::activations::ActivationSwish(); -nam::activations::ActivationHardSwish _HARD_SWISH = nam::activations::ActivationHardSwish(); -nam::activations::ActivationLeakyHardTanh _LEAKY_HARD_TANH = nam::activations::ActivationLeakyHardTanh(); +// Global singleton instances (statically allocated, never deleted) +static nam::activations::ActivationTanh _TANH; +static nam::activations::ActivationFastTanh _FAST_TANH; +static nam::activations::ActivationHardTanh _HARD_TANH; +static nam::activations::ActivationReLU _RELU; +static nam::activations::ActivationLeakyReLU _LEAKY_RELU(0.01); // FIXME does not parameterize LeakyReLU +static nam::activations::ActivationPReLU _PRELU(0.01); // Same as leaky ReLU by default +static nam::activations::ActivationSigmoid _SIGMOID; +static nam::activations::ActivationSwish _SWISH; +static nam::activations::ActivationHardSwish _HARD_SWISH; +static nam::activations::ActivationLeakyHardTanh _LEAKY_HARD_TANH; bool nam::activations::Activation::using_fast_tanh = false; -std::unordered_map nam::activations::Activation::_activations = { - {"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH}, - {"ReLU", &_RELU}, {"LeakyReLU", &_LEAKY_RELU}, {"Sigmoid", &_SIGMOID}, - {"SiLU", &_SWISH}, {"Hardswish", &_HARD_SWISH}, {"LeakyHardtanh", &_LEAKY_HARD_TANH}, - {"PReLU", &_PRELU}}; +// Helper to create a non-owning shared_ptr (no-op deleter) for singletons +template +nam::activations::Activation::Ptr make_singleton_ptr(T& singleton) +{ + return nam::activations::Activation::Ptr(&singleton, [](nam::activations::Activation*){}); +} + +std::unordered_map nam::activations::Activation::_activations = { + {"Tanh", make_singleton_ptr(_TANH)}, + {"Hardtanh", make_singleton_ptr(_HARD_TANH)}, + {"Fasttanh", make_singleton_ptr(_FAST_TANH)}, + {"ReLU", make_singleton_ptr(_RELU)}, + {"LeakyReLU", make_singleton_ptr(_LEAKY_RELU)}, + {"Sigmoid", make_singleton_ptr(_SIGMOID)}, + {"SiLU", make_singleton_ptr(_SWISH)}, + {"Hardswish", make_singleton_ptr(_HARD_SWISH)}, + {"LeakyHardtanh", make_singleton_ptr(_LEAKY_HARD_TANH)}, + {"PReLU", make_singleton_ptr(_PRELU)} +}; -nam::activations::Activation* tanh_bak = nullptr; -nam::activations::Activation* sigmoid_bak = nullptr; +nam::activations::Activation::Ptr tanh_bak = nullptr; +nam::activations::Activation::Ptr sigmoid_bak = nullptr; -nam::activations::Activation* nam::activations::Activation::get_activation(const std::string name) +nam::activations::Activation::Ptr nam::activations::Activation::get_activation(const std::string name) { if (_activations.find(name) == _activations.end()) return nullptr; @@ -31,6 +45,130 @@ nam::activations::Activation* nam::activations::Activation::get_activation(const return _activations[name]; } +// ActivationConfig implementation +nam::activations::ActivationConfig nam::activations::ActivationConfig::simple(ActivationType t) +{ + ActivationConfig config; + config.type = t; + return config; +} + +nam::activations::ActivationConfig nam::activations::ActivationConfig::from_json(const nlohmann::json& j) +{ + ActivationConfig config; + + // Map from string to ActivationType + static const std::unordered_map type_map = { + {"Tanh", ActivationType::Tanh}, + {"Hardtanh", ActivationType::Hardtanh}, + {"Fasttanh", ActivationType::Fasttanh}, + {"ReLU", ActivationType::ReLU}, + {"LeakyReLU", ActivationType::LeakyReLU}, + {"PReLU", ActivationType::PReLU}, + {"Sigmoid", ActivationType::Sigmoid}, + {"SiLU", ActivationType::SiLU}, + {"Hardswish", ActivationType::Hardswish}, + {"LeakyHardtanh", ActivationType::LeakyHardtanh}, + {"LeakyHardTanh", ActivationType::LeakyHardtanh} // Support both casings + }; + + // If it's a string, simple lookup + if (j.is_string()) + { + std::string name = j.get(); + auto it = type_map.find(name); + if (it == type_map.end()) + { + throw std::runtime_error("Unknown activation type: " + name); + } + config.type = it->second; + return config; + } + + // If it's an object, parse type and parameters + if (j.is_object()) + { + std::string type_str = j["type"].get(); + auto it = type_map.find(type_str); + if (it == type_map.end()) + { + throw std::runtime_error("Unknown activation type: " + type_str); + } + config.type = it->second; + + // Parse optional parameters based on activation type + if (config.type == ActivationType::PReLU) + { + if (j.find("negative_slope") != j.end()) + { + config.negative_slope = j["negative_slope"].get(); + } + else if (j.find("negative_slopes") != j.end()) + { + config.negative_slopes = j["negative_slopes"].get>(); + } + } + else if (config.type == ActivationType::LeakyReLU) + { + config.negative_slope = j.value("negative_slope", 0.01f); + } + else if (config.type == ActivationType::LeakyHardtanh) + { + config.min_val = j.value("min_val", -1.0f); + config.max_val = j.value("max_val", 1.0f); + config.min_slope = j.value("min_slope", 0.01f); + config.max_slope = j.value("max_slope", 0.01f); + } + + return config; + } + + throw std::runtime_error("Invalid activation config: expected string or object"); +} + +nam::activations::Activation::Ptr nam::activations::Activation::get_activation(const ActivationConfig& config) +{ + switch (config.type) + { + case ActivationType::Tanh: + return _activations["Tanh"]; + case ActivationType::Hardtanh: + return _activations["Hardtanh"]; + case ActivationType::Fasttanh: + return _activations["Fasttanh"]; + case ActivationType::ReLU: + return _activations["ReLU"]; + case ActivationType::Sigmoid: + return _activations["Sigmoid"]; + case ActivationType::SiLU: + return _activations["SiLU"]; + case ActivationType::Hardswish: + return _activations["Hardswish"]; + case ActivationType::LeakyReLU: + if (config.negative_slope.has_value()) + { + return std::make_shared(config.negative_slope.value()); + } + return _activations["LeakyReLU"]; + case ActivationType::PReLU: + if (config.negative_slopes.has_value()) + { + return std::make_shared(config.negative_slopes.value()); + } + else if (config.negative_slope.has_value()) + { + return std::make_shared(config.negative_slope.value()); + } + return std::make_shared(0.01f); + case ActivationType::LeakyHardtanh: + return std::make_shared( + config.min_val.value_or(-1.0f), config.max_val.value_or(1.0f), config.min_slope.value_or(0.01f), + config.max_slope.value_or(0.01f)); + default: + return nullptr; + } +} + void nam::activations::Activation::enable_fast_tanh() { nam::activations::Activation::using_fast_tanh = true; @@ -69,8 +207,7 @@ void nam::activations::Activation::enable_lut(std::string function_name, float m { throw std::runtime_error("Tried to enable LUT for a function other than Tanh or Sigmoid"); } - FastLUTActivation lut_activation(min, max, n_points, fn); - _activations[function_name] = &lut_activation; + _activations[function_name] = std::make_shared(min, max, n_points, fn); } void nam::activations::Activation::disable_lut(std::string function_name) diff --git a/NAM/activations.h b/NAM/activations.h index 6b4b6a2..977b458 100644 --- a/NAM/activations.h +++ b/NAM/activations.h @@ -1,16 +1,58 @@ #pragma once #include -#include #include // expf +#include +#include +#include +#include #include +#include + #include -#include + +#include "json.hpp" namespace nam { namespace activations { + +// Forward declaration +class Activation; + +// Strongly-typed activation type enum +enum class ActivationType +{ + Tanh, + Hardtanh, + Fasttanh, + ReLU, + LeakyReLU, + PReLU, + Sigmoid, + SiLU, // aka Swish + Hardswish, + LeakyHardtanh +}; + +// Strongly-typed activation configuration +struct ActivationConfig +{ + ActivationType type; + + // Optional parameters (used by specific activation types) + std::optional negative_slope; // LeakyReLU, PReLU (single) + std::optional> negative_slopes; // PReLU (per-channel) + std::optional min_val; // LeakyHardtanh + std::optional max_val; // LeakyHardtanh + std::optional min_slope; // LeakyHardtanh + std::optional max_slope; // LeakyHardtanh + + // Convenience constructors + static ActivationConfig simple(ActivationType t); + static ActivationConfig from_json(const nlohmann::json& j); +}; inline float relu(float x) { return x > 0.0f ? x : 0.0f; @@ -91,6 +133,9 @@ inline float hardswish(float x) class Activation { public: + // Type alias for shared pointer to Activation + using Ptr = std::shared_ptr; + Activation() = default; virtual ~Activation() = default; virtual void apply(Eigen::MatrixXf& matrix) { apply(matrix.data(), matrix.rows() * matrix.cols()); } @@ -101,7 +146,9 @@ class Activation } virtual void apply(float* data, long size) {} - static Activation* get_activation(const std::string name); + static Ptr get_activation(const std::string name); + static Ptr get_activation(const ActivationConfig& config); + static Ptr get_activation(const nlohmann::json& activation_config); static void enable_fast_tanh(); static void disable_fast_tanh(); static bool using_fast_tanh; @@ -109,7 +156,7 @@ class Activation static void disable_lut(std::string function_name); protected: - static std::unordered_map _activations; + static std::unordered_map _activations; }; // identity function activation @@ -226,20 +273,21 @@ class ActivationPReLU : public Activation void apply(Eigen::MatrixXf& matrix) override { // Matrix is organized as (channels, time_steps) - int n_channels = negative_slopes.size(); - int actual_channels = matrix.rows(); - - // NOTE: check not done during runtime on release builds - // model loader should make sure dimensions match - assert(actual_channels == n_channels); - + unsigned long actual_channels = static_cast(matrix.rows()); + + // Prepare the slopes for the current matrix size + std::vector slopes_for_channels = negative_slopes; + + // Fail loudly if input has more channels than activation + assert(actual_channels == negative_slopes.size()); + // Apply each negative slope to its corresponding channel - for (int channel = 0; channel < std::min(n_channels, actual_channels); channel++) + for (unsigned long channel = 0; channel < actual_channels; channel++) { // Apply the negative slope to all time steps in this channel - for (int time_step = 0; time_step < matrix.rows(); time_step++) + for (int time_step = 0; time_step < matrix.cols(); time_step++) { - matrix(channel, time_step) = leaky_relu(matrix(channel, time_step), negative_slopes[channel]); + matrix(channel, time_step) = leaky_relu(matrix(channel, time_step), slopes_for_channels[channel]); } } } diff --git a/NAM/convnet.cpp b/NAM/convnet.cpp index 8bbcded..fc7c151 100644 --- a/NAM/convnet.cpp +++ b/NAM/convnet.cpp @@ -48,7 +48,8 @@ void nam::convnet::BatchNorm::process_(Eigen::MatrixXf& x, const long i_start, c } void nam::convnet::ConvNetBlock::set_weights_(const int in_channels, const int out_channels, const int _dilation, - const bool batchnorm, const std::string activation, const int groups, + const bool batchnorm, + const activations::ActivationConfig& activation_config, const int groups, std::vector::iterator& weights) { this->_batchnorm = batchnorm; @@ -56,7 +57,7 @@ void nam::convnet::ConvNetBlock::set_weights_(const int in_channels, const int o this->conv.set_size_and_weights_(in_channels, out_channels, 2, _dilation, !batchnorm, groups, weights); if (this->_batchnorm) this->batchnorm = BatchNorm(out_channels, weights); - this->activation = activations::Activation::get_activation(activation); + this->activation = activations::Activation::get_activation(activation_config); } void nam::convnet::ConvNetBlock::SetMaxBufferSize(const int maxBufferSize) @@ -173,8 +174,9 @@ void nam::convnet::_Head::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf } nam::convnet::ConvNet::ConvNet(const int in_channels, const int out_channels, const int channels, - const std::vector& dilations, const bool batchnorm, const std::string activation, - std::vector& weights, const double expected_sample_rate, const int groups) + const std::vector& dilations, const bool batchnorm, + const activations::ActivationConfig& activation_config, std::vector& weights, + const double expected_sample_rate, const int groups) : Buffer(in_channels, out_channels, *std::max_element(dilations.begin(), dilations.end()), expected_sample_rate) { this->_verify_weights(channels, dilations, batchnorm, weights.size()); @@ -183,7 +185,7 @@ nam::convnet::ConvNet::ConvNet(const int in_channels, const int out_channels, co // First block takes in_channels input, subsequent blocks take channels input for (size_t i = 0; i < dilations.size(); i++) this->_blocks[i].set_weights_( - i == 0 ? in_channels : channels, channels, dilations[i], batchnorm, activation, groups, it); + i == 0 ? in_channels : channels, channels, dilations[i], batchnorm, activation_config, groups, it); // Only need _block_vals for the head (one entry) // Conv1D layers manage their own buffers now this->_block_vals.resize(1); @@ -327,13 +329,15 @@ std::unique_ptr nam::convnet::Factory(const nlohmann::json& config, st const int channels = config["channels"]; const std::vector dilations = config["dilations"]; const bool batchnorm = config["batchnorm"]; - const std::string activation = config["activation"]; + // Parse JSON into typed ActivationConfig at model loading boundary + const activations::ActivationConfig activation_config = + activations::ActivationConfig::from_json(config["activation"]); const int groups = config.value("groups", 1); // defaults to 1 // Default to 1 channel in/out for backward compatibility const int in_channels = config.value("in_channels", 1); const int out_channels = config.value("out_channels", 1); return std::make_unique( - in_channels, out_channels, channels, dilations, batchnorm, activation, weights, expectedSampleRate, groups); + in_channels, out_channels, channels, dilations, batchnorm, activation_config, weights, expectedSampleRate, groups); } namespace diff --git a/NAM/convnet.h b/NAM/convnet.h index d1e846c..1765311 100644 --- a/NAM/convnet.h +++ b/NAM/convnet.h @@ -9,8 +9,10 @@ #include +#include "activations.h" #include "conv1d.h" #include "dsp.h" +#include "json.hpp" namespace nam { @@ -44,7 +46,8 @@ class ConvNetBlock public: ConvNetBlock() {}; void set_weights_(const int in_channels, const int out_channels, const int _dilation, const bool batchnorm, - const std::string activation, const int groups, std::vector::iterator& weights); + const activations::ActivationConfig& activation_config, const int groups, + std::vector::iterator& weights); void SetMaxBufferSize(const int maxBufferSize); // Process input matrix directly (new API, similar to WaveNet) void Process(const Eigen::MatrixXf& input, const int num_frames); @@ -58,7 +61,7 @@ class ConvNetBlock private: BatchNorm batchnorm; bool _batchnorm = false; - activations::Activation* activation = nullptr; + activations::Activation::Ptr activation; Eigen::MatrixXf _output; // Output buffer owned by the block }; @@ -78,7 +81,7 @@ class ConvNet : public Buffer { public: ConvNet(const int in_channels, const int out_channels, const int channels, const std::vector& dilations, - const bool batchnorm, const std::string activation, std::vector& weights, + const bool batchnorm, const activations::ActivationConfig& activation_config, std::vector& weights, const double expected_sample_rate = -1.0, const int groups = 1); ~ConvNet() = default; diff --git a/NAM/gating_activations.h b/NAM/gating_activations.h index dbe5b03..ad49fb8 100644 --- a/NAM/gating_activations.h +++ b/NAM/gating_activations.h @@ -32,7 +32,7 @@ class GatingActivation * @param input_channels Number of input channels (default: 1) * @param gating_channels Number of gating channels (default: 1) */ - GatingActivation(activations::Activation* input_act, activations::Activation* gating_act, int input_channels = 1) + GatingActivation(activations::Activation::Ptr input_act, activations::Activation::Ptr gating_act, int input_channels = 1) : input_activation(input_act) , gating_activation(gating_act) , num_channels(input_channels) @@ -94,8 +94,8 @@ class GatingActivation int get_output_channels() const { return num_channels; } private: - activations::Activation* input_activation; - activations::Activation* gating_activation; + activations::Activation::Ptr input_activation; + activations::Activation::Ptr gating_activation; int num_channels; Eigen::MatrixXf input_buffer; }; @@ -109,7 +109,7 @@ class BlendingActivation * @param blend_act Activation function for blending channels * @param input_channels Number of input channels */ - BlendingActivation(activations::Activation* input_act, activations::Activation* blend_act, int input_channels = 1) + BlendingActivation(activations::Activation::Ptr input_act, activations::Activation::Ptr blend_act, int input_channels = 1) : input_activation(input_act) , blending_activation(blend_act) , num_channels(input_channels) @@ -169,8 +169,8 @@ class BlendingActivation int get_output_channels() const { return num_channels; } private: - activations::Activation* input_activation; - activations::Activation* blending_activation; + activations::Activation::Ptr input_activation; + activations::Activation::Ptr blending_activation; int num_channels; Eigen::MatrixXf input_buffer; }; diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 2cb749e..58887dd 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -113,7 +113,8 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma nam::wavenet::_LayerArray::_LayerArray(const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck, const int kernel_size, - const std::vector& dilations, const std::string activation, + const std::vector& dilations, + const activations::ActivationConfig& activation_config, const GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_1x1, const Head1x1Params& head1x1_params, const std::string& secondary_activation) @@ -122,7 +123,7 @@ nam::wavenet::_LayerArray::_LayerArray(const int input_size, const int condition , _bottleneck(bottleneck) { for (size_t i = 0; i < dilations.size(); i++) - this->_layers.push_back(_Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation, + this->_layers.push_back(_Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation_config, gating_mode, groups_input, groups_1x1, head1x1_params, secondary_activation)); } @@ -273,7 +274,7 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels, this->_layer_arrays.push_back(nam::wavenet::_LayerArray( layer_array_params[i].input_size, layer_array_params[i].condition_size, layer_array_params[i].head_size, layer_array_params[i].channels, layer_array_params[i].bottleneck, layer_array_params[i].kernel_size, - layer_array_params[i].dilations, layer_array_params[i].activation, layer_array_params[i].gating_mode, + layer_array_params[i].dilations, layer_array_params[i].activation_config, layer_array_params[i].gating_mode, layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_1x1, layer_array_params[i].head1x1_params, layer_array_params[i].secondary_activation)); if (i > 0) @@ -477,7 +478,9 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st const int head_size = layer_config["head_size"]; const int kernel_size = layer_config["kernel_size"]; const auto dilations = layer_config["dilations"]; - const std::string activation = layer_config["activation"].get(); + // Parse JSON into typed ActivationConfig at model loading boundary + const activations::ActivationConfig activation_config = + activations::ActivationConfig::from_json(layer_config["activation"]); // Parse gating mode - support both old "gated" boolean and new "gating_mode" string GatingMode gating_mode = GatingMode::NONE; std::string secondary_activation; @@ -531,7 +534,7 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st nam::wavenet::Head1x1Params head1x1_params(head1x1_active, head1x1_out_channels, head1x1_groups); layer_array_params.push_back(nam::wavenet::LayerArrayParams( - input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation, gating_mode, + input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation_config, gating_mode, head_bias, groups, groups_1x1, head1x1_params, secondary_activation)); } const bool with_head = !config["head"].is_null(); diff --git a/NAM/wavenet.h b/NAM/wavenet.h index e411385..93403d7 100644 --- a/NAM/wavenet.h +++ b/NAM/wavenet.h @@ -1,16 +1,17 @@ #pragma once -#include -#include #include #include +#include +#include -#include "json.hpp" #include -#include "dsp.h" +#include "activations.h" #include "conv1d.h" +#include "dsp.h" #include "gating_activations.h" +#include "json.hpp" namespace nam { @@ -48,14 +49,14 @@ struct Head1x1Params class _Layer { public: - // New constructor with GatingMode enum and configurable activations + // Constructor with GatingMode enum and typed ActivationConfig _Layer(const int condition_size, const int channels, const int bottleneck, const int kernel_size, const int dilation, - const std::string activation, const GatingMode gating_mode, const int groups_input, const int groups_1x1, - const Head1x1Params& head1x1_params, const std::string& secondary_activation) + const activations::ActivationConfig& activation_config, const GatingMode gating_mode, const int groups_input, + const int groups_1x1, const Head1x1Params& head1x1_params, const std::string& secondary_activation) : _conv(channels, (gating_mode != GatingMode::NONE) ? 2 * bottleneck : bottleneck, kernel_size, true, dilation) , _input_mixin(condition_size, (gating_mode != GatingMode::NONE) ? 2 * bottleneck : bottleneck, false) , _1x1(bottleneck, channels, groups_1x1) - , _activation(activations::Activation::get_activation(activation)) // needs to support activations with parameters + , _activation(activations::Activation::get_activation(activation_config)) , _gating_mode(gating_mode) , _bottleneck(bottleneck) { @@ -134,7 +135,7 @@ class _Layer // Output to head (skip connection: activated conv output) Eigen::MatrixXf _output_head; - activations::Activation* _activation; + activations::Activation::Ptr _activation; const GatingMode _gating_mode; const int _bottleneck; // Internal channel count (not doubled when gated) @@ -148,7 +149,7 @@ class LayerArrayParams public: LayerArrayParams(const int input_size_, const int condition_size_, const int head_size_, const int channels_, const int bottleneck_, const int kernel_size_, const std::vector&& dilations_, - const std::string activation_, const GatingMode gating_mode_, const bool head_bias_, + const activations::ActivationConfig& activation_, const GatingMode gating_mode_, const bool head_bias_, const int groups_input, const int groups_1x1_, const Head1x1Params& head1x1_params_, const std::string& secondary_activation_) : input_size(input_size_) @@ -158,7 +159,7 @@ class LayerArrayParams , bottleneck(bottleneck_) , kernel_size(kernel_size_) , dilations(std::move(dilations_)) - , activation(activation_) + , activation_config(activation_) , gating_mode(gating_mode_) , head_bias(head_bias_) , groups_input(groups_input) @@ -175,7 +176,7 @@ class LayerArrayParams const int bottleneck; const int kernel_size; std::vector dilations; - const std::string activation; + const activations::ActivationConfig activation_config; const GatingMode gating_mode; const bool head_bias; const int groups_input; @@ -188,11 +189,12 @@ class LayerArrayParams class _LayerArray { public: - // New constructor with GatingMode enum and configurable activations + // Constructor with GatingMode enum and typed ActivationConfig _LayerArray(const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck, const int kernel_size, const std::vector& dilations, - const std::string activation, const GatingMode gating_mode, const bool head_bias, const int groups_input, - const int groups_1x1, const Head1x1Params& head1x1_params, const std::string& secondary_activation); + const activations::ActivationConfig& activation_config, const GatingMode gating_mode, const bool head_bias, + const int groups_input, const int groups_1x1, const Head1x1Params& head1x1_params, + const std::string& secondary_activation); void SetMaxBufferSize(const int maxBufferSize); diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index de3a2e2..879417a 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -44,6 +44,18 @@ int main() // This is enforced by an assert so it doesn't need to be tested // test_activations::TestPReLU::test_wrong_number_of_channels(); + // Typed ActivationConfig tests + test_activations::TestTypedActivationConfig::test_simple_config(); + test_activations::TestTypedActivationConfig::test_all_simple_types(); + test_activations::TestTypedActivationConfig::test_leaky_relu_config(); + test_activations::TestTypedActivationConfig::test_prelu_single_slope_config(); + test_activations::TestTypedActivationConfig::test_prelu_multi_slope_config(); + test_activations::TestTypedActivationConfig::test_leaky_hardtanh_config(); + test_activations::TestTypedActivationConfig::test_from_json_string(); + test_activations::TestTypedActivationConfig::test_from_json_object(); + test_activations::TestTypedActivationConfig::test_from_json_prelu_multi(); + test_activations::TestTypedActivationConfig::test_unknown_activation_throws(); + test_dsp::test_construct(); test_dsp::test_get_input_level(); test_dsp::test_get_output_level(); diff --git a/tools/test/test_activations.cpp b/tools/test/test_activations.cpp index e9f7a86..abbdd23 100644 --- a/tools/test/test_activations.cpp +++ b/tools/test/test_activations.cpp @@ -41,7 +41,7 @@ class TestFastTanh { const std::string name = "Fasttanh"; auto a = nam::activations::Activation::get_activation(name); - _test_class(a); + _test_class(a.get()); } private: @@ -94,7 +94,7 @@ class TestLeakyReLU { const std::string name = "LeakyReLU"; auto a = nam::activations::Activation::get_activation(name); - _test_class(a); + _test_class(a.get()); } private: @@ -195,4 +195,150 @@ class TestPReLU } }; +class TestTypedActivationConfig +{ +public: + static void test_simple_config() + { + // Test simple() factory method + auto config = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); + assert(config.type == nam::activations::ActivationType::ReLU); + assert(!config.negative_slope.has_value()); + assert(!config.negative_slopes.has_value()); + + auto act = nam::activations::Activation::get_activation(config); + assert(act != nullptr); + } + + static void test_all_simple_types() + { + // Test that all simple activation types work + std::vector types = { + nam::activations::ActivationType::Tanh, nam::activations::ActivationType::Hardtanh, + nam::activations::ActivationType::Fasttanh, nam::activations::ActivationType::ReLU, + nam::activations::ActivationType::Sigmoid, nam::activations::ActivationType::SiLU, + nam::activations::ActivationType::Hardswish}; + + for (auto type : types) + { + auto config = nam::activations::ActivationConfig::simple(type); + auto act = nam::activations::Activation::get_activation(config); + assert(act != nullptr); + } + } + + static void test_leaky_relu_config() + { + // Test LeakyReLU with custom negative slope + nam::activations::ActivationConfig config; + config.type = nam::activations::ActivationType::LeakyReLU; + config.negative_slope = 0.2f; + + auto act = nam::activations::Activation::get_activation(config); + assert(act != nullptr); + + // Verify the behavior + std::vector data = {-1.0f, 0.0f, 1.0f}; + act->apply(data.data(), (long)data.size()); + assert(fabs(data[0] - (-0.2f)) < 1e-6); // -1.0 * 0.2 = -0.2 + assert(fabs(data[1] - 0.0f) < 1e-6); + assert(fabs(data[2] - 1.0f) < 1e-6); + } + + static void test_prelu_single_slope_config() + { + // Test PReLU with single slope + nam::activations::ActivationConfig config; + config.type = nam::activations::ActivationType::PReLU; + config.negative_slope = 0.25f; + + auto act = nam::activations::Activation::get_activation(config); + assert(act != nullptr); + } + + static void test_prelu_multi_slope_config() + { + // Test PReLU with multiple slopes (per-channel) + nam::activations::ActivationConfig config; + config.type = nam::activations::ActivationType::PReLU; + config.negative_slopes = std::vector{0.1f, 0.2f, 0.3f}; + + auto act = nam::activations::Activation::get_activation(config); + assert(act != nullptr); + + // Verify per-channel behavior + Eigen::MatrixXf data(3, 2); + data << -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f; + + act->apply(data); + + // Channel 0 (slope 0.1): -1.0 * 0.1 = -0.1 + assert(fabs(data(0, 0) - (-0.1f)) < 1e-6); + // Channel 1 (slope 0.2): -1.0 * 0.2 = -0.2 + assert(fabs(data(1, 0) - (-0.2f)) < 1e-6); + // Channel 2 (slope 0.3): -1.0 * 0.3 = -0.3 + assert(fabs(data(2, 0) - (-0.3f)) < 1e-6); + // Positive values unchanged + assert(fabs(data(0, 1) - 1.0f) < 1e-6); + } + + static void test_leaky_hardtanh_config() + { + // Test LeakyHardtanh with custom parameters + nam::activations::ActivationConfig config; + config.type = nam::activations::ActivationType::LeakyHardtanh; + config.min_val = -2.0f; + config.max_val = 2.0f; + config.min_slope = 0.1f; + config.max_slope = 0.1f; + + auto act = nam::activations::Activation::get_activation(config); + assert(act != nullptr); + } + + static void test_from_json_string() + { + // Test from_json with string input + nlohmann::json j = "ReLU"; + auto config = nam::activations::ActivationConfig::from_json(j); + assert(config.type == nam::activations::ActivationType::ReLU); + } + + static void test_from_json_object() + { + // Test from_json with object input + nlohmann::json j = {{"type", "LeakyReLU"}, {"negative_slope", 0.15f}}; + auto config = nam::activations::ActivationConfig::from_json(j); + assert(config.type == nam::activations::ActivationType::LeakyReLU); + assert(config.negative_slope.has_value()); + assert(fabs(config.negative_slope.value() - 0.15f) < 1e-6); + } + + static void test_from_json_prelu_multi() + { + // Test from_json with PReLU multi-slope + nlohmann::json j = {{"type", "PReLU"}, {"negative_slopes", {0.1f, 0.2f, 0.3f, 0.4f}}}; + auto config = nam::activations::ActivationConfig::from_json(j); + assert(config.type == nam::activations::ActivationType::PReLU); + assert(config.negative_slopes.has_value()); + assert(config.negative_slopes.value().size() == 4); + } + + static void test_unknown_activation_throws() + { + // Test that unknown activation type throws + nlohmann::json j = "UnknownActivation"; + bool threw = false; + try + { + nam::activations::ActivationConfig::from_json(j); + } + catch (const std::runtime_error& e) + { + threw = true; + } + assert(threw); + } +}; + }; // namespace test_activations diff --git a/tools/test/test_blending_detailed.cpp b/tools/test/test_blending_detailed.cpp index 7b774f5..b5c38db 100644 --- a/tools/test/test_blending_detailed.cpp +++ b/tools/test/test_blending_detailed.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "NAM/gating_activations.h" #include "NAM/activations.h" @@ -12,6 +13,13 @@ namespace test_blending_detailed { +// Helper to create a non-owning shared_ptr for stack-allocated activations in tests +template +nam::activations::Activation::Ptr make_test_ptr(T& activation) +{ + return nam::activations::Activation::Ptr(&activation, [](nam::activations::Activation*){}); +} + class TestBlendingDetailed { public: @@ -29,7 +37,7 @@ class TestBlendingDetailed // Test with default (linear) activations nam::activations::ActivationIdentity identity_act; nam::activations::ActivationIdentity identity_blend_act; - nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 2); + nam::gating_activations::BlendingActivation blending_act(make_test_ptr(identity_act), make_test_ptr(identity_blend_act), 2); blending_act.apply(input, output); // With linear activations: @@ -42,8 +50,8 @@ class TestBlendingDetailed assert(fabs(output(1, 1) - 4.0f) < 1e-6); // Test with sigmoid blending activation - nam::activations::Activation* sigmoid_act = nam::activations::Activation::get_activation("Sigmoid"); - nam::gating_activations::BlendingActivation blending_act_sigmoid(&identity_act, sigmoid_act, 2); + auto sigmoid_act = nam::activations::Activation::get_activation(std::string("Sigmoid")); + nam::gating_activations::BlendingActivation blending_act_sigmoid(make_test_ptr(identity_act), sigmoid_act, 2); Eigen::MatrixXf output_sigmoid(2, 2); blending_act_sigmoid.apply(input, output_sigmoid); @@ -80,7 +88,7 @@ class TestBlendingDetailed // Test with ReLU activation on input (which will change values < 0 to 0) nam::activations::ActivationReLU relu_act; nam::activations::ActivationIdentity identity_act; - nam::gating_activations::BlendingActivation blending_act(&relu_act, &identity_act, 1); + nam::gating_activations::BlendingActivation blending_act(make_test_ptr(relu_act), make_test_ptr(identity_act), 1); blending_act.apply(input, output); diff --git a/tools/test/test_convnet.cpp b/tools/test/test_convnet.cpp index 56bd5ec..0f55482 100644 --- a/tools/test/test_convnet.cpp +++ b/tools/test/test_convnet.cpp @@ -18,7 +18,7 @@ void test_convnet_basic() const int channels = 2; const std::vector dilations{1, 2}; const bool batchnorm = false; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const double expected_sample_rate = 48000.0; // Calculate weights needed: @@ -65,7 +65,7 @@ void test_convnet_batchnorm() const int channels = 1; const std::vector dilations{1}; const bool batchnorm = true; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const double expected_sample_rate = 48000.0; // Calculate weights needed: @@ -110,7 +110,7 @@ void test_convnet_multiple_blocks() const int channels = 2; const std::vector dilations{1, 2, 4}; const bool batchnorm = false; - const std::string activation = "Tanh"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const double expected_sample_rate = 48000.0; // Calculate weights needed: @@ -158,7 +158,7 @@ void test_convnet_zero_input() const int channels = 1; const std::vector dilations{1}; const bool batchnorm = false; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const double expected_sample_rate = 48000.0; std::vector weights; @@ -195,7 +195,7 @@ void test_convnet_different_buffer_sizes() const int channels = 1; const std::vector dilations{1}; const bool batchnorm = false; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const double expected_sample_rate = 48000.0; std::vector weights; @@ -235,7 +235,7 @@ void test_convnet_prewarm() const int channels = 2; const std::vector dilations{1, 2, 4}; const bool batchnorm = false; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const double expected_sample_rate = 48000.0; std::vector weights; @@ -278,7 +278,7 @@ void test_convnet_multiple_calls() const int channels = 1; const std::vector dilations{1}; const bool batchnorm = false; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const double expected_sample_rate = 48000.0; std::vector weights; diff --git a/tools/test/test_gating_activations.cpp b/tools/test/test_gating_activations.cpp index a67b872..90fd21a 100644 --- a/tools/test/test_gating_activations.cpp +++ b/tools/test/test_gating_activations.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "NAM/gating_activations.h" #include "NAM/activations.h" @@ -12,6 +13,13 @@ namespace test_gating_activations { +// Helper to create a non-owning shared_ptr for stack-allocated activations in tests +template +nam::activations::Activation::Ptr make_test_ptr(T& activation) +{ + return nam::activations::Activation::Ptr(&activation, [](nam::activations::Activation*){}); +} + class TestGatingActivation { public: @@ -26,7 +34,7 @@ class TestGatingActivation // Create gating activation with default activations (1 input channel, 1 gating channel) nam::activations::ActivationIdentity identity_act; nam::activations::ActivationSigmoid sigmoid_act; - nam::gating_activations::GatingActivation gating_act(&identity_act, &sigmoid_act, 1); + nam::gating_activations::GatingActivation gating_act(make_test_ptr(identity_act), make_test_ptr(sigmoid_act), 1); // Apply the activation gating_act.apply(input, output); @@ -52,7 +60,7 @@ class TestGatingActivation Eigen::MatrixXf output(1, 2); // Create gating activation with custom activations - nam::gating_activations::GatingActivation gating_act(&leaky_relu, &leaky_relu2, 1); + nam::gating_activations::GatingActivation gating_act(make_test_ptr(leaky_relu), make_test_ptr(leaky_relu2), 1); // Apply the activation gating_act.apply(input, output); @@ -85,7 +93,7 @@ class TestBlendingActivation // Create blending activation (1 input channel) nam::activations::ActivationIdentity identity_act; nam::activations::ActivationIdentity identity_blend_act; - nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 1); + nam::gating_activations::BlendingActivation blending_act(make_test_ptr(identity_act), make_test_ptr(identity_blend_act), 1); // Apply the activation blending_act.apply(input, output); @@ -107,7 +115,7 @@ class TestBlendingActivation // Test with default (linear) activations nam::activations::ActivationIdentity identity_act; nam::activations::ActivationIdentity identity_blend_act; - nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 1); + nam::gating_activations::BlendingActivation blending_act(make_test_ptr(identity_act), make_test_ptr(identity_blend_act), 1); blending_act.apply(input, output); // With linear activations, blending should be: @@ -118,8 +126,8 @@ class TestBlendingActivation assert(fabs(output(0, 1) - (-1.0f)) < 1e-6); // Test with sigmoid blending activation - nam::activations::Activation* sigmoid_act = nam::activations::Activation::get_activation("Sigmoid"); - nam::gating_activations::BlendingActivation blending_act2(&identity_act, sigmoid_act, 1); + auto sigmoid_act = nam::activations::Activation::get_activation(std::string("Sigmoid")); + nam::gating_activations::BlendingActivation blending_act2(make_test_ptr(identity_act), sigmoid_act, 1); blending_act2.apply(input, output); // With sigmoid blending, alpha values should be between 0 and 1 @@ -149,7 +157,7 @@ class TestBlendingActivation Eigen::MatrixXf output(1, 2); // Create blending activation with custom activations - nam::gating_activations::BlendingActivation blending_act(&leaky_relu, &leaky_relu2, 1); + nam::gating_activations::BlendingActivation blending_act(make_test_ptr(leaky_relu), make_test_ptr(leaky_relu2), 1); // Apply the activation blending_act.apply(input, output); @@ -167,7 +175,7 @@ class TestBlendingActivation nam::activations::ActivationIdentity identity_act; nam::activations::ActivationIdentity identity_blend_act; - nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 1); + nam::gating_activations::BlendingActivation blending_act(make_test_ptr(identity_act), make_test_ptr(identity_blend_act), 1); // This should trigger an assert and terminate the program // We can't easily test asserts in a unit test framework without special handling @@ -188,7 +196,7 @@ class TestBlendingActivation nam::activations::ActivationIdentity identity_act; nam::activations::ActivationIdentity identity_blend_act; - nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 1); + nam::gating_activations::BlendingActivation blending_act(make_test_ptr(identity_act), make_test_ptr(identity_blend_act), 1); blending_act.apply(input, output); assert(fabs(output(0, 0) - 0.0f) < 1e-6); diff --git a/tools/test/test_input_buffer_verification.cpp b/tools/test/test_input_buffer_verification.cpp index 01aa9e2..2f89565 100644 --- a/tools/test/test_input_buffer_verification.cpp +++ b/tools/test/test_input_buffer_verification.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "NAM/gating_activations.h" #include "NAM/activations.h" @@ -12,6 +13,13 @@ namespace test_input_buffer_verification { +// Helper to create a non-owning shared_ptr for stack-allocated activations in tests +template +nam::activations::Activation::Ptr make_test_ptr(T& activation) +{ + return nam::activations::Activation::Ptr(&activation, [](nam::activations::Activation*){}); +} + class TestInputBufferVerification { public: @@ -26,7 +34,7 @@ class TestInputBufferVerification // Use ReLU activation which will set negative values to 0 nam::activations::ActivationReLU relu_act; nam::activations::ActivationIdentity identity_act; - nam::gating_activations::BlendingActivation blending_act(&relu_act, &identity_act, 1); + nam::gating_activations::BlendingActivation blending_act(make_test_ptr(relu_act), make_test_ptr(identity_act), 1); // Apply the activation blending_act.apply(input, output); @@ -53,7 +61,7 @@ class TestInputBufferVerification // Use LeakyReLU with slope 0.1 nam::activations::ActivationLeakyReLU leaky_relu(0.1f); nam::activations::ActivationIdentity identity_act; - nam::gating_activations::BlendingActivation blending_act(&leaky_relu, &identity_act, 1); + nam::gating_activations::BlendingActivation blending_act(make_test_ptr(leaky_relu), make_test_ptr(identity_act), 1); blending_act.apply(input, output); diff --git a/tools/test/test_wavenet/test_condition_processing.cpp b/tools/test/test_wavenet/test_condition_processing.cpp index c4b5b1a..4a4c902 100644 --- a/tools/test/test_wavenet/test_condition_processing.cpp +++ b/tools/test/test_wavenet/test_condition_processing.cpp @@ -27,7 +27,7 @@ std::unique_ptr create_simple_wavenet( const int bottleneck = channels; const int kernel_size = 1; std::vector dilations{1}; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const bool with_head = false; diff --git a/tools/test/test_wavenet/test_full.cpp b/tools/test/test_wavenet/test_full.cpp index ee498c1..6be8787 100644 --- a/tools/test/test_wavenet/test_full.cpp +++ b/tools/test/test_wavenet/test_full.cpp @@ -22,7 +22,7 @@ void test_wavenet_model() const int bottleneck = channels; const int kernel_size = 1; std::vector dilations{1}; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const float head_scale = 1.0f; @@ -83,7 +83,7 @@ void test_wavenet_multiple_arrays() const int channels = 1; const int kernel_size = 1; std::vector dilations{1}; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const float head_scale = 0.5f; @@ -146,7 +146,7 @@ void test_wavenet_zero_input() const int bottleneck = channels; const int kernel_size = 1; std::vector dilations{1}; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const float head_scale = 1.0f; @@ -195,7 +195,7 @@ void test_wavenet_different_buffer_sizes() const int bottleneck = channels; const int kernel_size = 1; std::vector dilations{1}; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const float head_scale = 1.0f; @@ -247,7 +247,7 @@ void test_wavenet_prewarm() const int bottleneck = channels; const int kernel_size = 3; std::vector dilations{1, 2, 4}; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const float head_scale = 1.0f; diff --git a/tools/test/test_wavenet/test_head1x1.cpp b/tools/test/test_wavenet/test_head1x1.cpp index 18ff70b..5714a6b 100644 --- a/tools/test/test_wavenet/test_head1x1.cpp +++ b/tools/test/test_wavenet/test_head1x1.cpp @@ -21,7 +21,7 @@ void test_head1x1_inactive() const int bottleneck = channels; const int kernelSize = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; const int groups_1x1 = 1; @@ -87,7 +87,7 @@ void test_head1x1_active() const int bottleneck = channels; const int kernelSize = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; const int groups_1x1 = 1; @@ -160,7 +160,7 @@ void test_head1x1_gated() const int bottleneck = channels; const int kernelSize = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::GATED; const int groups_input = 1; const int groups_1x1 = 1; @@ -250,7 +250,7 @@ void test_head1x1_groups() const int bottleneck = channels; const int kernelSize = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; const int groups_1x1 = 1; @@ -330,7 +330,7 @@ void test_head1x1_different_out_channels() const int bottleneck = channels; const int kernelSize = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; const int groups_1x1 = 1; diff --git a/tools/test/test_wavenet/test_layer.cpp b/tools/test/test_wavenet/test_layer.cpp index ae43274..e349401 100644 --- a/tools/test/test_wavenet/test_layer.cpp +++ b/tools/test/test_wavenet/test_layer.cpp @@ -21,7 +21,7 @@ void test_gated() const int bottleneck = channels; const int kernelSize = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::GATED; const int groups_input = 1; const int groups_1x1 = 1; @@ -97,7 +97,7 @@ void test_layer_getters() const int bottleneck = channels; const int kernelSize = 3; const int dilation = 2; - const std::string activation = "Tanh"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; const int groups_1x1 = 1; @@ -119,7 +119,7 @@ void test_non_gated_layer() const int bottleneck = channels; const int kernelSize = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; const int groups_1x1 = 1; @@ -193,7 +193,8 @@ void test_layer_activations() const int groups_input = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); - auto layer = nam::wavenet::_Layer(conditionSize, channels, bottleneck, kernelSize, dilation, "Tanh", gating_mode, + auto tanh_config = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); + auto layer = nam::wavenet::_Layer(conditionSize, channels, bottleneck, kernelSize, dilation, tanh_config, gating_mode, groups_input, groups_1x1, head1x1_params, ""); std::vector weights{1.0f, 0.0f, 1.0f, 1.0f, 0.0f}; auto it = weights.begin(); @@ -224,7 +225,7 @@ void test_layer_multichannel() const int bottleneck = channels; const int kernelSize = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; const int groups_1x1 = 1; @@ -294,7 +295,7 @@ void test_layer_bottleneck() const int bottleneck = 2; // bottleneck < channels const int kernelSize = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; const int groups_1x1 = 1; @@ -371,7 +372,7 @@ void test_layer_bottleneck_gated() const int bottleneck = 2; // bottleneck < channels const int kernelSize = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::GATED; // gated doubles the internal bottleneck channels const int groups_input = 1; diff --git a/tools/test/test_wavenet/test_layer_array.cpp b/tools/test/test_wavenet/test_layer_array.cpp index d5916a2..ba7cb39 100644 --- a/tools/test/test_wavenet/test_layer_array.cpp +++ b/tools/test/test_wavenet/test_layer_array.cpp @@ -22,7 +22,7 @@ void test_layer_array_basic() const int bottleneck = channels; const int kernel_size = 1; std::vector dilations{1, 2}; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const int groups = 1; @@ -81,7 +81,7 @@ void test_layer_array_receptive_field() const int bottleneck = channels; const int kernel_size = 3; std::vector dilations{1, 2, 4}; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const int groups = 1; @@ -112,7 +112,7 @@ void test_layer_array_with_head_input() const int bottleneck = channels; const int kernel_size = 1; std::vector dilations{1}; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const int groups = 1; diff --git a/tools/test/test_wavenet/test_real_time_safe.cpp b/tools/test/test_wavenet/test_real_time_safe.cpp index cc04150..b89e9dd 100644 --- a/tools/test/test_wavenet/test_real_time_safe.cpp +++ b/tools/test/test_wavenet/test_real_time_safe.cpp @@ -432,7 +432,7 @@ void test_layer_process_realtime_safe() const int bottleneck = channels; const int kernel_size = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; const int groups_1x1 = 1; @@ -488,7 +488,7 @@ void test_layer_bottleneck_process_realtime_safe() const int bottleneck = 2; // bottleneck < channels const int kernel_size = 1; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; const int groups_1x1 = 1; @@ -574,7 +574,7 @@ void test_layer_grouped_process_realtime_safe() const int bottleneck = channels; const int kernel_size = 2; const int dilation = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 2; // groups_input > 1 const int groups_1x1 = 2; // 1x1 is also grouped @@ -685,7 +685,7 @@ void test_layer_array_process_realtime_safe() const int bottleneck = channels; const int kernel_size = 1; std::vector dilations{1}; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const int groups = 1; @@ -749,7 +749,7 @@ void test_process_realtime_safe() const int channels = 1; const int kernel_size = 1; std::vector dilations{1}; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const float head_scale = 1.0f; @@ -827,7 +827,7 @@ void test_process_3in_2out_realtime_safe() const int channels = 4; // internal channels const int bottleneck = 2; // bottleneck (will be used for head) const int kernel_size = 1; - const std::string activation = "ReLU"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const float head_scale = 1.0f; diff --git a/tools/test/test_wavenet_configurable_gating.cpp b/tools/test/test_wavenet_configurable_gating.cpp index a98326c..2c0d13b 100644 --- a/tools/test/test_wavenet_configurable_gating.cpp +++ b/tools/test/test_wavenet_configurable_gating.cpp @@ -21,7 +21,7 @@ class TestConfigurableGating const int bottleneck = 2; const int kernelSize = 3; const int dilation = 1; - const std::string activation = "Tanh"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const int groups_input = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); @@ -48,7 +48,7 @@ class TestConfigurableGating const int bottleneck = 2; const int kernelSize = 3; const int dilation = 1; - const std::string activation = "Tanh"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const int groups_input = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); @@ -78,7 +78,7 @@ class TestConfigurableGating const int bottleneck = 2; const int kernel_size = 3; const std::vector dilations = {1, 2}; - const std::string activation = "Tanh"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const bool head_bias = false; const int groups_input = 1; const int groups_1x1 = 1; @@ -111,7 +111,7 @@ class TestConfigurableGating const int bottleneck = 2; const int kernel_size = 3; const std::vector dilations = {1}; - const std::string activation = "Tanh"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const bool head_bias = false; const int groups_input = 1; const int groups_1x1 = 1; @@ -174,7 +174,7 @@ class TestConfigurableGating const int bottleneck = 2; const int kernelSize = 3; const int dilation = 1; - const std::string activation = "Tanh"; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const int groups_input = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); diff --git a/tools/test/test_wavenet_gating_compatibility.cpp b/tools/test/test_wavenet_gating_compatibility.cpp index f3ad8e6..6e443bb 100644 --- a/tools/test/test_wavenet_gating_compatibility.cpp +++ b/tools/test/test_wavenet_gating_compatibility.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "NAM/gating_activations.h" #include "NAM/activations.h" @@ -12,6 +13,13 @@ namespace test_wavenet_gating_compatibility { +// Helper to create a non-owning shared_ptr for stack-allocated activations in tests +template +nam::activations::Activation::Ptr make_test_ptr(T& activation) +{ + return nam::activations::Activation::Ptr(&activation, [](nam::activations::Activation*){}); +} + class TestWavenetGatingCompatibility { public: @@ -35,7 +43,7 @@ class TestWavenetGatingCompatibility // Wavenet uses: input activation (default/linear) and sigmoid for gating nam::activations::ActivationIdentity identity_act; nam::activations::ActivationSigmoid sigmoid_act; - nam::gating_activations::GatingActivation gating_act(&identity_act, &sigmoid_act, channels); + nam::gating_activations::GatingActivation gating_act(make_test_ptr(identity_act), make_test_ptr(sigmoid_act), channels); // Apply the activation gating_act.apply(input, output); @@ -84,7 +92,7 @@ class TestWavenetGatingCompatibility nam::activations::ActivationIdentity identity_act; nam::activations::ActivationSigmoid sigmoid_act; - nam::gating_activations::GatingActivation gating_act(&identity_act, &sigmoid_act, channels); + nam::gating_activations::GatingActivation gating_act(make_test_ptr(identity_act), make_test_ptr(sigmoid_act), channels); gating_act.apply(input, output); // Verify each column was processed independently @@ -120,7 +128,7 @@ class TestWavenetGatingCompatibility nam::activations::ActivationIdentity identity_act; nam::activations::ActivationSigmoid sigmoid_act; - nam::gating_activations::GatingActivation gating_act(&identity_act, &sigmoid_act, channels); + nam::gating_activations::GatingActivation gating_act(make_test_ptr(identity_act), make_test_ptr(sigmoid_act), channels); // This should not crash or produce incorrect results due to memory contiguity issues gating_act.apply(input, output); @@ -155,7 +163,7 @@ class TestWavenetGatingCompatibility nam::activations::ActivationIdentity identity_act; nam::activations::ActivationSigmoid sigmoid_act; - nam::gating_activations::GatingActivation gating_act(&identity_act, &sigmoid_act, channels); + nam::gating_activations::GatingActivation gating_act(make_test_ptr(identity_act), make_test_ptr(sigmoid_act), channels); gating_act.apply(input, output); // Verify dimensions