diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index 023c42a..02a4a13 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -417,7 +417,7 @@ Eigen::MatrixXf nam::Conv1x1::process(const Eigen::MatrixXf& input, const int nu return result; } -void nam::Conv1x1::process_(const Eigen::MatrixXf& input, const int num_frames) +void nam::Conv1x1::process_(const Eigen::Ref& input, const int num_frames) { assert(num_frames <= _output.cols()); diff --git a/NAM/dsp.h b/NAM/dsp.h index ee6f762..73319a2 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -207,7 +207,8 @@ class Conv1x1 Eigen::MatrixXf process(const Eigen::MatrixXf& input) const { return process(input, (int)input.cols()); }; Eigen::MatrixXf process(const Eigen::MatrixXf& input, const int num_frames) const; // Store output to pre-allocated _output; access with GetOutput() - void process_(const Eigen::MatrixXf& input, const int num_frames); + // Uses Eigen::Ref to accept matrices and block expressions without creating temporaries (real-time safe) + void process_(const Eigen::Ref& input, const int num_frames); long get_out_channels() const { return this->_weight.rows(); }; long get_in_channels() const { return this->_weight.cols(); }; diff --git a/NAM/film.h b/NAM/film.h index 6b273e6..b5376f0 100644 --- a/NAM/film.h +++ b/NAM/film.h @@ -47,7 +47,9 @@ class FiLM // :param input: (input_dim x num_frames) // :param condition: (condition_dim x num_frames) // Writes (input_dim x num_frames) into internal output buffer; access via GetOutput(). - void Process(const Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames) + // Uses Eigen::Ref to accept matrices and block expressions without creating temporaries (real-time safe) + void Process(const Eigen::Ref& input, const Eigen::Ref& condition, + const int num_frames) { assert(get_input_dim() == input.rows()); assert(get_condition_dim() == condition.rows()); @@ -72,7 +74,9 @@ class FiLM } // in-place - void Process_(Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames) + // Uses Eigen::Ref to accept matrices and block expressions without creating temporaries (real-time safe) + void Process_(Eigen::Ref input, const Eigen::Ref& condition, + const int num_frames) { Process(input, condition, num_frames); input.leftCols(num_frames).noalias() = _output.leftCols(num_frames); diff --git a/NAM/gating_activations.h b/NAM/gating_activations.h index 996a676..335a984 100644 --- a/NAM/gating_activations.h +++ b/NAM/gating_activations.h @@ -42,9 +42,10 @@ class GatingActivation { throw std::invalid_argument("GatingActivation: number of input channels must be positive"); } - // Initialize input buffer with correct size + // Initialize buffers with correct size // Note: current code copies column-by-column so we only need (num_channels, 1) input_buffer.resize(num_channels, 1); + gating_buffer.resize(num_channels, 1); } ~GatingActivation() = default; @@ -64,23 +65,20 @@ class GatingActivation assert(output.cols() == input.cols()); // Process column-by-column to ensure memory contiguity (important for column-major matrices) + // Uses pre-allocated buffers to avoid allocations in the loop (real-time safe) const int num_samples = input.cols(); for (int i = 0; i < num_samples; i++) { - // Store pre-activation input values in buffer to avoid overwriting issues + // Copy to pre-allocated buffers and apply activations in-place input_buffer = input.block(0, i, num_channels, 1); + input_activation->apply(input_buffer); - // Apply activation to input channels - Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1); - input_activation->apply(input_block); - - // Apply activation to gating channels - Eigen::MatrixXf gating_block = input.block(num_channels, i, num_channels, 1); - gating_activation->apply(gating_block); + gating_buffer = input.block(num_channels, i, num_channels, 1); + gating_activation->apply(gating_buffer); // Element-wise multiplication and store result // For wavenet compatibility, we assume one-to-one mapping - output.block(0, i, num_channels, 1) = input_block.array() * gating_block.array(); + output.block(0, i, num_channels, 1) = input_buffer.array() * gating_buffer.array(); } } @@ -99,6 +97,7 @@ class GatingActivation activations::Activation::Ptr gating_activation; int num_channels; Eigen::MatrixXf input_buffer; + Eigen::MatrixXf gating_buffer; }; class BlendingActivation @@ -118,9 +117,11 @@ class BlendingActivation { assert(num_channels > 0); - // Initialize input buffer with correct size + // Initialize buffers with correct size // Note: current code copies column-by-column so we only need (num_channels, 1) + pre_activation_buffer.resize(num_channels, 1); input_buffer.resize(num_channels, 1); + blend_buffer.resize(num_channels, 1); } ~BlendingActivation() = default; @@ -140,23 +141,24 @@ class BlendingActivation assert(output.cols() == input.cols()); // Process column-by-column to ensure memory contiguity + // Uses pre-allocated buffers to avoid allocations in the loop (real-time safe) const int num_samples = input.cols(); for (int i = 0; i < num_samples; i++) { // Store pre-activation input values in buffer - input_buffer = input.block(0, i, num_channels, 1); + pre_activation_buffer = input.block(0, i, num_channels, 1); - // Apply activation to input channels - Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1); - input_activation->apply(input_block); + // Copy to pre-allocated buffer and apply activation to input channels + input_buffer = input.block(0, i, num_channels, 1); + input_activation->apply(input_buffer); - // Apply activation to blend channels to compute alpha - Eigen::MatrixXf blend_block = input.block(num_channels, i, num_channels, 1); - blending_activation->apply(blend_block); + // Copy to pre-allocated buffer and apply activation to blend channels to compute alpha + blend_buffer = input.block(num_channels, i, num_channels, 1); + blending_activation->apply(blend_buffer); // Weighted blending: alpha * activated_input + (1 - alpha) * pre_activation_input output.block(0, i, num_channels, 1) = - blend_block.array() * input_block.array() + (1.0f - blend_block.array()) * input_buffer.array(); + blend_buffer.array() * input_buffer.array() + (1.0f - blend_buffer.array()) * pre_activation_buffer.array(); } } @@ -174,7 +176,9 @@ class BlendingActivation activations::Activation::Ptr input_activation; activations::Activation::Ptr blending_activation; int num_channels; + Eigen::MatrixXf pre_activation_buffer; Eigen::MatrixXf input_buffer; + Eigen::MatrixXf blend_buffer; }; diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 1224041..9717df9 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -47,8 +47,6 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize) this->_activation_pre_film->SetMaxBufferSize(maxBufferSize); if (this->_activation_post_film) this->_activation_post_film->SetMaxBufferSize(maxBufferSize); - if (this->_gating_activation_post_film) - this->_gating_activation_post_film->SetMaxBufferSize(maxBufferSize); if (this->_1x1_post_film) this->_1x1_post_film->SetMaxBufferSize(maxBufferSize); if (this->_head1x1_post_film) @@ -77,8 +75,6 @@ void nam::wavenet::_Layer::set_weights_(std::vector::iterator& weights) this->_activation_pre_film->set_weights_(weights); if (this->_activation_post_film) this->_activation_post_film->set_weights_(weights); - if (this->_gating_activation_post_film) - this->_gating_activation_post_film->set_weights_(weights); if (this->_1x1_post_film) this->_1x1_post_film->set_weights_(weights); if (this->_head1x1_post_film) @@ -150,12 +146,12 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma auto input_block = this->_z.leftCols(num_frames); auto output_block = this->_z.topRows(bottleneck).leftCols(num_frames); this->_gating_activation->apply(input_block, output_block); - if (this->_gating_activation_post_film) + if (this->_activation_post_film) { // Use Process() for blocks and copy result back - this->_gating_activation_post_film->Process(this->_z.topRows(bottleneck), condition, num_frames); + this->_activation_post_film->Process(this->_z.topRows(bottleneck), condition, num_frames); this->_z.topRows(bottleneck).leftCols(num_frames).noalias() = - this->_gating_activation_post_film->GetOutput().leftCols(num_frames); + this->_activation_post_film->GetOutput().leftCols(num_frames); } _1x1.process_(this->_z.topRows(bottleneck), num_frames); } @@ -219,23 +215,23 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma nam::wavenet::_LayerArray::_LayerArray( const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck, const int kernel_size, const std::vector& dilations, const activations::ActivationConfig& activation_config, - const GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_1x1, - const Head1x1Params& head1x1_params, const std::string& secondary_activation, const _FiLMParams& conv_pre_film_params, + const GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_input_mixin, + const int groups_1x1, const Head1x1Params& head1x1_params, + const activations::ActivationConfig& secondary_activation_config, const _FiLMParams& conv_pre_film_params, const _FiLMParams& conv_post_film_params, const _FiLMParams& input_mixin_pre_film_params, const _FiLMParams& input_mixin_post_film_params, const _FiLMParams& activation_pre_film_params, - const _FiLMParams& activation_post_film_params, const _FiLMParams& gating_activation_post_film_params, - const _FiLMParams& _1x1_post_film_params, const _FiLMParams& head1x1_post_film_params) + const _FiLMParams& activation_post_film_params, const _FiLMParams& _1x1_post_film_params, + const _FiLMParams& head1x1_post_film_params) : _rechannel(input_size, channels, false) -, _head_rechannel(bottleneck, head_size, head_bias) -, _bottleneck(bottleneck) +, _head_rechannel(head1x1_params.active ? head1x1_params.out_channels : bottleneck, head_size, head_bias) +, _head_output_size(head1x1_params.active ? head1x1_params.out_channels : bottleneck) { for (size_t i = 0; i < dilations.size(); i++) - this->_layers.push_back(_Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation_config, - gating_mode, groups_input, groups_1x1, head1x1_params, secondary_activation, - conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params, - input_mixin_post_film_params, activation_pre_film_params, - activation_post_film_params, gating_activation_post_film_params, - _1x1_post_film_params, head1x1_post_film_params)); + this->_layers.push_back( + _Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation_config, gating_mode, + groups_input, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation_config, + conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params, + activation_pre_film_params, activation_post_film_params, _1x1_post_film_params, head1x1_post_film_params)); } void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize) @@ -249,7 +245,8 @@ void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize) // Pre-allocate output buffers const long channels = this->_get_channels(); this->_layer_outputs.resize(channels, maxBufferSize); - this->_head_inputs.resize(this->_bottleneck, maxBufferSize); + // _head_inputs size matches actual head output: head1x1.out_channels if active, else bottleneck + this->_head_inputs.resize(this->_head_output_size, maxBufferSize); } @@ -386,12 +383,12 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels, layer_array_params[i].input_size, layer_array_params[i].condition_size, layer_array_params[i].head_size, layer_array_params[i].channels, layer_array_params[i].bottleneck, layer_array_params[i].kernel_size, layer_array_params[i].dilations, layer_array_params[i].activation_config, layer_array_params[i].gating_mode, - layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_1x1, - layer_array_params[i].head1x1_params, layer_array_params[i].secondary_activation, - layer_array_params[i].conv_pre_film_params, layer_array_params[i].conv_post_film_params, - layer_array_params[i].input_mixin_pre_film_params, layer_array_params[i].input_mixin_post_film_params, - layer_array_params[i].activation_pre_film_params, layer_array_params[i].activation_post_film_params, - layer_array_params[i].gating_activation_post_film_params, layer_array_params[i]._1x1_post_film_params, + layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_input_mixin, + layer_array_params[i].groups_1x1, layer_array_params[i].head1x1_params, + layer_array_params[i].secondary_activation_config, layer_array_params[i].conv_pre_film_params, + layer_array_params[i].conv_post_film_params, layer_array_params[i].input_mixin_pre_film_params, + layer_array_params[i].input_mixin_post_film_params, layer_array_params[i].activation_pre_film_params, + layer_array_params[i].activation_post_film_params, layer_array_params[i]._1x1_post_film_params, layer_array_params[i].head1x1_post_film_params)); if (i > 0) if (layer_array_params[i].channels != layer_array_params[i - 1].head_size) @@ -583,7 +580,8 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st { nlohmann::json layer_config = config["layers"][i]; - const int groups = layer_config.value("groups", 1); // defaults to 1 + const int groups = layer_config.value("groups_input", 1); // defaults to 1 + const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); // defaults to 1 const int groups_1x1 = layer_config.value("groups_1x1", 1); // defaults to 1 const int channels = layer_config["channels"]; @@ -599,7 +597,7 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st activations::ActivationConfig::from_json(layer_config["activation"]); // Parse gating mode - support both old "gated" boolean and new "gating_mode" string GatingMode gating_mode = GatingMode::NONE; - std::string secondary_activation; + activations::ActivationConfig secondary_activation_config; if (layer_config.find("gating_mode") != layer_config.end()) { @@ -607,17 +605,17 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st if (gating_mode_str == "gated") { gating_mode = GatingMode::GATED; - secondary_activation = layer_config["secondary_activation"].get(); + secondary_activation_config = activations::ActivationConfig::from_json(layer_config["secondary_activation"]); } else if (gating_mode_str == "blended") { gating_mode = GatingMode::BLENDED; - secondary_activation = layer_config["secondary_activation"].get(); + secondary_activation_config = activations::ActivationConfig::from_json(layer_config["secondary_activation"]); } else if (gating_mode_str == "none") { gating_mode = GatingMode::NONE; - secondary_activation.clear(); + // Leave secondary_activation_config with empty type } else throw std::runtime_error("Invalid gating_mode: " + gating_mode_str); @@ -629,12 +627,9 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st gating_mode = gated ? GatingMode::GATED : GatingMode::NONE; if (gated) { - secondary_activation = "Sigmoid"; - } - else - { - secondary_activation.clear(); + secondary_activation_config = activations::ActivationConfig::simple(activations::ActivationType::Sigmoid); } + // else: leave secondary_activation_config uninitialized } else { @@ -644,9 +639,16 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st const bool head_bias = layer_config["head_bias"]; // Parse head1x1 parameters - bool head1x1_active = layer_config.value("head1x1_active", false); - int head1x1_out_channels = layer_config.value("head1x1_out_channels", channels); - int head1x1_groups = layer_config.value("head1x1_groups", 1); + bool head1x1_active = false; + int head1x1_out_channels = channels; + int head1x1_groups = 1; + if (layer_config.find("head_1x1") != layer_config.end()) + { + const auto& head1x1_config = layer_config["head_1x1"]; + head1x1_active = head1x1_config["active"]; + head1x1_out_channels = head1x1_config["out_channels"]; + head1x1_groups = head1x1_config["groups"]; + } nam::wavenet::Head1x1Params head1x1_params(head1x1_active, head1x1_out_channels, head1x1_groups); // Helper function to parse FiLM parameters @@ -668,16 +670,14 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st nam::wavenet::_FiLMParams input_mixin_post_film_params = parse_film_params("input_mixin_post_film"); nam::wavenet::_FiLMParams activation_pre_film_params = parse_film_params("activation_pre_film"); nam::wavenet::_FiLMParams activation_post_film_params = parse_film_params("activation_post_film"); - nam::wavenet::_FiLMParams gating_activation_post_film_params = parse_film_params("gating_activation_post_film"); nam::wavenet::_FiLMParams _1x1_post_film_params = parse_film_params("1x1_post_film"); nam::wavenet::_FiLMParams head1x1_post_film_params = parse_film_params("head1x1_post_film"); layer_array_params.push_back(nam::wavenet::LayerArrayParams( input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation_config, - gating_mode, head_bias, groups, groups_1x1, head1x1_params, secondary_activation, conv_pre_film_params, - conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params, - activation_post_film_params, gating_activation_post_film_params, _1x1_post_film_params, - head1x1_post_film_params)); + gating_mode, head_bias, groups, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation_config, + conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params, + activation_pre_film_params, activation_post_film_params, _1x1_post_film_params, head1x1_post_film_params)); } const bool with_head = !config["head"].is_null(); const float head_scale = config["head_scale"]; diff --git a/NAM/wavenet.h b/NAM/wavenet.h index a27861b..336d2e4 100644 --- a/NAM/wavenet.h +++ b/NAM/wavenet.h @@ -64,15 +64,17 @@ class _Layer // Constructor with GatingMode enum and typed ActivationConfig _Layer(const int condition_size, const int channels, const int bottleneck, const int kernel_size, const int dilation, const activations::ActivationConfig& activation_config, const GatingMode gating_mode, const int groups_input, - const int groups_1x1, const Head1x1Params& head1x1_params, const std::string& secondary_activation, - const _FiLMParams& conv_pre_film_params, const _FiLMParams& conv_post_film_params, - const _FiLMParams& input_mixin_pre_film_params, const _FiLMParams& input_mixin_post_film_params, - const _FiLMParams& activation_pre_film_params, const _FiLMParams& activation_post_film_params, - const _FiLMParams& gating_activation_post_film_params, const _FiLMParams& _1x1_post_film_params, + const int groups_input_mixin, const int groups_1x1, const Head1x1Params& head1x1_params, + const activations::ActivationConfig& secondary_activation_config, const _FiLMParams& conv_pre_film_params, + const _FiLMParams& conv_post_film_params, const _FiLMParams& input_mixin_pre_film_params, + const _FiLMParams& input_mixin_post_film_params, const _FiLMParams& activation_pre_film_params, + const _FiLMParams& activation_post_film_params, const _FiLMParams& _1x1_post_film_params, const _FiLMParams& head1x1_post_film_params) - : _conv(channels, (gating_mode != GatingMode::NONE) ? 2 * bottleneck : bottleneck, kernel_size, true, dilation) - , _input_mixin(condition_size, (gating_mode != GatingMode::NONE) ? 2 * bottleneck : bottleneck, false) - , _1x1(bottleneck, channels, groups_1x1) + : _conv(channels, (gating_mode != GatingMode::NONE) ? 2 * bottleneck : bottleneck, kernel_size, true, dilation, + groups_input) + , _input_mixin( + condition_size, (gating_mode != GatingMode::NONE) ? 2 * bottleneck : bottleneck, false, groups_input_mixin) + , _1x1(bottleneck, channels, true, groups_1x1) , _activation(activations::Activation::get_activation(activation_config)) , _gating_mode(gating_mode) , _bottleneck(bottleneck) @@ -81,26 +83,25 @@ class _Layer { _head1x1 = std::make_unique(bottleneck, head1x1_params.out_channels, true, head1x1_params.groups); } + else + { + // If there's a post-head 1x1 FiLM but no head 1x1, this is redundant--don't allow it + if (head1x1_post_film_params.active) + { + throw std::invalid_argument("Do not use post-head 1x1 FiLM if there is no head 1x1"); + } + } // Validate & initialize gating/blending activation if (gating_mode == GatingMode::GATED) { - if (secondary_activation.empty()) - throw std::invalid_argument("secondary_activation must be provided for gated mode"); _gating_activation = std::make_unique( - _activation, activations::Activation::get_activation(secondary_activation), bottleneck); + _activation, activations::Activation::get_activation(secondary_activation_config), bottleneck); } else if (gating_mode == GatingMode::BLENDED) { - if (secondary_activation.empty()) - throw std::invalid_argument("secondary_activation must be provided for blended mode"); _blending_activation = std::make_unique( - _activation, activations::Activation::get_activation(secondary_activation), bottleneck); - } - else - { - if (!secondary_activation.empty()) - throw std::invalid_argument("secondary_activation provided for none mode"); + _activation, activations::Activation::get_activation(secondary_activation_config), bottleneck); } // Initialize FiLM objects @@ -132,11 +133,6 @@ class _Layer { _activation_post_film = std::make_unique(condition_size, bottleneck, activation_post_film_params.shift); } - if (gating_activation_post_film_params.active) - { - _gating_activation_post_film = - std::make_unique(condition_size, bottleneck, gating_activation_post_film_params.shift); - } if (_1x1_post_film_params.active) { _1x1_post_film = std::make_unique(condition_size, channels, _1x1_post_film_params.shift); @@ -211,7 +207,6 @@ class _Layer std::unique_ptr _input_mixin_post_film; std::unique_ptr _activation_pre_film; std::unique_ptr _activation_post_film; - std::unique_ptr _gating_activation_post_film; std::unique_ptr _1x1_post_film; std::unique_ptr _head1x1_post_film; }; @@ -222,13 +217,13 @@ class LayerArrayParams LayerArrayParams(const int input_size_, const int condition_size_, const int head_size_, const int channels_, const int bottleneck_, const int kernel_size_, const std::vector&& dilations_, const activations::ActivationConfig& activation_, const GatingMode gating_mode_, - const bool head_bias_, const int groups_input, const int groups_1x1_, - const Head1x1Params& head1x1_params_, const std::string& secondary_activation_, + const bool head_bias_, const int groups_input, const int groups_input_mixin_, const int groups_1x1_, + const Head1x1Params& head1x1_params_, + const activations::ActivationConfig& secondary_activation_config_, const _FiLMParams& conv_pre_film_params_, const _FiLMParams& conv_post_film_params_, const _FiLMParams& input_mixin_pre_film_params_, const _FiLMParams& input_mixin_post_film_params_, const _FiLMParams& activation_pre_film_params_, const _FiLMParams& activation_post_film_params_, - const _FiLMParams& gating_activation_post_film_params_, const _FiLMParams& _1x1_post_film_params_, - const _FiLMParams& head1x1_post_film_params_) + const _FiLMParams& _1x1_post_film_params_, const _FiLMParams& head1x1_post_film_params_) : input_size(input_size_) , condition_size(condition_size_) , head_size(head_size_) @@ -240,16 +235,16 @@ class LayerArrayParams , gating_mode(gating_mode_) , head_bias(head_bias_) , groups_input(groups_input) + , groups_input_mixin(groups_input_mixin_) , groups_1x1(groups_1x1_) , head1x1_params(head1x1_params_) - , secondary_activation(secondary_activation_) + , secondary_activation_config(secondary_activation_config_) , conv_pre_film_params(conv_pre_film_params_) , conv_post_film_params(conv_post_film_params_) , input_mixin_pre_film_params(input_mixin_pre_film_params_) , input_mixin_post_film_params(input_mixin_post_film_params_) , activation_pre_film_params(activation_pre_film_params_) , activation_post_film_params(activation_post_film_params_) - , gating_activation_post_film_params(gating_activation_post_film_params_) , _1x1_post_film_params(_1x1_post_film_params_) , head1x1_post_film_params(head1x1_post_film_params_) { @@ -266,16 +261,16 @@ class LayerArrayParams const GatingMode gating_mode; const bool head_bias; const int groups_input; + const int groups_input_mixin; const int groups_1x1; const Head1x1Params head1x1_params; - const std::string secondary_activation; + const activations::ActivationConfig secondary_activation_config; const _FiLMParams conv_pre_film_params; const _FiLMParams conv_post_film_params; const _FiLMParams input_mixin_pre_film_params; const _FiLMParams input_mixin_post_film_params; const _FiLMParams activation_pre_film_params; const _FiLMParams activation_post_film_params; - const _FiLMParams gating_activation_post_film_params; const _FiLMParams _1x1_post_film_params; const _FiLMParams head1x1_post_film_params; }; @@ -288,11 +283,11 @@ class _LayerArray _LayerArray(const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck, const int kernel_size, const std::vector& dilations, const activations::ActivationConfig& activation_config, const GatingMode gating_mode, - const bool head_bias, const int groups_input, const int groups_1x1, const Head1x1Params& head1x1_params, - const std::string& secondary_activation, const _FiLMParams& conv_pre_film_params, - const _FiLMParams& conv_post_film_params, const _FiLMParams& input_mixin_pre_film_params, - const _FiLMParams& input_mixin_post_film_params, const _FiLMParams& activation_pre_film_params, - const _FiLMParams& activation_post_film_params, const _FiLMParams& gating_activation_post_film_params, + const bool head_bias, const int groups_input, const int groups_input_mixin, const int groups_1x1, + const Head1x1Params& head1x1_params, const activations::ActivationConfig& secondary_activation_config, + const _FiLMParams& conv_pre_film_params, const _FiLMParams& conv_post_film_params, + const _FiLMParams& input_mixin_pre_film_params, const _FiLMParams& input_mixin_post_film_params, + const _FiLMParams& activation_pre_film_params, const _FiLMParams& activation_post_film_params, const _FiLMParams& _1x1_post_film_params, const _FiLMParams& head1x1_post_film_params); void SetMaxBufferSize(const int maxBufferSize); @@ -331,14 +326,15 @@ class _LayerArray std::vector<_Layer> _layers; // Output from last layer (for next layer array) Eigen::MatrixXf _layer_outputs; - // Accumulated head inputs from all layers (bottleneck channels) + // Accumulated head inputs from all layers + // Size is _head_output_size (= head1x1.out_channels if head1x1 active, else bottleneck) Eigen::MatrixXf _head_inputs; - // Rechannel for the head (bottleneck -> head_size) + // Rechannel for the head (_head_output_size -> head_size) Conv1x1 _head_rechannel; - // Bottleneck size (internal channel count) - const int _bottleneck; + // Head output size from each layer (head1x1.out_channels if active, else bottleneck) + const int _head_output_size; long _get_channels() const; // Common processing logic after head inputs are set diff --git a/example_models/wavenet_a2_max.nam b/example_models/wavenet_a2_max.nam new file mode 100644 index 0000000..cea7420 --- /dev/null +++ b/example_models/wavenet_a2_max.nam @@ -0,0 +1,2802 @@ +{ + "notes": [ + "This model is meant as a 'test case' to contain all of the new features that are being considered for A2.", + "It doesn't have slimmability." + ], + "version": "0.5.4", + "metadata": { + "date": { + "year": 2026, + "month": 1, + "day": 19, + "hour": 14, + "minute": 51, + "second": 21 + }, + "loudness": -20.0, + "gain": 0.2, + "name": "Generated WaveNet Model", + "modeled_by": "create_wavenet.py" + }, + "architecture": "WaveNet", + "config": { + "condition_dsp": { + "version": "0.5.4", + "metadata": { + "date": { + "year": 2026, + "month": 1, + "day": 19, + "hour": 14, + "minute": 51, + "second": 18 + }, + "loudness": -20.0, + "gain": 0.2, + "name": "Generated WaveNet Model", + "modeled_by": "create_wavenet.py" + }, + "architecture": "WaveNet", + "config": { + "layers": [ + { + "input_size": 1, + "condition_size": 1, + "head_size": 4, + "channels": 3, + "bottleneck": 6, + "kernel_size": 2, + "dilations": [ + 1, + 2 + ], + "activation": { + "type": "SiLU" + }, + "gating_mode": "gated", + "head_bias": false, + "groups_input": 3, + "groups_input_mixin": 1, + "groups_1x1": 3, + "head_1x1": { + "active": true, + "out_channels": 6, + "groups": 3 + }, + "secondary_activation": "Hardswish", + "conv_pre_film": { + "active": true, + "shift": true + }, + "conv_post_film": { + "active": true, + "shift": true + }, + "input_mixin_pre_film": { + "active": true, + "shift": true + }, + "input_mixin_post_film": { + "active": true, + "shift": true + }, + "activation_pre_film": { + "active": true, + "shift": true + }, + "activation_post_film": { + "active": true, + "shift": true + }, + "1x1_post_film": { + "active": true, + "shift": true + }, + "head1x1_post_film": { + "active": true, + "shift": true + } + }, + { + "input_size": 3, + "condition_size": 1, + "head_size": 8, + "channels": 4, + "bottleneck": 2, + "kernel_size": 3, + "dilations": [ + 1, + 3, + 5 + ], + "activation": { + "type": "PReLU", + "negative_slopes": [ + 0.04, + 0.05 + ] + }, + "gating_mode": "blended", + "head_bias": false, + "groups_input": 1, + "groups_input_mixin": 1, + "groups_1x1": 1, + "head_1x1": { + "active": true, + "out_channels": 4, + "groups": 2 + }, + "secondary_activation": { + "type": "LeakyHardtanh", + "min_val": 0.0, + "max_val": 0.9, + "min_slope": 0.0, + "max_slope": 0.02 + }, + "conv_pre_film": { + "active": true, + "shift": false + }, + "conv_post_film": { + "active": true, + "shift": false + }, + "input_mixin_pre_film": { + "active": true, + "shift": false + }, + "input_mixin_post_film": { + "active": true, + "shift": false + }, + "activation_pre_film": { + "active": true, + "shift": false + }, + "activation_post_film": { + "active": true, + "shift": false + }, + "1x1_post_film": { + "active": true, + "shift": false + }, + "head1x1_post_film": { + "active": true, + "shift": false + } + } + ], + "head": null, + "head_scale": 0.02 + }, + "weights": [ + 0.2788535969157675, + -0.9499784895546661, + -0.4499413632617615, + -0.5535785237023545, + 0.4729424283280248, + 0.3533989748458226, + 0.7843591354096908, + -0.8261223347411677, + -0.15615636062945915, + -0.9404055611238593, + -0.5627240503927933, + 0.010710576206724776, + -0.9469280606322728, + -0.602324698626703, + 0.2997688755590464, + 0.08988296120643335, + -0.5591187559186066, + 0.17853136775181744, + 0.6188609133556533, + -0.987002480643878, + 0.6116385036656158, + 0.3962787899764537, + -0.31949896696401625, + -0.6890410003764369, + 0.9144261444135624, + -0.32681090977474647, + -0.8145083132397042, + -0.806567246333072, + 0.6949887326949196, + 0.20745206273378214, + 0.6142565465487604, + 0.45946357338763577, + 0.07245618290940148, + 0.9462315279587412, + -0.24293124558329304, + 0.104081262546454, + 0.6588093285059897, + 0.2370395047284921, + 0.7234138006215545, + 0.15470429051352408, + 0.40914367242984695, + -0.9083512326886756, + -0.5442034486969063, + -0.42122407279578566, + -0.840416046152745, + -0.5344182272779396, + -0.7979971411805418, + -0.44405279377981577, + 0.27136888852880037, + -0.2703356420598315, + -0.2596380657662347, + -0.5809859384570246, + -0.4660443559017733, + 0.873309175424988, + 0.2960707704931871, + 0.21826201133397638, + -0.657722703603806, + 0.45825359590069836, + -0.6731950124761432, + -0.24108911648470444, + 0.9790467012731905, + 0.2799995197081857, + 0.11389948754929247, + 0.3692285019797492, + 0.6857038403796192, + 0.5519998230924896, + -0.5419038560717913, + -0.9357995121919245, + -0.36909390388183616, + -0.46451824804859454, + -0.5780343128273471, + 0.8858194286701089, + 0.7527352529453377, + -0.37064423840304417, + 0.3108773305897601, + -0.20873619787867148, + 0.829095179481087, + -0.0822962948252024, + -0.4702396670038951, + -0.5067449846120331, + 0.12273626832630158, + -0.47451678295412947, + 0.16917198044708104, + 0.795645767204954, + -0.20119898971920547, + -0.5613584816854333, + 0.9950752129902205, + 0.01905258735292903, + -0.8181811756524122, + -0.9057672491505309, + -0.7807017392986817, + 0.2548920834061801, + 0.5841587287259282, + -0.15568006640063192, + -0.8729445876960857, + -0.23676142698692648, + 0.9922427604801936, + 0.058228690198274036, + 0.9421567552272363, + 0.7215594044689961, + -0.9770379561143607, + 0.4414436387203893, + 0.36342073805314956, + 0.07394066081759032, + -0.46634962009491443, + 0.2819235971596161, + -0.7768956528082471, + -0.13046949866179003, + -0.09255258734158711, + 0.9076318550421603, + 0.7517058807563881, + -0.4732218984978185, + 0.0011722261005966406, + -0.6426962389397373, + 0.825255678689641, + 0.7410371396735338, + -0.4031104171027342, + 0.2778989897320103, + 0.21794042287634463, + -0.6943214629007304, + 0.5250216001503025, + 0.07875806023925147, + 0.5572529572611165, + 0.06070734439035497, + -0.998856207744113, + -0.3516878859906538, + -0.9610465152283354, + 0.8581972325292342, + 0.7574437556463685, + 0.6633310587223589, + -0.38497174919467714, + -0.8841496670116249, + 0.7560191984080811, + 0.8938988905959881, + -0.8286930958642424, + -0.02801907336677245, + -0.8615749630632328, + 0.5212043305144631, + 0.5316688586139755, + -0.7432170710004744, + -0.04943524380253739, + 0.0996071869898878, + -0.4698867421198818, + 0.7448660821705149, + -0.15372411959822618, + -0.5764035891158359, + 0.07859217755891668, + 0.45986213817995236, + -0.5976978732206082, + -0.3765674173982101, + 0.9902987133217893, + 0.299756115278907, + -0.12379983217099189, + 0.035151682071181245, + -0.7579916082634686, + -0.5506053259368853, + -0.32382887570508934, + 0.17661743691446663, + -0.539770534806846, + -0.559565231096881, + -0.8580138279819349, + 0.2622059145401978, + -0.5421164323776912, + 0.8108400260122559, + 0.719270800507493, + -0.8582853002226931, + -0.5239907312620096, + 0.33795555659256116, + -0.5715263852591228, + -0.73537630254995, + 0.871028481161342, + 0.14208618665056894, + -0.05465794737641172, + 0.5692388485815068, + 0.6149939955332868, + -0.6191801712762446, + -0.8061383715423533, + -0.13789763518724496, + -0.15284275396015845, + -0.06595066392665005, + 0.4581516989197012, + 0.34672909458660306, + 0.9683304227319323, + -0.8031642576960822, + -0.19475743579546245, + -0.3213947892100737, + 0.7233450727055821, + -0.5026873321594287, + -0.619582183118377, + -0.1027729043337362, + -0.15623672033119163, + -0.44290971066611906, + -0.500387104235799, + 0.8465311985520256, + -0.11373850989308609, + 0.7226982095236612, + 0.10065062489969612, + -0.8988233409502375, + 0.9985649368254532, + 0.6720551701599038, + 0.9379925145695025, + 0.8527339660162552, + 0.6973914688286109, + -0.667377778792172, + -0.02871774909856306, + -0.5725054016016367, + -0.19791941490109477, + -0.8827292000556421, + -0.2420537620461678, + 0.9706176875594519, + -0.4695938836556961, + 0.5681412038971387, + -0.08998326532171341, + -0.15398502801967417, + 0.9146352817193464, + 0.9908453789854277, + 0.11153664681123643, + 0.436816550592652, + -0.6904063494518717, + -0.4065843490108716, + 0.9374187299383177, + 0.15836058163251243, + 0.08439040274854848, + 0.4959511207581282, + -0.8856694541850338, + 0.16835518891794243, + 0.005700765839027122, + 0.7054397840965707, + -0.6851345441210335, + 0.9215578065489007, + -0.8397770695188262, + -0.6283500780385536, + 0.19007021290005532, + 0.3504251072081803, + -0.5295922099981376, + -0.7602267721057516, + 0.780574628258875, + -0.5075693044227503, + 0.1890383070668824, + 0.23876302066420618, + -0.16155016932825506, + 0.16734457858244944, + 0.04556543106391775, + 0.8694125154728545, + -0.5914816011529271, + 0.4323836015788296, + -0.522628094768308, + -0.208428306417491, + 0.34338044591994255, + -0.40000584040247555, + -0.36764560745629193, + 0.5037289848288042, + -0.8549137710136854, + -0.08342895476282775, + 0.9969088817088847, + 0.9921928957101889, + -0.853478557800734, + -0.5736913754659192, + -0.4695991704991973, + 0.8665187559874181, + 0.7617283473728791, + 0.7585404849690855, + -0.2609458225222321, + -0.6845063352855361, + 0.667489909279614, + 0.4070798501747419, + 0.22335553145190024, + 0.9744661272630086, + 0.3079526354214652, + -0.9843537856956841, + 0.6342082702309233, + -0.4012424956000442, + 0.32677742993215464, + 0.8778600078542078, + -0.7314177712132646, + -0.7691426591617956, + -0.7859280445811647, + 0.10644728176963181, + -0.45530357537036736, + 0.20965965406044784, + 0.43522437427759586, + -0.5928053753450941, + 0.26847591777015944, + -0.4720321967391812, + -0.02293629570124689, + 0.8106729821586465, + 0.6922074265897109, + -0.8154030645745332, + -0.15284845487254728, + -0.44663955205549666, + -0.9929086218244354, + 0.5422384460392542, + 0.27422675460275925, + -0.4760894751313036, + 0.48246181669586163, + 0.10336084225278253, + -0.14462616203864131, + -0.9806606007833201, + -0.8495122798524659, + 0.7662127866002859, + 0.8078571431197863, + 0.09118057841104465, + 0.6691900397720334, + 0.16501913297958803, + -0.7038124288650347, + -0.7451089614357225, + -0.38348330013973264, + 0.79796297748518, + 0.5922446097760834, + 0.7214051640018055, + 0.7978492730529492, + -0.5798469233204919, + -0.5009405215541511, + -0.7944127566564287, + 0.5602324837428854, + 0.7682694029020178, + -0.1872452203357664, + 0.2413230203014256, + -0.6908933233355907, + 0.8597620313873489, + 0.7292113924399279, + 0.9524120658619257, + 0.6215434398807937, + 0.7628324093266488, + -0.9504272762036226, + 0.4731289435101642, + -0.33562906410714266, + 0.8616317720966511, + 0.6044702778742779, + 0.7281280567505588, + 0.621498633148778, + -0.46638858081105594, + 0.5747490182709423, + -0.7838087471940858, + 0.7443335658121795, + 0.7171865026755633, + -0.5551325649086711, + 0.6331732111938579, + -0.07939353064211585, + -0.38961826532279886, + 0.5906909983057236, + -0.5448090251844593, + -0.952671130597097, + -0.6137404233445827, + -0.34347609760458697, + 0.7287058840605727, + 0.9337782080967223, + -0.4417500145562572, + 0.2829634772152554, + -0.20064323127987826, + 0.9622993743965202, + 0.07243146495744379, + 0.8784742806494314, + -0.7693164962971448, + 0.9408012220444559, + -0.6428643676550727, + 0.9250686315231109, + -0.46906727495406275, + -0.7831949055705778, + -0.1308724828707113, + 0.4570901213054086, + -0.37264537161001754, + 0.21241770661228654, + 0.022846119338956195, + -0.22960913331054567, + 0.15317608699319907, + -0.4905549877228361, + 0.4175705676683412, + -0.9966174435627411, + 0.8511503309981654, + 0.0769039941855838, + 0.438859998289691, + 0.48390015567895306, + 0.34125700886599897, + -0.27155705643747163, + -0.8600523777473796, + 0.32847536982254466, + -0.3395999279148072, + -0.37216870988328066, + 0.696030559012671, + 0.43950852602790036, + -0.39935546357747165, + -0.3814306755826935, + -0.1832141827615663, + -0.19519922588455074, + -0.40868959494810597, + -0.7454244018816936, + -0.1591073324541834, + 0.880727341460366, + 0.35463589054546585, + 0.8056110914651653, + 0.23102983190276105, + -0.39810025086886935, + 0.09587442627139642, + -0.999188120605425, + -0.4261725662621456, + -0.1402237000203308, + 0.15996956239136395, + 0.30941124740614323, + -0.07002361950597158, + -0.11568040139038516, + -0.5725971980217994, + -0.053627628181347475, + 0.8023616516565084, + 0.5920495202535605, + -0.6606172076038905, + -0.8304089265497565, + 0.030904019830432894, + 0.26588171153159146, + -0.3296234891803982, + 0.6368469290733285, + 0.5022762750814644, + 0.3455913411143341, + -0.5507186680054235, + -0.6017401345468467, + -0.9511492245463473, + -0.5103149118432997, + -0.04972731156238974, + 0.6994753892494638, + -0.8543435416308618, + -0.1711179780045613, + 0.2595307614754274, + -0.6111295265205814, + 0.39270850098100984, + -0.011245661979126131, + -0.5120311208431223, + 0.31211602222356816, + -0.9889103637239365, + 0.5019289532369458, + 0.5400923771480501, + -0.7868254068729221, + -0.14970761211453176, + -0.6482266365869367, + 0.9159320845590795, + 0.03591550088748163, + -0.8995632297187182, + -0.5016034406800567, + 0.6966726947033195, + -0.08707634905965489, + 0.602833203444529, + 0.3351554651727062, + 0.975784906132896, + 0.1909046369388394, + 0.9000792168863119, + 0.782851851620874, + 0.2253046455235257, + 0.4385479225519342, + 0.00955632964880393, + 0.6611383394428301, + 0.09574390122165677, + 0.7944162064665243, + 0.4873108843191698, + -0.05065112635389335, + -0.4816169030699613, + -0.5055205249806809, + 0.27532287355231255, + 0.5316273685943309, + 0.0425996256559642, + 0.2534968739635626, + -0.4508051061649234, + -0.8450332922705284, + -0.428543698273695, + -0.4565697858356308, + -0.3605808631624754, + 0.08030444503691281, + -0.7232518769676886, + -0.5374770405436264, + 0.38789962459810456, + 0.41283828338910444, + -0.8715422985722439, + -0.18480126066682656, + 0.08522228100783069, + -0.1684515317936881, + -0.5863312209711795, + -0.15971296445314875, + 0.8096769566803539, + 0.1681588284084503, + 0.39104597299593613, + 0.7134640646078685, + 0.5311891522361389, + -0.23923794214457672, + -0.9882078328320139, + -0.2964823946563506, + 0.5069502501187717, + 0.7068959011382092, + 0.9068606769402126, + -0.1619574347403563, + 0.49503133795610155, + 0.09226461946767817, + 0.20650517788248268, + -0.5589226113523622, + -0.5611567307571277, + -0.1283280479067268, + -0.9419503601065695, + -0.3277409126032351, + 0.3582837700566994, + -0.19136666172472583, + -0.6699105375929824, + -0.0652197015353797, + -0.7447444054376786, + 0.24451392194812938, + -0.9460670961897246, + -0.21195948732059056, + 0.12878396604954823, + -0.9457959073193751, + 0.2854992960186715, + -0.7286010255388715, + -0.07660311189697655, + -0.8994307330227449, + -0.24179227162372086, + -0.5766794315770369, + -0.3463083902373829, + 0.5224594157880542, + -0.24174756887172744, + 0.5040196471095697, + 0.6638485703105452, + -0.4954569364352388, + -0.8361875344767149, + -0.9612334258999786, + 0.0788380958450674, + 0.9998156570184185, + -0.30007931255963216, + 0.30028818649975, + 0.5624660992217898, + 0.30350931048777885, + 0.5084664081190524, + 0.8992234654319777, + -0.6012786352749342, + -0.9592399653593353, + -0.6952353084304193, + -0.7475580502515875, + 0.33891768923982135, + 0.12793916386003823, + -0.5640709181869927, + 0.3989299424923016, + 0.5337961967124816, + -0.6644217132643955, + 0.21449498778186338, + 0.49585130391057164, + -0.7709342572422047, + 0.6386023486221701, + 0.9294415460681753, + -0.7838025006847869, + -0.948643149005068, + -0.3760855112106096, + 0.3546945737008176, + 0.9163456764117919, + -0.20669111696674536, + 0.4300294100989368, + -0.8480070443138801, + 0.3812288318659607, + 0.2544847912020889, + -0.7961973891080469, + 0.544961769902448, + 0.7005864781775926, + 0.20082322963368826, + -0.7578898698653698, + 0.9676887030293426, + 0.5652706927220392, + -0.30559246938310203, + -0.1432439735305111, + -0.25885824756388764, + 0.011921579354155831, + -0.31753765027743364, + 0.6991512539991547, + 0.6446618361801162, + -0.7889222587120028, + 0.9215751344291572, + 0.27117021220289184, + 0.6574146220049164, + 0.41461728741215387, + -0.1290257099846459, + 0.46759060802678376, + 0.9309474624761553, + -0.45983520722519833, + 0.6163984376135119, + 0.07634581289651154, + -0.033004992233607755, + -0.12885101399221144, + 0.46205242861024454, + -0.46320892390154933, + 0.7034263200386637, + 0.6614620377812068, + -0.8266742038886512, + 0.7632623680055439, + -0.512273121618088, + -0.07058306679353654, + 0.22066340846104127, + -0.24202139174347193, + -0.9426000044598208, + 0.7019056726249182, + -0.6363202856846164, + -0.575760299640522, + 0.595664713656193, + -0.3193223136714327, + 0.7606399595164508, + 0.4023675006644032, + -0.4474628484876302, + -0.9796977711226458, + 0.8961251555540624, + -0.8287740760839575, + 0.4401493282083886, + -0.022844306302425288, + 0.5163293069649328, + 0.38121867889304606, + 0.29180579948190455, + -0.018357329842827674, + 0.5858657362646349, + -0.8138932988906566, + -0.5568071990549397, + 0.3835743105904037, + -0.3875879397398232, + 0.16311117066473435, + -0.05347902248095049, + 0.0618438622915356, + -0.1489923745895705, + 0.49187087342721925, + -0.3384174056081397, + 0.4057098843715379, + -0.45816714604760733, + -0.497192647598117, + -0.7586882304953884, + -0.6148314130969652, + -0.7608905174898943, + 0.07172793076749961, + 0.5243792189682059, + -0.6297003151246761, + -0.5672307202427886, + -0.031602825574582605, + 0.44917000218610315, + 0.9532140457661136, + 0.0492737382172157, + -0.43400259345008996, + -0.7989477838184185, + -0.6117648438178531, + -0.5450336730348906, + -0.6411169126291933, + -0.9717032654878723, + 0.06827017853532658, + -0.4513773464334654, + 0.9485898622054758, + 0.1067179335972166, + 0.3948347858203889, + -0.7474410140831822, + 0.736922394525888, + -0.018242610952442284, + 0.7454394699970692, + 0.1481284392526645, + -0.061206110138979586, + -0.1190624037279795, + -0.6312726592198366, + -0.8972465656328443, + 0.8821271935362065, + -0.04454162546862461, + 0.6442312905988608, + -0.19858511549612046, + -0.8518356594887204, + 0.2588914139039966, + -0.8927818514149337, + -0.7016048310526897, + 0.12567919416912843, + -0.39232897631554464, + 0.9878362454204404, + -0.7630968755770768, + 0.5288868906523885, + 0.21263530248592555, + 0.5814816596868388, + -0.5486257258699307, + 0.04514507026043568, + -0.09897107022249152, + -0.11455799196134331, + 0.7203333316523601, + 0.9800625209335958, + -0.3892395113612792, + 0.24205464214429706, + 0.21926182449026066, + 0.48017861096909, + 0.895180400650756, + -0.5844241883528973, + -0.577949609389214, + 0.32085627438407327, + -0.6858858132382457, + -0.652372903347131, + -0.8498702621976642, + -0.9946485547942285, + -0.09899259076459521, + 0.18762239023912652, + -0.4174814219379448, + -0.5370475308879483, + 0.41391165975806654, + 0.4059751161874343, + -0.09193734718591351, + 0.3747698401424855, + 0.8478220889650974, + 0.5756560532934627, + 0.25011601432844377, + 0.32236608570673564, + 0.8673369168910265, + -0.1497220677915767, + 0.08912475742115733, + 0.2952694468049615, + 0.8168228904252122, + 0.6532623193001192, + -0.8571803262883717, + -0.6681544215585506, + -0.3847763747715349, + 0.4979154441392466, + 0.13841409863818477, + -0.4227788233883165, + -0.7512926836505835, + 0.37735598242418744, + 0.3994673699514346, + 0.8853524814880815, + 0.0009443542358511525, + -0.012409561309831396, + -0.839116296201398, + -0.9202784316327739, + -0.13594267167845486, + -0.35535683328340006, + -0.4992641951759922, + -0.8173462267380296, + 0.9238222043856734, + 0.6719172278123626, + 0.1503982184398609, + 0.9015725556127048, + 0.9991448337549027, + 0.3445631686064785, + -0.4609779480658922, + -0.9195366537107676, + 0.5125376608251142, + -0.05899835349673288, + 0.3030189788673727, + 0.832145575853535, + -0.6370217055537482, + 0.17065925055570874, + 0.26956943890817375, + -0.016548395617973277, + -0.8175151874123456, + -0.3040778874106991, + -0.33338321266981596, + 0.3402670190423709, + 0.7154661888007052, + -0.3403926728420692, + 0.387347347966865, + -0.4235644092770885, + 0.8903870791264212, + 0.6271320694759452, + 0.10019321794356584, + -0.09034818279654111, + -0.37096568565966415, + -0.35345242744801064, + 0.9403694536171543, + -0.19164988599850874, + 0.029192504658210217, + 0.9762384295635083, + 0.31532077296625793, + 0.08518718871439002, + -0.173504858401357, + -0.6248349172109458, + -0.2764412816966495, + 0.5128863081111474, + 0.25081748429001793, + 0.5199810721220339, + -0.5928835232994523, + 0.09843927817066334, + 0.8553455216888954, + -0.1237678098552435, + 0.39650005824434453, + -0.7571478332747603, + 0.9462936316433888, + 0.21774333416392988, + -0.5214050753446993, + -0.6832436723895352, + 0.10167801400533993, + 0.10450281810746609, + -0.8135815974369549, + 0.9845142787904393, + 0.8258597569683359, + -0.0771042116424332, + -0.7650677018208849, + 0.6642863475378025, + -0.003248990585609457, + 0.4332066519835822, + 0.01774403013902348, + -0.4531502065734516, + 0.6694478911533888, + 0.9604892652067509, + -0.5125381878497233, + 0.10253015376093932, + -0.23282797350570017, + 0.8437362998631237, + 0.016481783185787968, + 0.7586525102929522, + 0.7280538688571763, + -0.44750519435591807, + 0.5800123640062269, + -0.17011515289730972, + 0.8684967873650402, + 0.015475351619312905, + 0.6410989463711567, + -0.4343220334343467, + -0.4028883004564656, + 0.17387544482832218, + 0.9978046664586775, + -0.020719306688598005, + -0.7028091632344218, + 0.07716115540774737, + -0.30975211661398494, + 0.10383483414162198, + 0.08686012591734604, + -0.08931076626699785, + -0.35644529954192516, + -0.6226952525857892, + 0.3949968552411427, + 0.14359528396980226, + -0.5328751078842138, + 0.5510889501985177, + -0.9127054018053962, + 0.4894103031303918, + 0.4104557620500051, + 0.6228178051297537, + -0.22784249501803533, + 0.32737765896906934, + 0.6414951034191101, + 0.9616362773195701, + -0.009342700767152934, + -0.9259607773088476, + 0.00458230026593176, + 0.1803608586191341, + 0.7394006267244309, + 0.7483807481086167, + -0.11938758045864573, + 0.05190217361512284, + -0.08614385105151556, + 0.44488765514128525, + -0.1800427605072712, + 0.3095626528554676, + -0.6912775624569023, + -0.061018798030536114, + 0.9384072611484344, + -0.32287753189713664, + 0.3854091971737641, + 0.2996733051585452, + 0.7035305847013829, + 0.7046826731317861, + 0.7186843683365332, + -0.2399812049546899, + -0.36667769213320067, + 0.437434850445966, + 0.5188036186687273, + 0.7447660347970726, + -0.9282018003246268, + -0.8631585056368971, + 0.2623220337092811, + 0.8418581975605934, + 0.9948518459124307, + 0.4935327334758539, + -0.13205705615374863, + -0.8031137472336041, + 0.26749565756158256, + 0.7451584652141814, + -0.11264289667658844, + 0.3880023255864691, + 0.806848124100972, + -0.9080180626455745, + 0.5922869303244236, + -0.4132644480695824, + -0.25031782044937034, + -0.7088604086012031, + 0.06233263629752006, + 0.13185612383187362, + 0.5850389477619065, + -0.6600327018638052, + -0.8420632986800107, + 0.7416791972897097, + 0.2394207370748307, + -0.5183404162873739, + 0.8256580320471096, + -0.7137645597429807, + -0.07770017329000489, + -0.49204532117291544, + -0.4893465836816919, + -0.9812051370902579, + 0.6092661539502926, + 0.802418847197766, + 0.3552217713981589, + -0.6840487558564261, + -0.11654043279237114, + -0.3088687511138366, + 0.17514341025284286, + 0.2778774047200243, + -0.15138212307820464, + -0.4998035511845975, + 0.6906078502851973, + -0.6015660017822153, + -0.23061350206530906, + -0.0335838778815678, + -0.5255885961448508, + 0.14384538470147779, + 0.14962386035729014, + 0.9853840872537951, + -0.4095384922334613, + 0.9558889691537253, + 0.31645963185736137, + -0.45103923964354453, + 0.13185803391139794, + 0.37159898546802417, + 0.48933768233065034, + -0.9019114984480001, + 0.21281298615294908, + -0.006545426952259348, + 0.8083105817874507, + -0.42761169708161106, + 0.5977202390151974, + 0.2141299963285319, + -0.29535808832936694, + 0.27323575601178374, + 0.24178232626060825, + 0.35552891724671887, + 0.44185675334141794, + 0.3183630806335074, + 0.6766742339250331, + 0.25649620737379664, + 0.8068074081459358, + 0.2926812177810456, + -0.38213423209472785, + -0.11835361967437685, + 0.1591476107368055, + 0.4647195358784766, + -0.8197332485075621, + -0.40977909674168944, + 0.49496172987674303, + -0.6487198591113867, + -0.7356804045064329, + 0.07881551796884212, + 0.9429791624226798, + 0.06170474740557008, + 0.8269739489633938, + 0.6609452391347947, + -0.4860598308734754, + 0.6493796250848145, + -0.03630434025251761, + 0.612976987587533, + 0.4931187014340901, + -0.32256949239622745, + -0.7696605851004812, + 0.9257865857377552, + -0.7184859699882347, + 0.9330004189255041, + 0.7202811937976437, + 0.4484334241510364, + 0.9598844855638069, + 0.9345394946000647, + 0.6091752880411239, + -0.26844990118886725, + 0.5813639371778749, + -0.9721626897981976, + 0.07314461653811821, + -0.09042794453225067, + 0.3456567637475072, + 0.3446815947020263, + 0.1691201833043321, + 0.6448346024535485, + 0.8805837835591084, + -0.7833077956015386, + -0.5323561955768337, + -0.9499507007070354, + 0.7684696904297044, + 0.1228147644997577, + 0.8305118174863189, + -0.55726559985202, + -0.8735659176796091, + 0.6477107027808953, + 0.8187752768557841, + -0.39561965094149953, + -0.18340828840917456, + -0.7204459749855578, + 0.8925230657638925, + -0.39127083128798956, + -0.014750762043588272, + -0.8056160027563568, + 0.7745186170570046, + -0.728671902587327, + -0.09271248622221484, + 0.34097243770034247, + 0.4862802430463431, + 0.8919481715588642, + -0.16174649317055434, + 0.48453802953063163, + -0.69095419518033, + -0.17023094512638615, + -0.801956730578943, + -0.02130592442077095, + -0.18376822860459896, + 0.903043050762119, + -0.9345674262899062, + -0.2589400825311292, + -0.1132338278785967, + 0.9011103397028539, + 0.7109003866119092, + -0.8012907507877394, + 0.3713605309625707, + 0.08893172296428986, + 0.9556850589040935, + -0.28265231757536413, + -0.20372071451125384, + -0.6203828756778409, + -0.7556805618254725, + 0.6960663769273621, + -0.09056526285896571, + 0.325537476123956, + 0.2834089344664352, + 0.19429191903909016, + -0.9572850905272587, + 0.5735891809092335, + -0.5128622056719527, + -0.7481522293839142, + 0.12915595181592665, + -0.8627796943512882, + 0.5303147517771689, + -0.585685259306683, + -0.5680972961626505, + 0.7393908535390894, + -0.34288089313553916, + -0.7048916401171126, + 0.8010620712635164, + -0.9943288970399673, + 0.7168122527603527, + -0.710624039358984, + -0.7400157371131264, + -0.49869160654375233, + -0.6510057581972131, + 0.3221152851946336, + -0.9484397004276033, + -0.9702793455385823, + 0.5799693284695078, + -0.5241367878190739, + -0.3524570760759551, + -0.6515075971876456, + -0.895201964277647, + 0.483436113908299, + 0.05217105319573423, + 0.4913305500679914, + -0.04750806915349948, + 0.5560340786284055, + 0.026475915218383328, + -0.781891979992309, + 0.007677379571642717, + 0.8908312859402125, + -0.9132699262016566, + 0.5664539919607592, + 0.7339618155196765 + ], + "sample_rate": 48000 + }, + "layers": [ + { + "input_size": 1, + "condition_size": 8, + "head_size": 1, + "channels": 4, + "bottleneck": 4, + "kernel_size": 4, + "dilations": [ + 1, + 2 + ], + "activation": { + "type": "PReLU", + "negative_slope": 0.015 + }, + "gating_mode": "none", + "head_bias": true, + "groups_input": 1, + "groups_input_mixin": 4, + "groups_1x1": 2, + "head_1x1": { + "active": true, + "out_channels": 4, + "groups": 2 + }, + "secondary_activation": "", + "conv_pre_film": { + "active": true, + "shift": true + }, + "conv_post_film": { + "active": true, + "shift": true + }, + "input_mixin_pre_film": { + "active": true, + "shift": true + }, + "input_mixin_post_film": { + "active": true, + "shift": true + }, + "activation_pre_film": { + "active": true, + "shift": true + }, + "activation_post_film": { + "active": true, + "shift": true + }, + "1x1_post_film": { + "active": true, + "shift": true + }, + "head1x1_post_film": { + "active": true, + "shift": true + } + } + ], + "head": null, + "head_scale": 0.02 + }, + "weights": [ + 0.2788535969157675, + -0.9499784895546661, + -0.4499413632617615, + -0.5535785237023545, + 0.4729424283280248, + 0.3533989748458226, + 0.7843591354096908, + -0.8261223347411677, + -0.15615636062945915, + -0.9404055611238593, + -0.5627240503927933, + 0.010710576206724776, + -0.9469280606322728, + -0.602324698626703, + 0.2997688755590464, + 0.08988296120643335, + -0.5591187559186066, + 0.17853136775181744, + 0.6188609133556533, + -0.987002480643878, + 0.6116385036656158, + 0.3962787899764537, + -0.31949896696401625, + -0.6890410003764369, + 0.9144261444135624, + -0.32681090977474647, + -0.8145083132397042, + -0.806567246333072, + 0.6949887326949196, + 0.20745206273378214, + 0.6142565465487604, + 0.45946357338763577, + 0.07245618290940148, + 0.9462315279587412, + -0.24293124558329304, + 0.104081262546454, + 0.6588093285059897, + 0.2370395047284921, + 0.7234138006215545, + 0.15470429051352408, + 0.40914367242984695, + -0.9083512326886756, + -0.5442034486969063, + -0.42122407279578566, + -0.840416046152745, + -0.5344182272779396, + -0.7979971411805418, + -0.44405279377981577, + 0.27136888852880037, + -0.2703356420598315, + -0.2596380657662347, + -0.5809859384570246, + -0.4660443559017733, + 0.873309175424988, + 0.2960707704931871, + 0.21826201133397638, + -0.657722703603806, + 0.45825359590069836, + -0.6731950124761432, + -0.24108911648470444, + 0.9790467012731905, + 0.2799995197081857, + 0.11389948754929247, + 0.3692285019797492, + 0.6857038403796192, + 0.5519998230924896, + -0.5419038560717913, + -0.9357995121919245, + -0.36909390388183616, + -0.46451824804859454, + -0.5780343128273471, + 0.8858194286701089, + 0.7527352529453377, + -0.37064423840304417, + 0.3108773305897601, + -0.20873619787867148, + 0.829095179481087, + -0.0822962948252024, + -0.4702396670038951, + -0.5067449846120331, + 0.12273626832630158, + -0.47451678295412947, + 0.16917198044708104, + 0.795645767204954, + -0.20119898971920547, + -0.5613584816854333, + 0.9950752129902205, + 0.01905258735292903, + -0.8181811756524122, + -0.9057672491505309, + -0.7807017392986817, + 0.2548920834061801, + 0.5841587287259282, + -0.15568006640063192, + -0.8729445876960857, + -0.23676142698692648, + 0.9922427604801936, + 0.058228690198274036, + 0.9421567552272363, + 0.7215594044689961, + -0.9770379561143607, + 0.4414436387203893, + 0.36342073805314956, + 0.07394066081759032, + -0.46634962009491443, + 0.2819235971596161, + -0.7768956528082471, + -0.13046949866179003, + -0.09255258734158711, + 0.9076318550421603, + 0.7517058807563881, + -0.4732218984978185, + 0.0011722261005966406, + -0.6426962389397373, + 0.825255678689641, + 0.7410371396735338, + -0.4031104171027342, + 0.2778989897320103, + 0.21794042287634463, + -0.6943214629007304, + 0.5250216001503025, + 0.07875806023925147, + 0.5572529572611165, + 0.06070734439035497, + -0.998856207744113, + -0.3516878859906538, + -0.9610465152283354, + 0.8581972325292342, + 0.7574437556463685, + 0.6633310587223589, + -0.38497174919467714, + -0.8841496670116249, + 0.7560191984080811, + 0.8938988905959881, + -0.8286930958642424, + -0.02801907336677245, + -0.8615749630632328, + 0.5212043305144631, + 0.5316688586139755, + -0.7432170710004744, + -0.04943524380253739, + 0.0996071869898878, + -0.4698867421198818, + 0.7448660821705149, + -0.15372411959822618, + -0.5764035891158359, + 0.07859217755891668, + 0.45986213817995236, + -0.5976978732206082, + -0.3765674173982101, + 0.9902987133217893, + 0.299756115278907, + -0.12379983217099189, + 0.035151682071181245, + -0.7579916082634686, + -0.5506053259368853, + -0.32382887570508934, + 0.17661743691446663, + -0.539770534806846, + -0.559565231096881, + -0.8580138279819349, + 0.2622059145401978, + -0.5421164323776912, + 0.8108400260122559, + 0.719270800507493, + -0.8582853002226931, + -0.5239907312620096, + 0.33795555659256116, + -0.5715263852591228, + -0.73537630254995, + 0.871028481161342, + 0.14208618665056894, + -0.05465794737641172, + 0.5692388485815068, + 0.6149939955332868, + -0.6191801712762446, + -0.8061383715423533, + -0.13789763518724496, + -0.15284275396015845, + -0.06595066392665005, + 0.4581516989197012, + 0.34672909458660306, + 0.9683304227319323, + -0.8031642576960822, + -0.19475743579546245, + -0.3213947892100737, + 0.7233450727055821, + -0.5026873321594287, + -0.619582183118377, + -0.1027729043337362, + -0.15623672033119163, + -0.44290971066611906, + -0.500387104235799, + 0.8465311985520256, + -0.11373850989308609, + 0.7226982095236612, + 0.10065062489969612, + -0.8988233409502375, + 0.9985649368254532, + 0.6720551701599038, + 0.9379925145695025, + 0.8527339660162552, + 0.6973914688286109, + -0.667377778792172, + -0.02871774909856306, + -0.5725054016016367, + -0.19791941490109477, + -0.8827292000556421, + -0.2420537620461678, + 0.9706176875594519, + -0.4695938836556961, + 0.5681412038971387, + -0.08998326532171341, + -0.15398502801967417, + 0.9146352817193464, + 0.9908453789854277, + 0.11153664681123643, + 0.436816550592652, + -0.6904063494518717, + -0.4065843490108716, + 0.9374187299383177, + 0.15836058163251243, + 0.08439040274854848, + 0.4959511207581282, + -0.8856694541850338, + 0.16835518891794243, + 0.005700765839027122, + 0.7054397840965707, + -0.6851345441210335, + 0.9215578065489007, + -0.8397770695188262, + -0.6283500780385536, + 0.19007021290005532, + 0.3504251072081803, + -0.5295922099981376, + -0.7602267721057516, + 0.780574628258875, + -0.5075693044227503, + 0.1890383070668824, + 0.23876302066420618, + -0.16155016932825506, + 0.16734457858244944, + 0.04556543106391775, + 0.8694125154728545, + -0.5914816011529271, + 0.4323836015788296, + -0.522628094768308, + -0.208428306417491, + 0.34338044591994255, + -0.40000584040247555, + -0.36764560745629193, + 0.5037289848288042, + -0.8549137710136854, + -0.08342895476282775, + 0.9969088817088847, + 0.9921928957101889, + -0.853478557800734, + -0.5736913754659192, + -0.4695991704991973, + 0.8665187559874181, + 0.7617283473728791, + 0.7585404849690855, + -0.2609458225222321, + -0.6845063352855361, + 0.667489909279614, + 0.4070798501747419, + 0.22335553145190024, + 0.9744661272630086, + 0.3079526354214652, + -0.9843537856956841, + 0.6342082702309233, + -0.4012424956000442, + 0.32677742993215464, + 0.8778600078542078, + -0.7314177712132646, + -0.7691426591617956, + -0.7859280445811647, + 0.10644728176963181, + -0.45530357537036736, + 0.20965965406044784, + 0.43522437427759586, + -0.5928053753450941, + 0.26847591777015944, + -0.4720321967391812, + -0.02293629570124689, + 0.8106729821586465, + 0.6922074265897109, + -0.8154030645745332, + -0.15284845487254728, + -0.44663955205549666, + -0.9929086218244354, + 0.5422384460392542, + 0.27422675460275925, + -0.4760894751313036, + 0.48246181669586163, + 0.10336084225278253, + -0.14462616203864131, + -0.9806606007833201, + -0.8495122798524659, + 0.7662127866002859, + 0.8078571431197863, + 0.09118057841104465, + 0.6691900397720334, + 0.16501913297958803, + -0.7038124288650347, + -0.7451089614357225, + -0.38348330013973264, + 0.79796297748518, + 0.5922446097760834, + 0.7214051640018055, + 0.7978492730529492, + -0.5798469233204919, + -0.5009405215541511, + -0.7944127566564287, + 0.5602324837428854, + 0.7682694029020178, + -0.1872452203357664, + 0.2413230203014256, + -0.6908933233355907, + 0.8597620313873489, + 0.7292113924399279, + 0.9524120658619257, + 0.6215434398807937, + 0.7628324093266488, + -0.9504272762036226, + 0.4731289435101642, + -0.33562906410714266, + 0.8616317720966511, + 0.6044702778742779, + 0.7281280567505588, + 0.621498633148778, + -0.46638858081105594, + 0.5747490182709423, + -0.7838087471940858, + 0.7443335658121795, + 0.7171865026755633, + -0.5551325649086711, + 0.6331732111938579, + -0.07939353064211585, + -0.38961826532279886, + 0.5906909983057236, + -0.5448090251844593, + -0.952671130597097, + -0.6137404233445827, + -0.34347609760458697, + 0.7287058840605727, + 0.9337782080967223, + -0.4417500145562572, + 0.2829634772152554, + -0.20064323127987826, + 0.9622993743965202, + 0.07243146495744379, + 0.8784742806494314, + -0.7693164962971448, + 0.9408012220444559, + -0.6428643676550727, + 0.9250686315231109, + -0.46906727495406275, + -0.7831949055705778, + -0.1308724828707113, + 0.4570901213054086, + -0.37264537161001754, + 0.21241770661228654, + 0.022846119338956195, + -0.22960913331054567, + 0.15317608699319907, + -0.4905549877228361, + 0.4175705676683412, + -0.9966174435627411, + 0.8511503309981654, + 0.0769039941855838, + 0.438859998289691, + 0.48390015567895306, + 0.34125700886599897, + -0.27155705643747163, + -0.8600523777473796, + 0.32847536982254466, + -0.3395999279148072, + -0.37216870988328066, + 0.696030559012671, + 0.43950852602790036, + -0.39935546357747165, + -0.3814306755826935, + -0.1832141827615663, + -0.19519922588455074, + -0.40868959494810597, + -0.7454244018816936, + -0.1591073324541834, + 0.880727341460366, + 0.35463589054546585, + 0.8056110914651653, + 0.23102983190276105, + -0.39810025086886935, + 0.09587442627139642, + -0.999188120605425, + -0.4261725662621456, + -0.1402237000203308, + 0.15996956239136395, + 0.30941124740614323, + -0.07002361950597158, + -0.11568040139038516, + -0.5725971980217994, + -0.053627628181347475, + 0.8023616516565084, + 0.5920495202535605, + -0.6606172076038905, + -0.8304089265497565, + 0.030904019830432894, + 0.26588171153159146, + -0.3296234891803982, + 0.6368469290733285, + 0.5022762750814644, + 0.3455913411143341, + -0.5507186680054235, + -0.6017401345468467, + -0.9511492245463473, + -0.5103149118432997, + -0.04972731156238974, + 0.6994753892494638, + -0.8543435416308618, + -0.1711179780045613, + 0.2595307614754274, + -0.6111295265205814, + 0.39270850098100984, + -0.011245661979126131, + -0.5120311208431223, + 0.31211602222356816, + -0.9889103637239365, + 0.5019289532369458, + 0.5400923771480501, + -0.7868254068729221, + -0.14970761211453176, + -0.6482266365869367, + 0.9159320845590795, + 0.03591550088748163, + -0.8995632297187182, + -0.5016034406800567, + 0.6966726947033195, + -0.08707634905965489, + 0.602833203444529, + 0.3351554651727062, + 0.975784906132896, + 0.1909046369388394, + 0.9000792168863119, + 0.782851851620874, + 0.2253046455235257, + 0.4385479225519342, + 0.00955632964880393, + 0.6611383394428301, + 0.09574390122165677, + 0.7944162064665243, + 0.4873108843191698, + -0.05065112635389335, + -0.4816169030699613, + -0.5055205249806809, + 0.27532287355231255, + 0.5316273685943309, + 0.0425996256559642, + 0.2534968739635626, + -0.4508051061649234, + -0.8450332922705284, + -0.428543698273695, + -0.4565697858356308, + -0.3605808631624754, + 0.08030444503691281, + -0.7232518769676886, + -0.5374770405436264, + 0.38789962459810456, + 0.41283828338910444, + -0.8715422985722439, + -0.18480126066682656, + 0.08522228100783069, + -0.1684515317936881, + -0.5863312209711795, + -0.15971296445314875, + 0.8096769566803539, + 0.1681588284084503, + 0.39104597299593613, + 0.7134640646078685, + 0.5311891522361389, + -0.23923794214457672, + -0.9882078328320139, + -0.2964823946563506, + 0.5069502501187717, + 0.7068959011382092, + 0.9068606769402126, + -0.1619574347403563, + 0.49503133795610155, + 0.09226461946767817, + 0.20650517788248268, + -0.5589226113523622, + -0.5611567307571277, + -0.1283280479067268, + -0.9419503601065695, + -0.3277409126032351, + 0.3582837700566994, + -0.19136666172472583, + -0.6699105375929824, + -0.0652197015353797, + -0.7447444054376786, + 0.24451392194812938, + -0.9460670961897246, + -0.21195948732059056, + 0.12878396604954823, + -0.9457959073193751, + 0.2854992960186715, + -0.7286010255388715, + -0.07660311189697655, + -0.8994307330227449, + -0.24179227162372086, + -0.5766794315770369, + -0.3463083902373829, + 0.5224594157880542, + -0.24174756887172744, + 0.5040196471095697, + 0.6638485703105452, + -0.4954569364352388, + -0.8361875344767149, + -0.9612334258999786, + 0.0788380958450674, + 0.9998156570184185, + -0.30007931255963216, + 0.30028818649975, + 0.5624660992217898, + 0.30350931048777885, + 0.5084664081190524, + 0.8992234654319777, + -0.6012786352749342, + -0.9592399653593353, + -0.6952353084304193, + -0.7475580502515875, + 0.33891768923982135, + 0.12793916386003823, + -0.5640709181869927, + 0.3989299424923016, + 0.5337961967124816, + -0.6644217132643955, + 0.21449498778186338, + 0.49585130391057164, + -0.7709342572422047, + 0.6386023486221701, + 0.9294415460681753, + -0.7838025006847869, + -0.948643149005068, + -0.3760855112106096, + 0.3546945737008176, + 0.9163456764117919, + -0.20669111696674536, + 0.4300294100989368, + -0.8480070443138801, + 0.3812288318659607, + 0.2544847912020889, + -0.7961973891080469, + 0.544961769902448, + 0.7005864781775926, + 0.20082322963368826, + -0.7578898698653698, + 0.9676887030293426, + 0.5652706927220392, + -0.30559246938310203, + -0.1432439735305111, + -0.25885824756388764, + 0.011921579354155831, + -0.31753765027743364, + 0.6991512539991547, + 0.6446618361801162, + -0.7889222587120028, + 0.9215751344291572, + 0.27117021220289184, + 0.6574146220049164, + 0.41461728741215387, + -0.1290257099846459, + 0.46759060802678376, + 0.9309474624761553, + -0.45983520722519833, + 0.6163984376135119, + 0.07634581289651154, + -0.033004992233607755, + -0.12885101399221144, + 0.46205242861024454, + -0.46320892390154933, + 0.7034263200386637, + 0.6614620377812068, + -0.8266742038886512, + 0.7632623680055439, + -0.512273121618088, + -0.07058306679353654, + 0.22066340846104127, + -0.24202139174347193, + -0.9426000044598208, + 0.7019056726249182, + -0.6363202856846164, + -0.575760299640522, + 0.595664713656193, + -0.3193223136714327, + 0.7606399595164508, + 0.4023675006644032, + -0.4474628484876302, + -0.9796977711226458, + 0.8961251555540624, + -0.8287740760839575, + 0.4401493282083886, + -0.022844306302425288, + 0.5163293069649328, + 0.38121867889304606, + 0.29180579948190455, + -0.018357329842827674, + 0.5858657362646349, + -0.8138932988906566, + -0.5568071990549397, + 0.3835743105904037, + -0.3875879397398232, + 0.16311117066473435, + -0.05347902248095049, + 0.0618438622915356, + -0.1489923745895705, + 0.49187087342721925, + -0.3384174056081397, + 0.4057098843715379, + -0.45816714604760733, + -0.497192647598117, + -0.7586882304953884, + -0.6148314130969652, + -0.7608905174898943, + 0.07172793076749961, + 0.5243792189682059, + -0.6297003151246761, + -0.5672307202427886, + -0.031602825574582605, + 0.44917000218610315, + 0.9532140457661136, + 0.0492737382172157, + -0.43400259345008996, + -0.7989477838184185, + -0.6117648438178531, + -0.5450336730348906, + -0.6411169126291933, + -0.9717032654878723, + 0.06827017853532658, + -0.4513773464334654, + 0.9485898622054758, + 0.1067179335972166, + 0.3948347858203889, + -0.7474410140831822, + 0.736922394525888, + -0.018242610952442284, + 0.7454394699970692, + 0.1481284392526645, + -0.061206110138979586, + -0.1190624037279795, + -0.6312726592198366, + -0.8972465656328443, + 0.8821271935362065, + -0.04454162546862461, + 0.6442312905988608, + -0.19858511549612046, + -0.8518356594887204, + 0.2588914139039966, + -0.8927818514149337, + -0.7016048310526897, + 0.12567919416912843, + -0.39232897631554464, + 0.9878362454204404, + -0.7630968755770768, + 0.5288868906523885, + 0.21263530248592555, + 0.5814816596868388, + -0.5486257258699307, + 0.04514507026043568, + -0.09897107022249152, + -0.11455799196134331, + 0.7203333316523601, + 0.9800625209335958, + -0.3892395113612792, + 0.24205464214429706, + 0.21926182449026066, + 0.48017861096909, + 0.895180400650756, + -0.5844241883528973, + -0.577949609389214, + 0.32085627438407327, + -0.6858858132382457, + -0.652372903347131, + -0.8498702621976642, + -0.9946485547942285, + -0.09899259076459521, + 0.18762239023912652, + -0.4174814219379448, + -0.5370475308879483, + 0.41391165975806654, + 0.4059751161874343, + -0.09193734718591351, + 0.3747698401424855, + 0.8478220889650974, + 0.5756560532934627, + 0.25011601432844377, + 0.32236608570673564, + 0.8673369168910265, + -0.1497220677915767, + 0.08912475742115733, + 0.2952694468049615, + 0.8168228904252122, + 0.6532623193001192, + -0.8571803262883717, + -0.6681544215585506, + -0.3847763747715349, + 0.4979154441392466, + 0.13841409863818477, + -0.4227788233883165, + -0.7512926836505835, + 0.37735598242418744, + 0.3994673699514346, + 0.8853524814880815, + 0.0009443542358511525, + -0.012409561309831396, + -0.839116296201398, + -0.9202784316327739, + -0.13594267167845486, + -0.35535683328340006, + -0.4992641951759922, + -0.8173462267380296, + 0.9238222043856734, + 0.6719172278123626, + 0.1503982184398609, + 0.9015725556127048, + 0.9991448337549027, + 0.3445631686064785, + -0.4609779480658922, + -0.9195366537107676, + 0.5125376608251142, + -0.05899835349673288, + 0.3030189788673727, + 0.832145575853535, + -0.6370217055537482, + 0.17065925055570874, + 0.26956943890817375, + -0.016548395617973277, + -0.8175151874123456, + -0.3040778874106991, + -0.33338321266981596, + 0.3402670190423709, + 0.7154661888007052, + -0.3403926728420692, + 0.387347347966865, + -0.4235644092770885, + 0.8903870791264212, + 0.6271320694759452, + 0.10019321794356584, + -0.09034818279654111, + -0.37096568565966415, + -0.35345242744801064, + 0.9403694536171543, + -0.19164988599850874, + 0.029192504658210217, + 0.9762384295635083, + 0.31532077296625793, + 0.08518718871439002, + -0.173504858401357, + -0.6248349172109458, + -0.2764412816966495, + 0.5128863081111474, + 0.25081748429001793, + 0.5199810721220339, + -0.5928835232994523, + 0.09843927817066334, + 0.8553455216888954, + -0.1237678098552435, + 0.39650005824434453, + -0.7571478332747603, + 0.9462936316433888, + 0.21774333416392988, + -0.5214050753446993, + -0.6832436723895352, + 0.10167801400533993, + 0.10450281810746609, + -0.8135815974369549, + 0.9845142787904393, + 0.8258597569683359, + -0.0771042116424332, + -0.7650677018208849, + 0.6642863475378025, + -0.003248990585609457, + 0.4332066519835822, + 0.01774403013902348, + -0.4531502065734516, + 0.6694478911533888, + 0.9604892652067509, + -0.5125381878497233, + 0.10253015376093932, + -0.23282797350570017, + 0.8437362998631237, + 0.016481783185787968, + 0.7586525102929522, + 0.7280538688571763, + -0.44750519435591807, + 0.5800123640062269, + -0.17011515289730972, + 0.8684967873650402, + 0.015475351619312905, + 0.6410989463711567, + -0.4343220334343467, + -0.4028883004564656, + 0.17387544482832218, + 0.9978046664586775, + -0.020719306688598005, + -0.7028091632344218, + 0.07716115540774737, + -0.30975211661398494, + 0.10383483414162198, + 0.08686012591734604, + -0.08931076626699785, + -0.35644529954192516, + -0.6226952525857892, + 0.3949968552411427, + 0.14359528396980226, + -0.5328751078842138, + 0.5510889501985177, + -0.9127054018053962, + 0.4894103031303918, + 0.4104557620500051, + 0.6228178051297537, + -0.22784249501803533, + 0.32737765896906934, + 0.6414951034191101, + 0.9616362773195701, + -0.009342700767152934, + -0.9259607773088476, + 0.00458230026593176, + 0.1803608586191341, + 0.7394006267244309, + 0.7483807481086167, + -0.11938758045864573, + 0.05190217361512284, + -0.08614385105151556, + 0.44488765514128525, + -0.1800427605072712, + 0.3095626528554676, + -0.6912775624569023, + -0.061018798030536114, + 0.9384072611484344, + -0.32287753189713664, + 0.3854091971737641, + 0.2996733051585452, + 0.7035305847013829, + 0.7046826731317861, + 0.7186843683365332, + -0.2399812049546899, + -0.36667769213320067, + 0.437434850445966, + 0.5188036186687273, + 0.7447660347970726, + -0.9282018003246268, + -0.8631585056368971, + 0.2623220337092811, + 0.8418581975605934, + 0.9948518459124307, + 0.4935327334758539, + -0.13205705615374863, + -0.8031137472336041, + 0.26749565756158256, + 0.7451584652141814, + -0.11264289667658844, + 0.3880023255864691, + 0.806848124100972, + -0.9080180626455745, + 0.5922869303244236, + -0.4132644480695824, + -0.25031782044937034, + -0.7088604086012031, + 0.06233263629752006, + 0.13185612383187362, + 0.5850389477619065, + -0.6600327018638052, + -0.8420632986800107, + 0.7416791972897097, + 0.2394207370748307, + -0.5183404162873739, + 0.8256580320471096, + -0.7137645597429807, + -0.07770017329000489, + -0.49204532117291544, + -0.4893465836816919, + -0.9812051370902579, + 0.6092661539502926, + 0.802418847197766, + 0.3552217713981589, + -0.6840487558564261, + -0.11654043279237114, + -0.3088687511138366, + 0.17514341025284286, + 0.2778774047200243, + -0.15138212307820464, + -0.4998035511845975, + 0.6906078502851973, + -0.6015660017822153, + -0.23061350206530906, + -0.0335838778815678, + -0.5255885961448508, + 0.14384538470147779, + 0.14962386035729014, + 0.9853840872537951, + -0.4095384922334613, + 0.9558889691537253, + 0.31645963185736137, + -0.45103923964354453, + 0.13185803391139794, + 0.37159898546802417, + 0.48933768233065034, + -0.9019114984480001, + 0.21281298615294908, + -0.006545426952259348, + 0.8083105817874507, + -0.42761169708161106, + 0.5977202390151974, + 0.2141299963285319, + -0.29535808832936694, + 0.27323575601178374, + 0.24178232626060825, + 0.35552891724671887, + 0.44185675334141794, + 0.3183630806335074, + 0.6766742339250331, + 0.25649620737379664, + 0.8068074081459358, + 0.2926812177810456, + -0.38213423209472785, + -0.11835361967437685, + 0.1591476107368055, + 0.4647195358784766, + -0.8197332485075621, + -0.40977909674168944, + 0.49496172987674303, + -0.6487198591113867, + -0.7356804045064329, + 0.07881551796884212, + 0.9429791624226798, + 0.06170474740557008, + 0.8269739489633938, + 0.6609452391347947, + -0.4860598308734754, + 0.6493796250848145, + -0.03630434025251761, + 0.612976987587533, + 0.4931187014340901, + -0.32256949239622745, + -0.7696605851004812, + 0.9257865857377552, + -0.7184859699882347, + 0.9330004189255041, + 0.7202811937976437, + 0.4484334241510364, + 0.9598844855638069, + 0.9345394946000647, + 0.6091752880411239, + -0.26844990118886725, + 0.5813639371778749, + -0.9721626897981976, + 0.07314461653811821, + -0.09042794453225067, + 0.3456567637475072, + 0.3446815947020263, + 0.1691201833043321, + 0.6448346024535485, + 0.8805837835591084, + -0.7833077956015386, + -0.5323561955768337, + -0.9499507007070354, + 0.7684696904297044, + 0.1228147644997577, + 0.8305118174863189, + -0.55726559985202, + -0.8735659176796091, + 0.6477107027808953, + 0.8187752768557841, + -0.39561965094149953, + -0.18340828840917456, + -0.7204459749855578, + 0.8925230657638925, + -0.39127083128798956, + -0.014750762043588272, + -0.8056160027563568, + 0.7745186170570046, + -0.728671902587327, + -0.09271248622221484, + 0.34097243770034247, + 0.4862802430463431, + 0.8919481715588642, + -0.16174649317055434, + 0.48453802953063163, + -0.69095419518033, + -0.17023094512638615, + -0.801956730578943, + -0.02130592442077095, + -0.18376822860459896, + 0.903043050762119, + -0.9345674262899062, + -0.2589400825311292, + -0.1132338278785967, + 0.9011103397028539, + 0.7109003866119092, + -0.8012907507877394, + 0.3713605309625707, + 0.08893172296428986, + 0.9556850589040935, + -0.28265231757536413, + -0.20372071451125384, + -0.6203828756778409, + -0.7556805618254725, + 0.6960663769273621, + -0.09056526285896571, + 0.325537476123956, + 0.2834089344664352, + 0.19429191903909016, + -0.9572850905272587, + 0.5735891809092335, + -0.5128622056719527, + -0.7481522293839142, + 0.12915595181592665, + -0.8627796943512882, + 0.5303147517771689, + -0.585685259306683, + -0.5680972961626505, + 0.7393908535390894, + -0.34288089313553916, + -0.7048916401171126, + 0.8010620712635164, + -0.9943288970399673, + 0.7168122527603527, + -0.710624039358984, + -0.7400157371131264, + -0.49869160654375233, + -0.6510057581972131, + 0.3221152851946336, + -0.9484397004276033, + -0.9702793455385823, + 0.5799693284695078, + -0.5241367878190739, + -0.3524570760759551, + -0.6515075971876456, + -0.895201964277647, + 0.483436113908299, + 0.05217105319573423, + 0.4913305500679914, + -0.04750806915349948, + 0.5560340786284055, + 0.026475915218383328, + -0.781891979992309, + 0.007677379571642717, + 0.8908312859402125, + -0.9132699262016566, + 0.5664539919607592, + 0.7339618155196765, + 0.042902429426168176, + -0.08391495580463593, + 0.9280523662440574, + -0.8783491850109884, + -0.04203617800327342, + -0.19676549097487905, + 0.3721949921244654, + -0.01946229171154945, + 0.8194016582304586, + -0.853018568467099, + -0.8384190451784161, + 0.21659484726687994, + -0.8686355333597722, + -0.4499680008840936, + 0.26615344860203094, + 0.09671286809672108, + -0.3496291133626641, + 0.9892555117218471, + 0.06111367486264907, + -0.09256916484905675, + 0.21085358307062574, + -0.801643076641916, + 0.40355883709213236, + 0.7055854745911501, + 0.3018333297626108, + 0.5379254602094772, + 0.44167983331519833, + -0.5699538673450062, + -0.09689016806943695, + -0.5430128512708312, + -0.3221367623296496, + -0.09300219418510536, + -0.16802069947700948, + -0.8098283214487383, + -0.146471987478084, + 0.3302157261206178, + -0.2513979531289263, + -0.6947221504625525, + 0.8459700714687688, + -0.8657333837267045, + 0.663543776949709, + -0.8135397965926452, + -0.8068711348684288, + 0.4775919969774314, + 0.6235385705547847, + 0.11274147123070022, + 0.1729301654789539, + 0.12317282798414486, + -0.3407080371675899, + -0.7555374292908814, + -0.2928038407324647, + 0.3306810400058309, + 0.5005685005029565, + 0.7361842977381299, + 0.44212135749229886, + 0.9367972506229489, + 0.20082018244935407, + -0.29670762861370203, + 0.15583703677970973, + -0.5745223886559878, + 0.3134726059763042, + -0.5515102617848688, + -0.7835632361454667, + 0.6907468372026901, + -0.2648778987692997, + 0.5252112638736994, + 0.14820000866292538, + 0.6144427423046888, + 0.6903103226567162, + 0.9490932042514164, + 0.6368537190813393, + 0.22714656107092956, + 0.2853983276596628, + -0.9474923370928332, + 0.8581685819898728, + 0.658921579919326, + -0.4651045496917787, + -0.6391678560782912, + 0.40539754573120934, + -0.3820306223401797, + -0.32035068644548326, + -0.9877884211926826, + 0.7397254130728763, + 0.13264218952275253, + -0.19843131200096886, + -0.7162506916974627, + 0.26634402531102763, + -0.9386858032381882, + 0.49222352401141345, + -0.5697342399329781, + -0.1603350124709937, + -0.31820803646133444, + -0.2598938150459922, + 0.4431919354853464, + 0.553671239933482, + 0.1351871132287945, + -0.8300859200456414, + -0.8947823471489564, + -0.6851802057856937, + 0.23567636385206114, + 0.34793742122622606, + -0.4557943129075621, + 0.32387738561652246, + -0.028676590218007503, + -0.11591162660444176, + -0.4536663112704611, + 0.5098862873367414, + -0.7723649837795965, + -0.1401727332042193, + -0.4335070598391013, + 0.35697250952026605, + -0.02673449327149835, + 0.33426511747269916, + -0.909165274791145, + -0.2094732078247421, + 0.19864991388890085, + -0.9846258282002343, + -0.39716127596312645, + -0.5775320415593754, + -0.725530389485346, + -0.4889609919570814, + -0.34375528812431777, + -0.9845401868613761, + 0.4940282468595356, + -0.6486103962483154, + -0.23958510856952753, + 0.40734252676532723, + 0.000524693112426311, + 0.6667084048397565, + 0.6124003731335277, + -0.8558490068156843, + 0.7235287240451771, + -0.9153954768744172, + -0.9625169268287086, + 0.8423248690048248, + 0.7242200272241328, + 0.15151832147366218, + 0.14679936177168607, + 0.41899792313786843, + -0.16461208031303887, + -0.7696532546724035, + -0.9582868819505024, + -0.3504636411089723, + 0.6026443086209043, + 0.23625052660848045, + 0.6640518261434143, + 0.8395395034827693, + -0.8237402374042306, + 0.6889687196293939, + -0.5133670503545327, + 0.17774257660582382, + 0.04792508600126322, + -0.20846660628133407, + -0.37945087632827734, + -0.3209734377030722, + -0.3338627550137361, + -0.6637345839086268, + 0.020966569084296616, + -0.7719467203228949, + 0.019904124645943932, + 0.8118454631601009, + -0.3012494690552394, + 0.4547582113478086, + 0.6378972030497039, + 0.6300740115000283, + -0.5274623021065514, + -0.7071115634446112, + -0.6054563943520335, + 0.20479797054633164, + 0.5204305910937841, + 0.3110180210374782, + -0.6457077421055417, + 0.5456961784951204, + -0.011765949965222733, + 0.5088916504939716, + 0.5197542992154969, + -0.10218948600987399, + 0.8483085167713187, + 0.12898356680559853, + 0.27059663812118684, + 0.2490435588837896, + 0.7284937496628323, + 0.2544348137995436, + -0.6980851972138462, + -0.8634274830084878, + -0.11558387232743028, + -0.3943591297219742, + -0.45065266502768164, + -0.8876557595738439, + 0.014673770579556322, + -0.37918429878737814, + -0.09617227259863559, + -0.8862198983320952, + 0.6633932633262429, + -0.84653799765309, + 0.7285000678504783, + 0.7105867429121806, + 0.23001677683114718, + 0.014135634667910235, + -0.07457668214545543, + 0.10863274267660783, + 0.583635594530717, + 0.7917535311136052, + -0.10053259269818748, + 0.6196318353959422, + 0.30367490929721974, + -0.3569464742391433, + -0.04874194411386101, + -0.6982778466087229, + -0.8762525997977986, + -0.7929962454592168, + 0.7982536679115471, + -0.31312444832346475, + 0.4286310983349253, + 0.009098003001914767, + -0.6548821771272701, + -0.5045125528034315, + -0.12448345138153405, + -0.1211564164747514, + 0.04549607053109117, + -0.6825075840369172, + -0.25429603579734583, + -0.43421284277116423, + -0.18246120550535938, + -0.3232657063171609, + 0.19577172476597893, + 0.5784538631273295, + 0.2946107139386138, + -0.8681762914405651, + -0.8109881049698113, + 0.356758689680162, + -0.43170605224377256, + 0.44746731000492357, + 0.31312817282095584, + 0.8126853943273544, + 0.7465593241210777, + -0.3332759278788868, + 0.1654790291728392, + -0.7171432388313643, + -0.30035842492831155, + 0.9353930153854892, + 0.3969599256237619, + -0.21608403132927934, + 0.19008245630314957, + 0.8760043991315216, + -0.38083622516807814, + -0.24664138788739964, + 0.5833239157270871, + 0.6263695676296019, + 0.340232799989445, + 0.657917945788929, + 0.47754934425896556, + 0.3708288805551525, + 0.05278667946834803, + 0.29204964146697576, + -0.15318726735579258, + -0.2763438073064308, + -0.27480466203591125, + -0.6394741542463107, + -0.5716146775821993, + 0.8953365350686937, + -0.02745815825454545, + -0.5469139069322255, + -0.7248692936458414, + -0.8456698313982474, + 0.6888567773717666, + -0.7977184726642117, + 0.541749440726202, + 0.670239653269112, + 0.7673643309851232, + -0.9245050152848775, + -0.3264712560714369, + 0.5326152088944831, + -0.737901917140984, + -0.24656025835495066, + -0.675505575823099, + 0.662690113378251, + 0.5421956274619768, + 0.6180874392786022, + -0.6689216685118826, + -0.12465318972330652, + -0.17828277700693307, + 0.3527258443770658, + -0.5249395971061499, + -0.1116025803894487, + -0.4301441348732835, + 0.4970730361908726, + -0.10214407393331926, + 0.06802229939652227, + -0.3810642068744208, + 0.6172477421815727, + -0.061968789935494595, + 0.6702267854514148, + -0.26431808354993436, + 0.8942603404882459, + 0.9688795870631546, + -0.07664004311802164, + -0.43645653459244915, + -0.23625513161857903, + 0.05491957692296534, + 0.9325363064118946, + 0.6337824791624802, + 0.6025184483030952, + -0.7232029307931376, + -0.49999357682198586, + 0.2823580724088943, + 0.748233890104474, + 0.10908149084886243, + -0.7948205365031813, + 0.691784553466877, + 0.7023320961695576, + -0.42987397188134313, + 0.5262336605833817, + -0.45441740081728677, + 0.8106124179564824, + -0.7053026880167914, + -0.12505487961003836, + 0.8928265260235213, + -0.5559239861860388, + -0.09774401955858214, + -0.30082984372270216, + -0.9466596183941229, + -0.8934862256578511, + 0.004014229386470891, + -0.5284438522731347, + 0.9890507024765824, + -0.25017465316446486, + -0.9436249089436999, + 0.8616518094999062, + 0.6783525752232114, + 0.29992136858836327, + 0.582761274964352, + -0.7248008245482567, + -0.42624120537346366, + 0.6595231663056451, + 0.39214397715196725, + -0.7224146163636276, + 0.41107235057816083, + -0.10279705203550682, + -0.9894976033887641, + -0.8415484574585574, + -0.4881521431259106, + 0.669926198564762, + 0.09760849088767087, + 0.4544695706498636, + 0.05554301177344856, + -0.7776262793515232, + -0.42379684392153893, + -0.39769761082756006, + -0.9045011067622974, + -0.1603489124931139, + 0.5877982172788763, + -0.08577276672710266, + -0.77828420941997, + 0.810293771323997, + 0.19347808563789393, + -0.9671292955882111, + 0.03075146041881216, + -0.5161237311580134, + -0.7128463195074799, + -0.14152221379333252, + 0.22961916555190132, + -0.5188715223869129, + -0.16686480958292038, + 0.3287426034840182, + -0.8287720900254765, + 0.9493089819044433, + -0.8646413541823079, + 0.052118890644341054, + 0.014655393159573205, + 0.9766629711929349, + 0.10830390483639096, + -0.2190925348798718, + -0.05972984368302803, + 0.2713415829372601, + 0.9620788451031206, + -0.49269947786157453, + -0.9675155377821048, + 0.5770400325590306, + -0.3103950136732174, + 0.4658820429012833, + 0.25651392495131287, + 0.5430027482197182, + 0.4703739696246225, + -0.33496278325603024, + -0.9113286234095836, + 0.09202749041538305, + 0.6270177311121763, + -0.6498217458965831, + 0.558285186956563, + -0.07075420050592451, + 0.3907785039921279, + 0.2634716955166756, + 0.6229953636952128, + -0.8737989259355508, + 0.5523807994068435, + -0.08464084510529357, + -0.4131148576499375, + -0.9123874486817538, + -0.6010603325696233, + -0.916188116139236, + 0.8667419599007946, + 0.030767178508997572, + 0.9782454045922468, + 0.08606139530837176, + -0.4933724695947652, + 0.5065818376377298, + -0.6177931385321782, + -0.2860516479292732, + 0.56168313395685, + 0.7315965541561151, + -0.33615062723731093, + -0.7510499835112332, + -0.26396165137065397, + 0.7789730340245629, + 0.48661541103924244, + 0.7892749899101066, + -0.22671046347861568, + 0.9474471686306181, + -0.007593546925954708, + -0.004953215012745371, + 0.8486209332539272, + 0.03855170706988398, + 0.6022961748035476, + 0.45416264868527145, + -0.8421459878890643, + 0.2049065976604545, + 0.6446825590797733, + 0.09094879468927397, + -0.35757714357081194, + -0.8398621778500095, + 0.3218384429162733, + -0.38700828781849506, + 0.2052432554611996, + -0.14776785423390915, + 0.37952961689097253, + -0.29690603245601577, + -0.915289674299741, + 0.7400743501127842, + -0.2948813793830354, + 0.9963011955460981, + -0.4508892798502848, + 0.9600545583885582, + 0.8958087572061726, + -0.8499176700236815, + 0.2750250757665966, + -0.2733777386980354, + 0.6021919511243399, + 0.35882121562937974, + 0.9055787925592156, + -0.7144410632749005, + 0.21514580664171068, + 0.562623939486933, + -0.930402068402916 + ], + "sample_rate": 48000 +} \ No newline at end of file diff --git a/generate_weights_a2.py b/generate_weights_a2.py new file mode 100644 index 0000000..4aca1c1 --- /dev/null +++ b/generate_weights_a2.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +""" +Generate weights for wavenet_a2_max.nam file. +This script handles the full A2 architecture including: +- FiLM (Feature-wise Linear Modulation) modules +- head1x1 modules +- condition_dsp (nested WaveNet) +- Advanced gating modes (GATED, BLENDED, NONE) +- Complex activation configurations +""" + +import json +import random +from pathlib import Path +from typing import Dict, Any, List + + +def count_conv1d_weights(in_channels: int, out_channels: int, kernel_size: int, + has_bias: bool, groups: int = 1) -> int: + """Count weights for a Conv1D layer.""" + weight_count = kernel_size * (out_channels * in_channels // groups) + if has_bias: + weight_count += out_channels + return weight_count + + +def count_conv1x1_weights(in_channels: int, out_channels: int, + has_bias: bool, groups: int = 1) -> int: + """Count weights for a Conv1x1 layer (kernel_size=1).""" + weight_count = (out_channels * in_channels // groups) + if has_bias: + weight_count += out_channels + return weight_count + + +def count_film_weights(condition_dim: int, input_dim: int, has_shift: bool) -> int: + """ + Count weights for a FiLM (Feature-wise Linear Modulation) module. + FiLM uses a Conv1x1: condition_dim -> (2*input_dim if shift else input_dim), with bias + """ + out_channels = (2 * input_dim) if has_shift else input_dim + return count_conv1x1_weights(condition_dim, out_channels, has_bias=True, groups=1) + + +def parse_gating_mode(layer_config: Dict[str, Any]) -> str: + """Parse gating mode from layer config (handles both old and new formats).""" + if "gating_mode" in layer_config: + gating_mode_str = layer_config["gating_mode"] + if gating_mode_str in ["GATED", "BLENDED", "NONE"]: + return gating_mode_str + # Handle lowercase versions + return gating_mode_str.upper() + elif "gated" in layer_config: + # Backward compatibility + return "GATED" if layer_config["gated"] else "NONE" + else: + return "NONE" + + +def count_layer_weights(layer_config: Dict[str, Any], condition_size: int) -> int: + """ + Count weights for a single layer (one dilation). + + A layer consists of: + 1. Conv1D: (channels, bottleneck*(2 if gated/blended else 1), kernel_size, bias=True, groups_input) + 2. Input mixin Conv1x1: (condition_size, bottleneck*(2 if gated/blended else 1), bias=False, groups_input_mixin) + 3. 1x1 Conv1x1: (bottleneck, channels, bias=True, groups_1x1) + 4. Optional head1x1 Conv1x1: (bottleneck, head1x1_out_channels, bias=True, head1x1_groups) + 5. FiLM modules (optional, various configurations) + """ + channels = layer_config["channels"] + bottleneck = layer_config.get("bottleneck", channels) + kernel_size = layer_config["kernel_size"] + groups_input = layer_config.get("groups_input", 1) + groups_input_mixin = layer_config.get("groups_input_mixin", 1) + groups_1x1 = layer_config.get("groups_1x1", 1) + + gating_mode = parse_gating_mode(layer_config) + + # Output channels are doubled for GATED and BLENDED modes + conv_out_channels = 2 * bottleneck if gating_mode in ["GATED", "BLENDED"] else bottleneck + + weight_count = 0 + + # 1. Conv1D weights + weight_count += count_conv1d_weights( + channels, conv_out_channels, kernel_size, + has_bias=True, groups=groups_input + ) + + # 2. Input mixin Conv1x1 weights + weight_count += count_conv1x1_weights( + condition_size, conv_out_channels, + has_bias=False, groups=groups_input_mixin + ) + + # 3. 1x1 Conv1x1 weights + weight_count += count_conv1x1_weights( + bottleneck, channels, + has_bias=True, groups=groups_1x1 + ) + + # 4. Optional head1x1 weights + head1x1_config = layer_config.get("head_1x1") or layer_config.get("head1x1") + if head1x1_config and head1x1_config.get("active", False): + head1x1_out_channels = head1x1_config.get("out_channels", channels) + head1x1_groups = head1x1_config.get("groups", 1) + weight_count += count_conv1x1_weights( + bottleneck, head1x1_out_channels, + has_bias=True, groups=head1x1_groups + ) + + # 5. FiLM module weights + # Parse all possible FiLM configurations + film_configs = [ + ("conv_pre_film", channels), + ("conv_post_film", conv_out_channels), + ("input_mixin_pre_film", condition_size), + ("input_mixin_post_film", conv_out_channels), + ("activation_pre_film", conv_out_channels), + ("activation_post_film", bottleneck), + ("1x1_post_film", channels), + ("head1x1_post_film", head1x1_config.get("out_channels", channels) if head1x1_config and head1x1_config.get("active") else 0) + ] + + for film_key, input_dim in film_configs: + if film_key in layer_config and layer_config[film_key]: + film_params = layer_config[film_key] + if isinstance(film_params, dict) and film_params.get("active", True): + has_shift = film_params.get("shift", True) + if input_dim > 0: # Only count if input_dim is valid + weight_count += count_film_weights(condition_size, input_dim, has_shift) + + return weight_count + + +def count_layer_array_weights(layer_config: Dict[str, Any]) -> int: + """ + Count the total number of weights for a layer array. + + Each layer array consists of: + 1. Rechannel Conv1x1: (input_size, channels, bias=False) + 2. Layers (one per dilation) + 3. Head rechannel Conv1x1: (head_output_size, head_size, bias=head_bias) + where head_output_size = head_1x1.out_channels if head_1x1 active, else bottleneck + """ + input_size = layer_config["input_size"] + condition_size = layer_config["condition_size"] + head_size = layer_config["head_size"] + channels = layer_config["channels"] + bottleneck = layer_config.get("bottleneck", channels) + dilations = layer_config["dilations"] + head_bias = layer_config.get("head_bias", False) + + # Determine head output size: head_1x1.out_channels if active, else bottleneck + head1x1_config = layer_config.get("head_1x1") or layer_config.get("head1x1") + if head1x1_config and head1x1_config.get("active", False): + head_output_size = head1x1_config.get("out_channels", channels) + else: + head_output_size = bottleneck + + num_layers = len(dilations) + + weight_count = 0 + + # 1. Rechannel weights + weight_count += count_conv1x1_weights(input_size, channels, has_bias=False, groups=1) + + # 2. For each layer in the array + for _ in range(num_layers): + weight_count += count_layer_weights(layer_config, condition_size) + + # 3. Head rechannel weights (input is head_output_size, not bottleneck) + weight_count += count_conv1x1_weights( + head_output_size, head_size, + has_bias=head_bias, groups=1 + ) + + return weight_count + + +def count_wavenet_weights(config: Dict[str, Any]) -> int: + """ + Count total weights for a WaveNet model (including optional condition_dsp). + """ + weight_count = 0 + + # Count weights for each layer array + for layer_config in config["layers"]: + weight_count += count_layer_array_weights(layer_config) + + # Add head_scale (1 float) + weight_count += 1 + + return weight_count + + +def generate_weights(weight_count: int, seed: int = None, + weight_range: tuple = (-1.0, 1.0)) -> List[float]: + """Generate random weights in the specified range.""" + if seed is not None: + random.seed(seed) + return [random.uniform(*weight_range) for _ in range(weight_count)] + + +def process_model(input_path: Path, output_path: Path, seed: int = None) -> None: + """ + Load a .nam file with empty weights and generate random weights for it. + """ + # Load the input file + with open(input_path, 'r') as f: + model_data = json.load(f) + + print(f"Processing: {input_path}") + print(f"Architecture: {model_data.get('architecture', 'Unknown')}") + + # Process condition_dsp if present + if "config" in model_data and "condition_dsp" in model_data["config"]: + condition_dsp = model_data["config"]["condition_dsp"] + if condition_dsp and "config" in condition_dsp: + print("\nCounting weights for condition_dsp...") + condition_weights = count_wavenet_weights(condition_dsp["config"]) + print(f" Condition DSP weights: {condition_weights}") + + # Generate weights for condition_dsp + condition_dsp["weights"] = generate_weights(condition_weights, seed) + print(f" Generated {len(condition_dsp['weights'])} weights for condition_dsp") + + # Count main model weights + print("\nCounting weights for main model...") + main_weights = count_wavenet_weights(model_data["config"]) + print(f" Main model weights: {main_weights}") + + # Generate weights for main model + model_data["weights"] = generate_weights(main_weights, seed) + print(f" Generated {len(model_data['weights'])} weights for main model") + + # Print detailed breakdown + print("\nWeight breakdown:") + total_weights = 0 + + # Condition DSP breakdown + if "config" in model_data and "condition_dsp" in model_data["config"]: + condition_dsp = model_data["config"]["condition_dsp"] + if condition_dsp and "config" in condition_dsp: + print(" Condition DSP:") + for i, layer in enumerate(condition_dsp["config"]["layers"]): + layer_weights = count_layer_array_weights(layer) + print(f" Layer array {i+1}: {layer_weights} weights") + total_weights += layer_weights + total_weights += 1 # head_scale + + # Main model breakdown + print(" Main model:") + for i, layer in enumerate(model_data["config"]["layers"]): + layer_weights = count_layer_array_weights(layer) + print(f" Layer array {i+1}: {layer_weights} weights") + total_weights += layer_weights + total_weights += 1 # head_scale + + print(f"\nTotal weights generated: {total_weights}") + + # Write output file + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, 'w') as f: + json.dump(model_data, f, indent=4) + + print(f"\nOutput written to: {output_path}") + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Generate weights for A2 WaveNet models with empty weight arrays" + ) + parser.add_argument( + "--input", + type=Path, + default=Path("example_models/wavenet_a2_max.nam"), + help="Input .nam file with empty weights (default: example_models/wavenet_a2_max.nam)" + ) + parser.add_argument( + "--output", + type=Path, + default=Path("example_models/wavenet_a2_max_withweights.nam"), + help="Output .nam file (default: example_models/wavenet_a2_max_withweights.nam)" + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for weight generation (default: 42)" + ) + + args = parser.parse_args() + + process_model(args.input, args.output, args.seed) + + +if __name__ == "__main__": + main() diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 6fce4b6..8118e08 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -12,7 +12,7 @@ include_directories(tools ${NAM_DEPS_PATH}/nlohmann) add_executable(loadmodel loadmodel.cpp ${NAM_SOURCES}) add_executable(benchmodel benchmodel.cpp ${NAM_SOURCES}) -add_executable(run_tests run_tests.cpp ${NAM_SOURCES}) +add_executable(run_tests run_tests.cpp test/allocation_tracking.cpp ${NAM_SOURCES}) # Compile run_tests without optimizations to ensure allocation tracking works correctly # Also ensure assertions are enabled (NDEBUG is not defined) so tests actually run set_target_properties(run_tests PROPERTIES COMPILE_OPTIONS "-O0") diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 3d0b4d1..c86acec 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -8,6 +8,7 @@ #include "test/test_convnet.cpp" #include "test/test_dsp.cpp" #include "test/test_film.cpp" +#include "test/test_film_realtime_safe.cpp" #include "test/test_fast_lut.cpp" #include "test/test_get_dsp.cpp" #include "test/test_ring_buffer.cpp" @@ -118,6 +119,17 @@ int main() test_film::test_process_inplace_scale_only(); test_film::test_process_inplace_partial_frames(); + test_film_realtime_safe::test_allocation_tracking_pass(); + test_film_realtime_safe::test_allocation_tracking_fail(); + test_film_realtime_safe::test_film_process_with_shift_realtime_safe(); + test_film_realtime_safe::test_film_process_without_shift_realtime_safe(); + test_film_realtime_safe::test_film_process_inplace_with_shift_realtime_safe(); + test_film_realtime_safe::test_film_process_inplace_without_shift_realtime_safe(); + test_film_realtime_safe::test_film_process_large_dimensions_realtime_safe(); + test_film_realtime_safe::test_film_process_partial_frames_realtime_safe(); + test_film_realtime_safe::test_film_process_varying_dimensions_realtime_safe(); + test_film_realtime_safe::test_film_process_consecutive_calls_realtime_safe(); + test_wavenet::test_layer::test_gated(); test_wavenet::test_layer::test_layer_getters(); test_wavenet::test_layer::test_non_gated_layer(); @@ -148,6 +160,8 @@ int main() test_wavenet::test_layer_grouped_process_realtime_safe(); test_wavenet::test_layer_all_films_with_shift_realtime_safe(); test_wavenet::test_layer_all_films_without_shift_realtime_safe(); + test_wavenet::test_layer_post_activation_film_gated_realtime_safe(); + test_wavenet::test_layer_post_activation_film_blended_realtime_safe(); test_wavenet::test_layer_array_process_realtime_safe(); test_wavenet::test_process_realtime_safe(); test_wavenet::test_process_3in_2out_realtime_safe(); @@ -209,14 +223,13 @@ int main() test_get_dsp::test_null_output_level(); // Finally, some end-to-end tests. - std::cerr << "Running end-to-end tests" << std::endl; test_get_dsp::test_load_and_process_nam_files(); + std::cout << "Success!" << std::endl; #ifdef ADDASSERT + std::cerr << "===============================================================" << std::endl; std::cerr << "Checking that we're successfully asserting. We should now fail." << std::endl; assert(false); #endif - - std::cout << "Success!" << std::endl; return 0; } diff --git a/tools/test/allocation_tracking.cpp b/tools/test/allocation_tracking.cpp new file mode 100644 index 0000000..983ed11 --- /dev/null +++ b/tools/test/allocation_tracking.cpp @@ -0,0 +1,90 @@ +// Allocation tracking implementation +// This file contains the actual definitions of the global tracking variables +// and the overridden malloc/free/new/delete operators. + +#include "allocation_tracking.h" + +// Allocation tracking globals - definitions +namespace allocation_tracking +{ +volatile int g_allocation_count = 0; +volatile int g_deallocation_count = 0; +volatile bool g_tracking_enabled = false; + +// Original malloc/free functions +void* (*original_malloc)(size_t) = nullptr; +void (*original_free)(void*) = nullptr; +void* (*original_realloc)(void*, size_t) = nullptr; +} // namespace allocation_tracking + +// Override malloc/free to track Eigen allocations (Eigen uses malloc directly) +extern "C" { +void* malloc(size_t size) +{ + if (!allocation_tracking::original_malloc) + allocation_tracking::original_malloc = reinterpret_cast(dlsym(RTLD_NEXT, "malloc")); + void* ptr = allocation_tracking::original_malloc(size); + if (allocation_tracking::g_tracking_enabled && ptr != nullptr) + ++allocation_tracking::g_allocation_count; + return ptr; +} + +void free(void* ptr) +{ + if (!allocation_tracking::original_free) + allocation_tracking::original_free = reinterpret_cast(dlsym(RTLD_NEXT, "free")); + if (allocation_tracking::g_tracking_enabled && ptr != nullptr) + ++allocation_tracking::g_deallocation_count; + allocation_tracking::original_free(ptr); +} + +void* realloc(void* ptr, size_t size) +{ + if (!allocation_tracking::original_realloc) + allocation_tracking::original_realloc = reinterpret_cast(dlsym(RTLD_NEXT, "realloc")); + void* new_ptr = allocation_tracking::original_realloc(ptr, size); + if (allocation_tracking::g_tracking_enabled) + { + if (ptr != nullptr && new_ptr != ptr) + ++allocation_tracking::g_deallocation_count; // Old pointer was freed + if (new_ptr != nullptr && new_ptr != ptr) + ++allocation_tracking::g_allocation_count; // New allocation + } + return new_ptr; +} +} + +// Overload global new/delete operators to track allocations +void* operator new(std::size_t size) +{ + void* ptr = std::malloc(size); + if (!ptr) + throw std::bad_alloc(); + if (allocation_tracking::g_tracking_enabled) + ++allocation_tracking::g_allocation_count; + return ptr; +} + +void* operator new[](std::size_t size) +{ + void* ptr = std::malloc(size); + if (!ptr) + throw std::bad_alloc(); + if (allocation_tracking::g_tracking_enabled) + ++allocation_tracking::g_allocation_count; + return ptr; +} + +void operator delete(void* ptr) noexcept +{ + if (allocation_tracking::g_tracking_enabled && ptr != nullptr) + ++allocation_tracking::g_deallocation_count; + std::free(ptr); +} + +void operator delete[](void* ptr) noexcept +{ + if (allocation_tracking::g_tracking_enabled && ptr != nullptr) + ++allocation_tracking::g_deallocation_count; + std::free(ptr); +} diff --git a/tools/test/allocation_tracking.h b/tools/test/allocation_tracking.h new file mode 100644 index 0000000..00c110a --- /dev/null +++ b/tools/test/allocation_tracking.h @@ -0,0 +1,106 @@ +// Allocation tracking infrastructure for real-time safety tests +// This header provides tools to detect memory allocations/deallocations +// during real-time critical code paths. + +#pragma once + +#include +#include +#include +#include +#include +#include + +// Allocation tracking globals +namespace allocation_tracking +{ +extern volatile int g_allocation_count; +extern volatile int g_deallocation_count; +extern volatile bool g_tracking_enabled; + +// Original malloc/free functions +extern void* (*original_malloc)(size_t); +extern void (*original_free)(void*); +extern void* (*original_realloc)(void*, size_t); + +// Helper function to run allocation tracking tests +// setup: Function to run before tracking starts (can be nullptr) +// test: Function to run while tracking allocations (required) +// teardown: Function to run after tracking stops (can be nullptr) +// expected_allocations: Expected number of allocations (default 0) +// expected_deallocations: Expected number of deallocations (default 0) +// test_name: Name of the test for error messages +template +void run_allocation_test(std::function setup, TestFunc test, std::function teardown, + int expected_allocations, int expected_deallocations, const char* test_name) +{ + // Run setup if provided + if (setup) + setup(); + + // Reset allocation counters and enable tracking + g_allocation_count = 0; + g_deallocation_count = 0; + g_tracking_enabled = true; + + // Run the test code + test(); + + // Disable tracking before any cleanup + g_tracking_enabled = false; + + // Run teardown if provided + if (teardown) + teardown(); + + // Assert expected allocations/deallocations + if (g_allocation_count != expected_allocations || g_deallocation_count != expected_deallocations) + { + std::cerr << "ERROR: " << test_name << " - Expected " << expected_allocations << " allocations, " + << expected_deallocations << " deallocations. Got " << g_allocation_count << " allocations, " + << g_deallocation_count << " deallocations.\n"; + std::abort(); + } +} + +// Convenience wrapper for tests that expect zero allocations (most common case) +template +void run_allocation_test_no_allocations(std::function setup, TestFunc test, std::function teardown, + const char* test_name) +{ + run_allocation_test(setup, test, teardown, 0, 0, test_name); +} + +// Convenience wrapper for tests that expect allocations (for testing the tracking mechanism) +template +void run_allocation_test_expect_allocations(std::function setup, TestFunc test, std::function teardown, + const char* test_name) +{ + // Run setup if provided + if (setup) + setup(); + + // Reset allocation counters and enable tracking + g_allocation_count = 0; + g_deallocation_count = 0; + g_tracking_enabled = true; + + // Run the test code + test(); + + // Disable tracking before any cleanup + g_tracking_enabled = false; + + // Run teardown if provided + if (teardown) + teardown(); + + // Assert that allocations occurred (this test verifies our tracking works) + if (g_allocation_count == 0 && g_deallocation_count == 0) + { + std::cerr << "ERROR: " << test_name + << " - Expected allocations/deallocations but none occurred (tracking may not be working)\n"; + std::abort(); + } +} +} // namespace allocation_tracking diff --git a/tools/test/test_film_realtime_safe.cpp b/tools/test/test_film_realtime_safe.cpp new file mode 100644 index 0000000..af301db --- /dev/null +++ b/tools/test/test_film_realtime_safe.cpp @@ -0,0 +1,515 @@ +// Test to verify FiLM::Process and FiLM::Process_ are real-time safe (no allocations/frees) + +#include +#include +#include +#include +#include + +#include "NAM/film.h" +#include "allocation_tracking.h" + +namespace test_film_realtime_safe +{ +using namespace allocation_tracking; + +// Test that pre-allocated Eigen operations with noalias() don't allocate +void test_allocation_tracking_pass() +{ + const int rows = 10; + const int cols = 20; + + // Pre-allocate matrices for matrix product: c = a * b + // a is rows x cols, b is cols x rows, so c is rows x rows + Eigen::MatrixXf a(rows, cols); + Eigen::MatrixXf b(cols, rows); + Eigen::MatrixXf c(rows, rows); + + a.setConstant(1.0f); + b.setConstant(2.0f); + + run_allocation_test_no_allocations( + nullptr, // No setup needed + [&]() { + // Matrix product with noalias() - should not allocate (all matrices pre-allocated) + c.noalias() = a * b; + }, + nullptr, // No teardown needed + "test_allocation_tracking_pass"); + + // Verify result: c should be rows x rows with value 2*cols (each element is sum of cols elements of value 2) + assert(c.rows() == rows && c.cols() == rows); + assert(std::abs(c(0, 0) - 2.0f * cols) < 0.001f); +} + +// Test that creating a new matrix causes allocations (should be caught) +void test_allocation_tracking_fail() +{ + run_allocation_test_expect_allocations( + nullptr, // No setup needed + [&]() { + // This operation should allocate (creating new matrix) + Eigen::MatrixXf a(10, 20); + a.setConstant(1.0f); + }, + nullptr, // No teardown needed + "test_allocation_tracking_fail"); +} + +// Test that FiLM::Process() method with shift does not allocate or free memory +void test_film_process_with_shift_realtime_safe() +{ + // Setup: Create a FiLM with shift enabled + const int condition_dim = 2; + const int input_dim = 3; + nam::FiLM film(condition_dim, input_dim, /*shift=*/true); + + const int maxBufferSize = 256; + film.SetMaxBufferSize(maxBufferSize); + + // Set weights: all-zero weights with fixed biases + std::vector weights; + weights.resize((2 * input_dim) * condition_dim + (2 * input_dim), 0.0f); + + // Set biases for scale and shift + const int bias_offset = (2 * input_dim) * condition_dim; + weights[bias_offset + 0] = 2.0f; // scale0 + weights[bias_offset + 1] = -1.0f; // scale1 + weights[bias_offset + 2] = 0.5f; // scale2 + weights[bias_offset + 3] = 10.0f; // shift0 + weights[bias_offset + 4] = -20.0f; // shift1 + weights[bias_offset + 5] = 3.0f; // shift2 + + auto it = weights.begin(); + film.set_weights_(it); + + // Test with several different buffer sizes + std::vector buffer_sizes{1, 8, 16, 32, 64, 128, 256}; + + for (int buffer_size : buffer_sizes) + { + // Prepare input/condition matrices (allocate before tracking) + Eigen::MatrixXf input(input_dim, buffer_size); + Eigen::MatrixXf condition(condition_dim, buffer_size); + input.setConstant(0.5f); + condition.setConstant(0.5f); + + std::string test_name = "FiLM Process (with shift) - Buffer size " + std::to_string(buffer_size); + run_allocation_test_no_allocations( + nullptr, // No setup needed + [&]() { + // Call Process() - this should not allocate or free + film.Process(input, condition, buffer_size); + }, + nullptr, // No teardown needed + test_name.c_str()); + + // Verify output is valid + auto output = film.GetOutput().leftCols(buffer_size); + assert(output.rows() == input_dim && output.cols() == buffer_size); + assert(std::isfinite(output(0, 0))); + assert(std::isfinite(output(input_dim - 1, buffer_size - 1))); + } +} + +// Test that FiLM::Process() method without shift does not allocate or free memory +void test_film_process_without_shift_realtime_safe() +{ + // Setup: Create a FiLM with shift disabled (scale-only mode) + const int condition_dim = 2; + const int input_dim = 3; + nam::FiLM film(condition_dim, input_dim, /*shift=*/false); + + const int maxBufferSize = 256; + film.SetMaxBufferSize(maxBufferSize); + + // Set weights: all-zero weights with fixed biases for scale + std::vector weights; + weights.resize(input_dim * condition_dim + input_dim, 0.0f); + + // Set biases for scale + const int bias_offset = input_dim * condition_dim; + weights[bias_offset + 0] = 2.0f; // scale0 + weights[bias_offset + 1] = -1.0f; // scale1 + weights[bias_offset + 2] = 0.5f; // scale2 + + auto it = weights.begin(); + film.set_weights_(it); + + // Test with several different buffer sizes + std::vector buffer_sizes{1, 8, 16, 32, 64, 128, 256}; + + for (int buffer_size : buffer_sizes) + { + // Prepare input/condition matrices (allocate before tracking) + Eigen::MatrixXf input(input_dim, buffer_size); + Eigen::MatrixXf condition(condition_dim, buffer_size); + input.setConstant(0.5f); + condition.setConstant(0.5f); + + std::string test_name = "FiLM Process (without shift) - Buffer size " + std::to_string(buffer_size); + run_allocation_test_no_allocations( + nullptr, // No setup needed + [&]() { + // Call Process() - this should not allocate or free + film.Process(input, condition, buffer_size); + }, + nullptr, // No teardown needed + test_name.c_str()); + + // Verify output is valid + auto output = film.GetOutput().leftCols(buffer_size); + assert(output.rows() == input_dim && output.cols() == buffer_size); + assert(std::isfinite(output(0, 0))); + assert(std::isfinite(output(input_dim - 1, buffer_size - 1))); + } +} + +// Test that FiLM::Process_() in-place method with shift does not allocate or free memory +void test_film_process_inplace_with_shift_realtime_safe() +{ + // Setup: Create a FiLM with shift enabled + const int condition_dim = 2; + const int input_dim = 3; + nam::FiLM film(condition_dim, input_dim, /*shift=*/true); + + const int maxBufferSize = 256; + film.SetMaxBufferSize(maxBufferSize); + + // Set weights: all-zero weights with fixed biases + std::vector weights; + weights.resize((2 * input_dim) * condition_dim + (2 * input_dim), 0.0f); + + // Set biases for scale and shift + const int bias_offset = (2 * input_dim) * condition_dim; + weights[bias_offset + 0] = 2.0f; // scale0 + weights[bias_offset + 1] = -1.0f; // scale1 + weights[bias_offset + 2] = 0.5f; // scale2 + weights[bias_offset + 3] = 10.0f; // shift0 + weights[bias_offset + 4] = -20.0f; // shift1 + weights[bias_offset + 5] = 3.0f; // shift2 + + auto it = weights.begin(); + film.set_weights_(it); + + // Test with several different buffer sizes + std::vector buffer_sizes{1, 8, 16, 32, 64, 128, 256}; + + for (int buffer_size : buffer_sizes) + { + // Prepare input/condition matrices (allocate before tracking) + Eigen::MatrixXf input(input_dim, buffer_size); + Eigen::MatrixXf condition(condition_dim, buffer_size); + input.setConstant(0.5f); + condition.setConstant(0.5f); + + std::string test_name = "FiLM Process_ (in-place with shift) - Buffer size " + std::to_string(buffer_size); + run_allocation_test_no_allocations( + nullptr, // No setup needed + [&]() { + // Call Process_() - this should not allocate or free + film.Process_(input, condition, buffer_size); + }, + nullptr, // No teardown needed + test_name.c_str()); + + // Verify output is valid (input should be modified in-place) + assert(input.rows() == input_dim && input.cols() >= buffer_size); + assert(std::isfinite(input(0, 0))); + assert(std::isfinite(input(input_dim - 1, buffer_size - 1))); + } +} + +// Test that FiLM::Process_() in-place method without shift does not allocate or free memory +void test_film_process_inplace_without_shift_realtime_safe() +{ + // Setup: Create a FiLM with shift disabled (scale-only mode) + const int condition_dim = 2; + const int input_dim = 3; + nam::FiLM film(condition_dim, input_dim, /*shift=*/false); + + const int maxBufferSize = 256; + film.SetMaxBufferSize(maxBufferSize); + + // Set weights: all-zero weights with fixed biases for scale + std::vector weights; + weights.resize(input_dim * condition_dim + input_dim, 0.0f); + + // Set biases for scale + const int bias_offset = input_dim * condition_dim; + weights[bias_offset + 0] = 2.0f; // scale0 + weights[bias_offset + 1] = -1.0f; // scale1 + weights[bias_offset + 2] = 0.5f; // scale2 + + auto it = weights.begin(); + film.set_weights_(it); + + // Test with several different buffer sizes + std::vector buffer_sizes{1, 8, 16, 32, 64, 128, 256}; + + for (int buffer_size : buffer_sizes) + { + // Prepare input/condition matrices (allocate before tracking) + Eigen::MatrixXf input(input_dim, buffer_size); + Eigen::MatrixXf condition(condition_dim, buffer_size); + input.setConstant(0.5f); + condition.setConstant(0.5f); + + std::string test_name = "FiLM Process_ (in-place without shift) - Buffer size " + std::to_string(buffer_size); + run_allocation_test_no_allocations( + nullptr, // No setup needed + [&]() { + // Call Process_() - this should not allocate or free + film.Process_(input, condition, buffer_size); + }, + nullptr, // No teardown needed + test_name.c_str()); + + // Verify output is valid (input should be modified in-place) + assert(input.rows() == input_dim && input.cols() >= buffer_size); + assert(std::isfinite(input(0, 0))); + assert(std::isfinite(input(input_dim - 1, buffer_size - 1))); + } +} + +// Test that FiLM::Process() with larger dimensions does not allocate or free memory +void test_film_process_large_dimensions_realtime_safe() +{ + // Setup: Create a FiLM with larger dimensions + const int condition_dim = 8; + const int input_dim = 16; + nam::FiLM film(condition_dim, input_dim, /*shift=*/true); + + const int maxBufferSize = 256; + film.SetMaxBufferSize(maxBufferSize); + + // Set weights: all-zero weights with fixed biases + std::vector weights; + weights.resize((2 * input_dim) * condition_dim + (2 * input_dim), 0.0f); + + // Set biases for scale and shift (use simple pattern) + const int bias_offset = (2 * input_dim) * condition_dim; + for (int i = 0; i < input_dim; i++) + { + weights[bias_offset + i] = 1.0f + 0.1f * i; // scale + weights[bias_offset + input_dim + i] = 0.5f * i; // shift + } + + auto it = weights.begin(); + film.set_weights_(it); + + // Test with several different buffer sizes + std::vector buffer_sizes{1, 8, 16, 32, 64, 128, 256}; + + for (int buffer_size : buffer_sizes) + { + // Prepare input/condition matrices (allocate before tracking) + Eigen::MatrixXf input(input_dim, buffer_size); + Eigen::MatrixXf condition(condition_dim, buffer_size); + input.setConstant(0.5f); + condition.setConstant(0.5f); + + std::string test_name = "FiLM Process (large dimensions: condition_dim=" + std::to_string(condition_dim) + + ", input_dim=" + std::to_string(input_dim) + ") - Buffer size " + + std::to_string(buffer_size); + run_allocation_test_no_allocations( + nullptr, // No setup needed + [&]() { + // Call Process() - this should not allocate or free + film.Process(input, condition, buffer_size); + }, + nullptr, // No teardown needed + test_name.c_str()); + + // Verify output is valid + auto output = film.GetOutput().leftCols(buffer_size); + assert(output.rows() == input_dim && output.cols() == buffer_size); + assert(std::isfinite(output(0, 0))); + assert(std::isfinite(output(input_dim - 1, buffer_size - 1))); + } +} + +// Test that FiLM::Process() with partial frame processing does not allocate or free memory +void test_film_process_partial_frames_realtime_safe() +{ + // Setup: Create a FiLM + const int condition_dim = 2; + const int input_dim = 3; + nam::FiLM film(condition_dim, input_dim, /*shift=*/true); + + const int maxBufferSize = 256; + film.SetMaxBufferSize(maxBufferSize); + + // Set weights: all-zero weights with fixed biases + std::vector weights; + weights.resize((2 * input_dim) * condition_dim + (2 * input_dim), 0.0f); + + // Set biases for scale and shift + const int bias_offset = (2 * input_dim) * condition_dim; + weights[bias_offset + 0] = 2.0f; // scale0 + weights[bias_offset + 1] = -1.0f; // scale1 + weights[bias_offset + 2] = 0.5f; // scale2 + weights[bias_offset + 3] = 10.0f; // shift0 + weights[bias_offset + 4] = -20.0f; // shift1 + weights[bias_offset + 5] = 3.0f; // shift2 + + auto it = weights.begin(); + film.set_weights_(it); + + // Test with buffer smaller than maxBufferSize to verify partial frame processing + const int full_buffer_size = 64; + std::vector partial_buffer_sizes{1, 8, 16, 32}; + + for (int buffer_size : partial_buffer_sizes) + { + // Prepare input/condition matrices with full size (allocate before tracking) + Eigen::MatrixXf input(input_dim, full_buffer_size); + Eigen::MatrixXf condition(condition_dim, full_buffer_size); + input.setConstant(0.5f); + condition.setConstant(0.5f); + + std::string test_name = "FiLM Process (partial frames: " + std::to_string(buffer_size) + " of " + + std::to_string(full_buffer_size) + " frames)"; + run_allocation_test_no_allocations( + nullptr, // No setup needed + [&]() { + // Call Process() with partial buffer_size - this should not allocate or free + film.Process(input, condition, buffer_size); + }, + nullptr, // No teardown needed + test_name.c_str()); + + // Verify output is valid + auto output = film.GetOutput().leftCols(buffer_size); + assert(output.rows() == input_dim && output.cols() >= buffer_size); + assert(std::isfinite(output(0, 0))); + assert(std::isfinite(output(input_dim - 1, buffer_size - 1))); + } +} + +// Test that FiLM::Process() with varying condition and input dimensions does not allocate or free memory +void test_film_process_varying_dimensions_realtime_safe() +{ + // Test various combinations of condition_dim and input_dim + struct DimConfig + { + int condition_dim; + int input_dim; + bool shift; + }; + + std::vector configs{ + {1, 1, true}, // Minimal dimensions with shift + {1, 1, false}, // Minimal dimensions without shift + {1, 4, true}, // Small condition, larger input + {4, 1, false}, // Larger condition, small input + {4, 4, true}, // Equal dimensions + {3, 5, false}, // Non-power-of-2 dimensions + {7, 11, true}, // Prime dimensions + }; + + const int maxBufferSize = 128; + const int buffer_size = 64; + + for (const auto& config : configs) + { + // Setup: Create a FiLM with specific dimensions + nam::FiLM film(config.condition_dim, config.input_dim, config.shift); + film.SetMaxBufferSize(maxBufferSize); + + // Set weights: all-zero weights with fixed biases + const int output_channels = config.shift ? (2 * config.input_dim) : config.input_dim; + std::vector weights; + weights.resize(output_channels * config.condition_dim + output_channels, 0.0f); + + // Set biases + const int bias_offset = output_channels * config.condition_dim; + for (int i = 0; i < output_channels; i++) + { + weights[bias_offset + i] = 1.0f + 0.1f * i; + } + + auto it = weights.begin(); + film.set_weights_(it); + + // Prepare input/condition matrices (allocate before tracking) + Eigen::MatrixXf input(config.input_dim, buffer_size); + Eigen::MatrixXf condition(config.condition_dim, buffer_size); + input.setConstant(0.5f); + condition.setConstant(0.5f); + + std::string shift_str = config.shift ? "true" : "false"; + std::string test_name = "FiLM Process (condition_dim=" + std::to_string(config.condition_dim) + + ", input_dim=" + std::to_string(config.input_dim) + ", shift=" + shift_str + ")"; + run_allocation_test_no_allocations( + nullptr, // No setup needed + [&]() { + // Call Process() - this should not allocate or free + film.Process(input, condition, buffer_size); + }, + nullptr, // No teardown needed + test_name.c_str()); + + // Verify output is valid + auto output = film.GetOutput().leftCols(buffer_size); + assert(output.rows() == config.input_dim && output.cols() >= buffer_size); + assert(std::isfinite(output(0, 0))); + assert(std::isfinite(output(config.input_dim - 1, buffer_size - 1))); + } +} + +// Test that multiple consecutive calls to FiLM::Process() do not allocate or free memory +void test_film_process_consecutive_calls_realtime_safe() +{ + // Setup: Create a FiLM + const int condition_dim = 2; + const int input_dim = 3; + nam::FiLM film(condition_dim, input_dim, /*shift=*/true); + + const int maxBufferSize = 256; + film.SetMaxBufferSize(maxBufferSize); + + // Set weights + std::vector weights; + weights.resize((2 * input_dim) * condition_dim + (2 * input_dim), 0.0f); + + const int bias_offset = (2 * input_dim) * condition_dim; + weights[bias_offset + 0] = 2.0f; + weights[bias_offset + 1] = -1.0f; + weights[bias_offset + 2] = 0.5f; + weights[bias_offset + 3] = 10.0f; + weights[bias_offset + 4] = -20.0f; + weights[bias_offset + 5] = 3.0f; + + auto it = weights.begin(); + film.set_weights_(it); + + const int buffer_size = 64; + const int num_consecutive_calls = 10; + + // Prepare input/condition matrices (allocate before tracking) + Eigen::MatrixXf input(input_dim, buffer_size); + Eigen::MatrixXf condition(condition_dim, buffer_size); + input.setConstant(0.5f); + condition.setConstant(0.5f); + + std::string test_name = "FiLM Process (consecutive calls: " + std::to_string(num_consecutive_calls) + " calls)"; + run_allocation_test_no_allocations( + nullptr, // No setup needed + [&]() { + // Call Process() multiple times consecutively - none should allocate or free + for (int i = 0; i < num_consecutive_calls; i++) + { + film.Process(input, condition, buffer_size); + } + }, + nullptr, // No teardown needed + test_name.c_str()); + + // Verify output is valid after all calls + auto output = film.GetOutput().leftCols(buffer_size); + assert(output.rows() == input_dim && output.cols() >= buffer_size); + assert(std::isfinite(output(0, 0))); + assert(std::isfinite(output(input_dim - 1, buffer_size - 1))); +} +} // namespace test_film_realtime_safe diff --git a/tools/test/test_get_dsp.cpp b/tools/test/test_get_dsp.cpp index 863cb10..7540b17 100644 --- a/tools/test/test_get_dsp.cpp +++ b/tools/test/test_get_dsp.cpp @@ -139,8 +139,9 @@ void test_load_and_process_nam_files() { // Test loading and processing three different .nam files // Paths are relative to root directory where tests run (./build/tools/run_tests) - const std::vector nam_files = { - "example_models/wavenet.nam", "example_models/lstm.nam", "example_models/wavenet_condition_dsp.nam"}; + const std::vector nam_files = {"example_models/wavenet.nam", "example_models/lstm.nam", + "example_models/wavenet_condition_dsp.nam", + "example_models/wavenet_a2_max.nam"}; const int num_buffers = 3; const int buffer_size = 64; diff --git a/tools/test/test_wavenet/test_condition_processing.cpp b/tools/test/test_wavenet/test_condition_processing.cpp index 23cf7df..afc16f3 100644 --- a/tools/test/test_wavenet/test_condition_processing.cpp +++ b/tools/test/test_wavenet/test_condition_processing.cpp @@ -23,14 +23,15 @@ static nam::wavenet::_FiLMParams make_default_film_params() static nam::wavenet::LayerArrayParams make_layer_array_params( const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck, const int kernel_size, std::vector&& dilations, const nam::activations::ActivationConfig& activation_config, - const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_1x1, - const nam::wavenet::Head1x1Params& head1x1_params, const std::string& secondary_activation) + const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, + const int groups_input_mixin, const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, + const nam::activations::ActivationConfig& secondary_activation_config) { auto film_params = make_default_film_params(); return nam::wavenet::LayerArrayParams( input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations), activation_config, - gating_mode, head_bias, groups_input, groups_1x1, head1x1_params, secondary_activation, film_params, film_params, - film_params, film_params, film_params, film_params, film_params, film_params, film_params); + gating_mode, head_bias, groups_input, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation_config, + film_params, film_params, film_params, film_params, film_params, film_params, film_params, film_params); } // Helper function to create a simple WaveNet with specified input and output channels @@ -51,14 +52,16 @@ std::unique_ptr create_simple_wavenet( const bool head_bias = false; const bool with_head = false; const int groups = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; const bool head1x1_active = false; const int head1x1_groups = 1; nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, head1x1_groups); - nam::wavenet::LayerArrayParams params = make_layer_array_params( - input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations), activation, - gating_mode, head_bias, groups, groups_1x1, head1x1_params, ""); + nam::wavenet::LayerArrayParams params = + make_layer_array_params(input_size, condition_size, head_size, channels, bottleneck, kernel_size, + std::move(dilations), activation, gating_mode, head_bias, groups, groups_input_mixin, + groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); std::vector layer_array_params; layer_array_params.push_back(std::move(params)); diff --git a/tools/test/test_wavenet/test_full.cpp b/tools/test/test_wavenet/test_full.cpp index 9b04d25..20a7af1 100644 --- a/tools/test/test_wavenet/test_full.cpp +++ b/tools/test/test_wavenet/test_full.cpp @@ -22,14 +22,15 @@ static nam::wavenet::_FiLMParams make_default_film_params() static nam::wavenet::LayerArrayParams make_layer_array_params( const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck, const int kernel_size, std::vector&& dilations, const nam::activations::ActivationConfig& activation_config, - const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_1x1, - const nam::wavenet::Head1x1Params& head1x1_params, const std::string& secondary_activation) + const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, + const int groups_input_mixin, const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, + const nam::activations::ActivationConfig& secondary_activation_config) { auto film_params = make_default_film_params(); return nam::wavenet::LayerArrayParams( input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations), activation_config, - gating_mode, head_bias, groups_input, groups_1x1, head1x1_params, secondary_activation, film_params, film_params, - film_params, film_params, film_params, film_params, film_params, film_params, film_params); + gating_mode, head_bias, groups_input, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation_config, + film_params, film_params, film_params, film_params, film_params, film_params, film_params, film_params); } // Test full WaveNet model void test_wavenet_model() @@ -47,13 +48,15 @@ void test_wavenet_model() const float head_scale = 1.0f; const bool with_head = false; const int groups = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; const bool head1x1_active = false; nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, 1); + nam::activations::ActivationConfig empty_config{}; nam::wavenet::LayerArrayParams params = make_layer_array_params( input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations), activation, - gating_mode, head_bias, groups, groups_1x1, head1x1_params, ""); + gating_mode, head_bias, groups, groups_input_mixin, groups_1x1, head1x1_params, empty_config); std::vector layer_array_params; layer_array_params.push_back(std::move(params)); @@ -108,6 +111,7 @@ void test_wavenet_multiple_arrays() const float head_scale = 0.5f; const bool with_head = false; const int groups = 1; + const int groups_input_mixin = 1; std::vector layer_array_params; // First array @@ -119,12 +123,14 @@ void test_wavenet_multiple_arrays() nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, 1); layer_array_params.push_back(make_layer_array_params(input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations1), activation, gating_mode, - head_bias, groups, groups_1x1, head1x1_params, "")); + head_bias, groups, groups_input_mixin, groups_1x1, + head1x1_params, nam::activations::ActivationConfig{})); // Second array (head_size of first must match channels of second) std::vector dilations2{1}; layer_array_params.push_back(make_layer_array_params(head_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations2), activation, gating_mode, - head_bias, groups, groups_1x1, head1x1_params, "")); + head_bias, groups, groups_input_mixin, groups_1x1, + head1x1_params, nam::activations::ActivationConfig{})); std::vector weights; // Array 0: rechannel, layer, head_rechannel @@ -171,13 +177,15 @@ void test_wavenet_zero_input() const float head_scale = 1.0f; const bool with_head = false; const int groups = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; const bool head1x1_active = false; nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, 1); - nam::wavenet::LayerArrayParams params = make_layer_array_params( - input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations), activation, - gating_mode, head_bias, groups, groups_1x1, head1x1_params, ""); + nam::wavenet::LayerArrayParams params = + make_layer_array_params(input_size, condition_size, head_size, channels, bottleneck, kernel_size, + std::move(dilations), activation, gating_mode, head_bias, groups, groups_input_mixin, + groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); std::vector layer_array_params; layer_array_params.push_back(std::move(params)); @@ -220,13 +228,15 @@ void test_wavenet_different_buffer_sizes() const float head_scale = 1.0f; const bool with_head = false; const int groups = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; const bool head1x1_active = false; nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, 1); - nam::wavenet::LayerArrayParams params = make_layer_array_params( - input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations), activation, - gating_mode, head_bias, groups, groups_1x1, head1x1_params, ""); + nam::wavenet::LayerArrayParams params = + make_layer_array_params(input_size, condition_size, head_size, channels, bottleneck, kernel_size, + std::move(dilations), activation, gating_mode, head_bias, groups, groups_input_mixin, + groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); std::vector layer_array_params; layer_array_params.push_back(std::move(params)); @@ -272,14 +282,16 @@ void test_wavenet_prewarm() const float head_scale = 1.0f; const bool with_head = false; const int groups = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; const bool head1x1_active = false; nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, 1); - nam::wavenet::LayerArrayParams params = make_layer_array_params( - input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations), activation, - gating_mode, head_bias, groups, groups_1x1, head1x1_params, ""); + nam::wavenet::LayerArrayParams params = + make_layer_array_params(input_size, condition_size, head_size, channels, bottleneck, kernel_size, + std::move(dilations), activation, gating_mode, head_bias, groups, groups_input_mixin, + groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); std::vector layer_array_params; layer_array_params.push_back(std::move(params)); diff --git a/tools/test/test_wavenet/test_head1x1.cpp b/tools/test/test_wavenet/test_head1x1.cpp index 1ee8c24..be55431 100644 --- a/tools/test/test_wavenet/test_head1x1.cpp +++ b/tools/test/test_wavenet/test_head1x1.cpp @@ -23,14 +23,15 @@ static nam::wavenet::_Layer make_layer(const int condition_size, const int chann const int kernel_size, const int dilation, const nam::activations::ActivationConfig& activation_config, const nam::wavenet::GatingMode gating_mode, const int groups_input, - const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, - const std::string& secondary_activation) + const int groups_input_mixin, const int groups_1x1, + const nam::wavenet::Head1x1Params& head1x1_params, + const nam::activations::ActivationConfig& secondary_activation_config) { auto film_params = make_default_film_params(); return nam::wavenet::_Layer(condition_size, channels, bottleneck, kernel_size, dilation, activation_config, - gating_mode, groups_input, groups_1x1, head1x1_params, secondary_activation, film_params, - film_params, film_params, film_params, film_params, film_params, film_params, film_params, - film_params); + gating_mode, groups_input, groups_input_mixin, groups_1x1, head1x1_params, + secondary_activation_config, film_params, film_params, film_params, film_params, + film_params, film_params, film_params, film_params); } void test_head1x1_inactive() @@ -44,12 +45,14 @@ void test_head1x1_inactive() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; const bool head1x1_active = false; nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, 1); - auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + auto layer = + make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, groups_input, + groups_input_mixin, groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); // Set weights (same as non-gated layer test) // With bottleneck=channels=2: @@ -110,14 +113,16 @@ void test_head1x1_active() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; const bool head1x1_active = true; const int head1x1_groups = 1; // Create head1x1 with different out_channels to verify it's being used nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, head1x1_groups); - auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + auto layer = + make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, groups_input, + groups_input_mixin, groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); // Set weights: conv, input_mixin, 1x1, head1x1 // With bottleneck=channels=2: @@ -183,13 +188,15 @@ void test_head1x1_gated() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::GATED; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; const bool head1x1_active = true; const int head1x1_groups = 1; nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, head1x1_groups); + auto sigmoid_config = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Sigmoid); auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, "Sigmoid"); + groups_input, groups_input_mixin, groups_1x1, head1x1_params, sigmoid_config); // For gated: conv outputs 2*bottleneck, input_mixin outputs 2*bottleneck, 1x1 outputs channels // head1x1 outputs channels @@ -273,13 +280,15 @@ void test_head1x1_groups() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; const bool head1x1_active = true; const int head1x1_groups = 2; // Grouped head1x1 nam::wavenet::Head1x1Params head1x1_params(head1x1_active, channels, head1x1_groups); - auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + auto layer = + make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, groups_input, + groups_input_mixin, groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); // With grouped head1x1, we need to provide weights for each group // For groups=2, channels=4, bottleneck=4: each group has 2 in_channels and 2 out_channels @@ -353,14 +362,16 @@ void test_head1x1_different_out_channels() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; const bool head1x1_active = true; const int head1x1_out_channels = 2; // Different from bottleneck const int head1x1_groups = 1; nam::wavenet::Head1x1Params head1x1_params(head1x1_active, head1x1_out_channels, head1x1_groups); - auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + auto layer = + make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, groups_input, + groups_input_mixin, groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); // head1x1 should map from bottleneck to head1x1_out_channels // With channels=4, bottleneck=4, head1x1_out_channels=2: diff --git a/tools/test/test_wavenet/test_layer.cpp b/tools/test/test_wavenet/test_layer.cpp index 6dff3f9..8c6cc12 100644 --- a/tools/test/test_wavenet/test_layer.cpp +++ b/tools/test/test_wavenet/test_layer.cpp @@ -23,14 +23,15 @@ static nam::wavenet::_Layer make_layer(const int condition_size, const int chann const int kernel_size, const int dilation, const nam::activations::ActivationConfig& activation_config, const nam::wavenet::GatingMode gating_mode, const int groups_input, - const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, - const std::string& secondary_activation) + const int groups_input_mixin, const int groups_1x1, + const nam::wavenet::Head1x1Params& head1x1_params, + const nam::activations::ActivationConfig& secondary_activation_config) { auto film_params = make_default_film_params(); return nam::wavenet::_Layer(condition_size, channels, bottleneck, kernel_size, dilation, activation_config, - gating_mode, groups_input, groups_1x1, head1x1_params, secondary_activation, film_params, - film_params, film_params, film_params, film_params, film_params, film_params, film_params, - film_params); + gating_mode, groups_input, groups_input_mixin, groups_1x1, head1x1_params, + secondary_activation_config, film_params, film_params, film_params, film_params, + film_params, film_params, film_params, film_params); } void test_gated() { @@ -44,10 +45,12 @@ void test_gated() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::GATED; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); + auto sigmoid_config = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Sigmoid); auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, "Sigmoid"); + groups_input, groups_input_mixin, groups_1x1, head1x1_params, sigmoid_config); // Conv, input mixin, 1x1 std::vector weights{ @@ -120,11 +123,13 @@ void test_layer_getters() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); + nam::activations::ActivationConfig empty_config{}; auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + groups_input, groups_input_mixin, groups_1x1, head1x1_params, empty_config); assert(layer.get_channels() == channels); assert(layer.get_kernel_size() == kernelSize); @@ -142,11 +147,13 @@ void test_non_gated_layer() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); + nam::activations::ActivationConfig empty_config{}; auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + groups_input, groups_input_mixin, groups_1x1, head1x1_params, empty_config); // For non-gated: conv outputs 1 channel, input_mixin outputs 1 channel, 1x1 outputs 1 channel // Conv: (1,1,1) weight + (1,) bias @@ -211,11 +218,13 @@ void test_layer_activations() { const int bottleneck = channels; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); auto tanh_config = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); + nam::activations::ActivationConfig empty_config{}; auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, tanh_config, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + groups_input, groups_input_mixin, groups_1x1, head1x1_params, empty_config); std::vector weights{1.0f, 0.0f, 1.0f, 1.0f, 0.0f}; auto it = weights.begin(); layer.set_weights_(it); @@ -248,11 +257,13 @@ void test_layer_multichannel() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); + nam::activations::ActivationConfig empty_config{}; auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + groups_input, groups_input_mixin, groups_1x1, head1x1_params, empty_config); assert(layer.get_channels() == channels); @@ -318,11 +329,13 @@ void test_layer_bottleneck() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); + nam::activations::ActivationConfig empty_config{}; auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + groups_input, groups_input_mixin, groups_1x1, head1x1_params, empty_config); // With bottleneck < channels, the internal conv and input_mixin should have bottleneck channels, // but the 1x1 should map from bottleneck back to channels @@ -396,11 +409,13 @@ void test_layer_bottleneck_gated() const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::GATED; // gated doubles the internal bottleneck channels const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); + auto sigmoid_config = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Sigmoid); auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, "Sigmoid"); + groups_input, groups_input_mixin, groups_1x1, head1x1_params, sigmoid_config); // With gated=true and bottleneck=2, internal channels should be 2*bottleneck=4 // Conv: (channels, 2*bottleneck, kernelSize=1) = (4, 4, 1) + bias diff --git a/tools/test/test_wavenet/test_layer_array.cpp b/tools/test/test_wavenet/test_layer_array.cpp index 082d990..cd1f762 100644 --- a/tools/test/test_wavenet/test_layer_array.cpp +++ b/tools/test/test_wavenet/test_layer_array.cpp @@ -22,14 +22,15 @@ static nam::wavenet::_FiLMParams make_default_film_params() static nam::wavenet::_LayerArray make_layer_array( const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck, const int kernel_size, const std::vector& dilations, const nam::activations::ActivationConfig& activation_config, - const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_1x1, - const nam::wavenet::Head1x1Params& head1x1_params, const std::string& secondary_activation) + const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, + const int groups_input_mixin, const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, + const nam::activations::ActivationConfig& secondary_activation_config) { auto film_params = make_default_film_params(); return nam::wavenet::_LayerArray(input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, - activation_config, gating_mode, head_bias, groups_input, groups_1x1, head1x1_params, - secondary_activation, film_params, film_params, film_params, film_params, - film_params, film_params, film_params, film_params, film_params); + activation_config, gating_mode, head_bias, groups_input, groups_input_mixin, + groups_1x1, head1x1_params, secondary_activation_config, film_params, film_params, + film_params, film_params, film_params, film_params, film_params, film_params); } // Test layer array construction and basic processing void test_layer_array_basic() @@ -45,12 +46,13 @@ void test_layer_array_basic() const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const int groups = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); - auto layer_array = - make_layer_array(input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation, - gating_mode, head_bias, groups, groups_1x1, head1x1_params, ""); + auto layer_array = make_layer_array(input_size, condition_size, head_size, channels, bottleneck, kernel_size, + dilations, activation, gating_mode, head_bias, groups, groups_input_mixin, + groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); const int numFrames = 4; layer_array.SetMaxBufferSize(numFrames); @@ -104,12 +106,13 @@ void test_layer_array_receptive_field() const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const int groups = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); - auto layer_array = - make_layer_array(input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation, - gating_mode, head_bias, groups, groups_1x1, head1x1_params, ""); + auto layer_array = make_layer_array(input_size, condition_size, head_size, channels, bottleneck, kernel_size, + dilations, activation, gating_mode, head_bias, groups, groups_input_mixin, + groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); long rf = layer_array.get_receptive_field(); // Expected: sum of dilation * (kernel_size - 1) for each layer @@ -135,12 +138,13 @@ void test_layer_array_with_head_input() const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const int groups = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); - auto layer_array = - make_layer_array(input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation, - gating_mode, head_bias, groups, groups_1x1, head1x1_params, ""); + auto layer_array = make_layer_array(input_size, condition_size, head_size, channels, bottleneck, kernel_size, + dilations, activation, gating_mode, head_bias, groups, groups_input_mixin, + groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); const int numFrames = 2; layer_array.SetMaxBufferSize(numFrames); diff --git a/tools/test/test_wavenet/test_real_time_safe.cpp b/tools/test/test_wavenet/test_real_time_safe.cpp index ff872e2..00cfb5b 100644 --- a/tools/test/test_wavenet/test_real_time_safe.cpp +++ b/tools/test/test_wavenet/test_real_time_safe.cpp @@ -3,105 +3,19 @@ #include #include #include -#include -#include -#include #include #include -#include #include #include #include "NAM/wavenet.h" #include "NAM/conv1d.h" - -// Allocation tracking -namespace -{ -volatile int g_allocation_count = 0; -volatile int g_deallocation_count = 0; -volatile bool g_tracking_enabled = false; - -// Original malloc/free functions -void* (*original_malloc)(size_t) = nullptr; -void (*original_free)(void*) = nullptr; -void* (*original_realloc)(void*, size_t) = nullptr; -} // namespace - -// Override malloc/free to track Eigen allocations (Eigen uses malloc directly) -extern "C" { -void* malloc(size_t size) -{ - if (!original_malloc) - original_malloc = reinterpret_cast(dlsym(RTLD_NEXT, "malloc")); - void* ptr = original_malloc(size); - if (g_tracking_enabled && ptr != nullptr) - ++g_allocation_count; - return ptr; -} - -void free(void* ptr) -{ - if (!original_free) - original_free = reinterpret_cast(dlsym(RTLD_NEXT, "free")); - if (g_tracking_enabled && ptr != nullptr) - ++g_deallocation_count; - original_free(ptr); -} - -void* realloc(void* ptr, size_t size) -{ - if (!original_realloc) - original_realloc = reinterpret_cast(dlsym(RTLD_NEXT, "realloc")); - void* new_ptr = original_realloc(ptr, size); - if (g_tracking_enabled) - { - if (ptr != nullptr && new_ptr != ptr) - ++g_deallocation_count; // Old pointer was freed - if (new_ptr != nullptr && new_ptr != ptr) - ++g_allocation_count; // New allocation - } - return new_ptr; -} -} - -// Overload global new/delete operators to track allocations -void* operator new(std::size_t size) -{ - void* ptr = std::malloc(size); - if (!ptr) - throw std::bad_alloc(); - if (g_tracking_enabled) - ++g_allocation_count; - return ptr; -} - -void* operator new[](std::size_t size) -{ - void* ptr = std::malloc(size); - if (!ptr) - throw std::bad_alloc(); - if (g_tracking_enabled) - ++g_allocation_count; - return ptr; -} - -void operator delete(void* ptr) noexcept -{ - if (g_tracking_enabled && ptr != nullptr) - ++g_deallocation_count; - std::free(ptr); -} - -void operator delete[](void* ptr) noexcept -{ - if (g_tracking_enabled && ptr != nullptr) - ++g_deallocation_count; - std::free(ptr); -} +#include "../allocation_tracking.h" namespace test_wavenet { +using namespace allocation_tracking; + // Helper function to create default (inactive) FiLM parameters static nam::wavenet::_FiLMParams make_default_film_params() { @@ -113,42 +27,45 @@ static nam::wavenet::_Layer make_layer(const int condition_size, const int chann const int kernel_size, const int dilation, const nam::activations::ActivationConfig& activation_config, const nam::wavenet::GatingMode gating_mode, const int groups_input, - const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, - const std::string& secondary_activation) + const int groups_input_mixin, const int groups_1x1, + const nam::wavenet::Head1x1Params& head1x1_params, + const nam::activations::ActivationConfig& secondary_activation_config) { auto film_params = make_default_film_params(); return nam::wavenet::_Layer(condition_size, channels, bottleneck, kernel_size, dilation, activation_config, - gating_mode, groups_input, groups_1x1, head1x1_params, secondary_activation, film_params, - film_params, film_params, film_params, film_params, film_params, film_params, film_params, - film_params); + gating_mode, groups_input, groups_input_mixin, groups_1x1, head1x1_params, + secondary_activation_config, film_params, film_params, film_params, film_params, + film_params, film_params, film_params, film_params); } // Helper function to create a LayerArray with default FiLM parameters static nam::wavenet::_LayerArray make_layer_array( const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck, const int kernel_size, const std::vector& dilations, const nam::activations::ActivationConfig& activation_config, - const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_1x1, - const nam::wavenet::Head1x1Params& head1x1_params, const std::string& secondary_activation) + const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, + const int groups_input_mixin, const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, + const nam::activations::ActivationConfig& secondary_activation_config) { auto film_params = make_default_film_params(); return nam::wavenet::_LayerArray(input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, - activation_config, gating_mode, head_bias, groups_input, groups_1x1, head1x1_params, - secondary_activation, film_params, film_params, film_params, film_params, - film_params, film_params, film_params, film_params, film_params); + activation_config, gating_mode, head_bias, groups_input, groups_input_mixin, + groups_1x1, head1x1_params, secondary_activation_config, film_params, film_params, + film_params, film_params, film_params, film_params, film_params, film_params); } // Helper function to create LayerArrayParams with default FiLM parameters static nam::wavenet::LayerArrayParams make_layer_array_params( const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck, const int kernel_size, std::vector&& dilations, const nam::activations::ActivationConfig& activation_config, - const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_1x1, - const nam::wavenet::Head1x1Params& head1x1_params, const std::string& secondary_activation) + const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, + const int groups_input_mixin, const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, + const nam::activations::ActivationConfig& secondary_activation_config) { auto film_params = make_default_film_params(); return nam::wavenet::LayerArrayParams( input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations), activation_config, - gating_mode, head_bias, groups_input, groups_1x1, head1x1_params, secondary_activation, film_params, film_params, - film_params, film_params, film_params, film_params, film_params, film_params, film_params); + gating_mode, head_bias, groups_input, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation_config, + film_params, film_params, film_params, film_params, film_params, film_params, film_params, film_params); } // Helper function to create a Layer with all FiLMs active @@ -156,95 +73,19 @@ static nam::wavenet::_Layer make_layer_all_films(const int condition_size, const const int kernel_size, const int dilation, const nam::activations::ActivationConfig& activation_config, const nam::wavenet::GatingMode gating_mode, const int groups_input, - const int groups_1x1, + const int groups_input_mixin, const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, - const std::string& secondary_activation, const bool shift) + const nam::activations::ActivationConfig& secondary_activation_config, + const bool shift) { nam::wavenet::_FiLMParams film_params(true, shift); + // Don't activate head1x1_post_film if head1x1 is not active (validation will fail) + nam::wavenet::_FiLMParams head1x1_post_film_params = + head1x1_params.active ? film_params : nam::wavenet::_FiLMParams(false, false); return nam::wavenet::_Layer(condition_size, channels, bottleneck, kernel_size, dilation, activation_config, - gating_mode, groups_input, groups_1x1, head1x1_params, secondary_activation, film_params, - film_params, film_params, film_params, film_params, film_params, film_params, film_params, - film_params); -} -// Helper function to run allocation tracking tests -// setup: Function to run before tracking starts (can be nullptr) -// test: Function to run while tracking allocations (required) -// teardown: Function to run after tracking stops (can be nullptr) -// expected_allocations: Expected number of allocations (default 0) -// expected_deallocations: Expected number of deallocations (default 0) -// test_name: Name of the test for error messages -template -void run_allocation_test(std::function setup, TestFunc test, std::function teardown, - int expected_allocations, int expected_deallocations, const char* test_name) -{ - // Run setup if provided - if (setup) - setup(); - - // Reset allocation counters and enable tracking - g_allocation_count = 0; - g_deallocation_count = 0; - g_tracking_enabled = true; - - // Run the test code - test(); - - // Disable tracking before any cleanup - g_tracking_enabled = false; - - // Run teardown if provided - if (teardown) - teardown(); - - // Assert expected allocations/deallocations - if (g_allocation_count != expected_allocations || g_deallocation_count != expected_deallocations) - { - std::cerr << "ERROR: " << test_name << " - Expected " << expected_allocations << " allocations, " - << expected_deallocations << " deallocations. Got " << g_allocation_count << " allocations, " - << g_deallocation_count << " deallocations.\n"; - std::abort(); - } -} - -// Convenience wrapper for tests that expect zero allocations (most common case) -template -void run_allocation_test_no_allocations(std::function setup, TestFunc test, std::function teardown, - const char* test_name) -{ - run_allocation_test(setup, test, teardown, 0, 0, test_name); -} - -// Convenience wrapper for tests that expect allocations (for testing the tracking mechanism) -template -void run_allocation_test_expect_allocations(std::function setup, TestFunc test, std::function teardown, - const char* test_name) -{ - // Run setup if provided - if (setup) - setup(); - - // Reset allocation counters and enable tracking - g_allocation_count = 0; - g_deallocation_count = 0; - g_tracking_enabled = true; - - // Run the test code - test(); - - // Disable tracking before any cleanup - g_tracking_enabled = false; - - // Run teardown if provided - if (teardown) - teardown(); - - // Assert that allocations occurred (this test verifies our tracking works) - if (g_allocation_count == 0 && g_deallocation_count == 0) - { - std::cerr << "ERROR: " << test_name - << " - Expected allocations/deallocations but none occurred (tracking may not be working)\n"; - std::abort(); - } + gating_mode, groups_input, groups_input_mixin, groups_1x1, head1x1_params, + secondary_activation_config, film_params, film_params, film_params, film_params, + film_params, film_params, film_params, head1x1_post_film_params); } // Test that pre-allocated Eigen operations with noalias() don't allocate @@ -279,21 +120,15 @@ void test_allocation_tracking_pass() assert(std::abs(c(0, 0) - 2.0f * cols) < 0.001f); } -// Test that resizing a matrix causes allocations (should be caught) +// Test that creating a new matrix causes allocations (should be caught) void test_allocation_tracking_fail() { - const int rows = 10; - const int cols = 20; - - // Pre-allocate matrix - Eigen::MatrixXf a(rows, cols); - a.setConstant(1.0f); - run_allocation_test_expect_allocations( nullptr, // No setup needed [&]() { - // This operation should allocate (resizing requires reallocation) - a.resize(rows * 2, cols * 2); + // This operation should allocate (creating new matrix) + Eigen::MatrixXf a(10, 20); + a.setConstant(1.0f); }, nullptr, // No teardown needed "test_allocation_tracking_fail"); @@ -499,11 +334,13 @@ void test_layer_process_realtime_safe() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); - auto layer = make_layer(condition_size, channels, bottleneck, kernel_size, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + auto layer = + make_layer(condition_size, channels, bottleneck, kernel_size, dilation, activation, gating_mode, groups_input, + groups_input_mixin, groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); // Set weights std::vector weights{1.0f, 0.0f, // Conv (weight, bias) @@ -555,11 +392,13 @@ void test_layer_bottleneck_process_realtime_safe() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); - auto layer = make_layer(condition_size, channels, bottleneck, kernel_size, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + auto layer = + make_layer(condition_size, channels, bottleneck, kernel_size, dilation, activation, gating_mode, groups_input, + groups_input_mixin, groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); // Set weights for bottleneck != channels // Conv: (channels, bottleneck, kernelSize=1) = (4, 2, 1) + bias @@ -641,11 +480,13 @@ void test_layer_grouped_process_realtime_safe() const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 2; // groups_input > 1 + const int groups_input_mixin = 1; const int groups_1x1 = 2; // 1x1 is also grouped nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); - auto layer = make_layer(condition_size, channels, bottleneck, kernel_size, dilation, activation, gating_mode, - groups_input, groups_1x1, head1x1_params, ""); + auto layer = + make_layer(condition_size, channels, bottleneck, kernel_size, dilation, activation, gating_mode, groups_input, + groups_input_mixin, groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); // Set weights for grouped convolution // With groups_input=2, channels=4: each group has 2 in_channels and 2 out_channels @@ -750,11 +591,13 @@ static void test_layer_all_films_realtime_safe_impl(const bool shift) const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); auto layer = make_layer_all_films(condition_size, channels, bottleneck, kernel_size, dilation, activation, - gating_mode, groups_input, groups_1x1, head1x1_params, "", shift); + gating_mode, groups_input, groups_input_mixin, groups_1x1, head1x1_params, + nam::activations::ActivationConfig{}, shift); // Set weights // Base layer weights: @@ -772,8 +615,8 @@ static void test_layer_all_films_realtime_safe_impl(const bool shift) // FiLM weights (each FiLM uses Conv1x1: condition_size -> (shift ? 2 : 1) * input_dim with bias) // With shift=true: each FiLM needs (2 * input_dim) * condition_size weights + (2 * input_dim) biases = 4 weights // With shift=false: each FiLM needs input_dim * condition_size weights + input_dim biases = 2 weights - // All 8 FiLMs are active (excluding head1x1_post_film since head1x1 is false) - for (int i = 0; i < 8; i++) + // All 7 FiLMs are active (excluding head1x1_post_film since head1x1 is false) + for (int i = 0; i < 7; i++) { if (shift) { @@ -840,6 +683,220 @@ void test_layer_all_films_without_shift_realtime_safe() test_layer_all_films_realtime_safe_impl(false); } +// Test that Layer::Process() with post-activation FiLM (gated mode) does not allocate or free memory +// This specifically tests the case where FiLM::Process() receives _z.topRows(bottleneck) +void test_layer_post_activation_film_gated_realtime_safe() +{ + // Setup: Create a Layer with GATED mode and activation_post_film enabled + // Use simpler dimensions first to verify weight counting + const int condition_size = 1; + const int channels = 2; + const int bottleneck = 1; // bottleneck < channels to trigger topRows() + const int kernel_size = 1; + const int dilation = 1; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); + const auto secondary_activation = + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Sigmoid); + const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::GATED; + const int groups_input = 1; + const int groups_input_mixin = 1; + const int groups_1x1 = 1; + nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); + + // Create FiLM params with activation_post_film enabled + nam::wavenet::_FiLMParams inactive_film(false, false); + nam::wavenet::_FiLMParams active_film(true, true); // activation_post_film will be active + + auto layer = + nam::wavenet::_Layer(condition_size, channels, bottleneck, kernel_size, dilation, activation, gating_mode, + groups_input, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation, + inactive_film, // conv_pre_film + inactive_film, // conv_post_film + inactive_film, // input_mixin_pre_film + inactive_film, // input_mixin_post_film + inactive_film, // activation_pre_film + active_film, // activation_post_film - THIS IS THE KEY ONE + inactive_film, // _1x1_post_film + inactive_film // head1x1_post_film + ); + + // Set weights - Order: conv, input_mixin, 1x1, then FiLMs + // NOTE: In GATED mode, conv and input_mixin output 2*bottleneck channels! + std::vector weights; + + // Conv weights: In GATED mode outputs 2*bottleneck = 2*1 = 2 channels + // Conv: (out_channels, in_channels, kernel_size) + bias = (2, 2, 1) + 2 = 4 + 2 = 6 + weights.push_back(0.5f); // ch0, in0 + weights.push_back(0.5f); // ch0, in1 + weights.push_back(0.5f); // ch1, in0 + weights.push_back(0.5f); // ch1, in1 + weights.push_back(0.0f); // bias ch0 + weights.push_back(0.0f); // bias ch1 + + // Input mixin: outputs 2*bottleneck = 2 channels + // (condition_size, out_channels) = (1, 2) = 2 weights + weights.push_back(0.5f); // ch0 + weights.push_back(0.5f); // ch1 + + // 1x1 weights: (bottleneck, channels) + bias = (1, 2) + 2 = 2 + 2 = 4 + weights.push_back(1.0f); + weights.push_back(1.0f); + weights.push_back(0.0f); // bias + weights.push_back(0.0f); + + // activation_post_film: FiLM(condition_size, bottleneck, shift=true) + // Creates Conv1x1(condition_size, 2*bottleneck, bias=true) internally + // Weight count: (1 * 2) + 2 = 4 weights + weights.push_back(1.0f); // scale weight + weights.push_back(0.0f); // shift weight + weights.push_back(1.0f); // scale bias + weights.push_back(0.0f); // shift bias + + auto it = weights.begin(); + layer.set_weights_(it); + assert(it == weights.end()); + + const int maxBufferSize = 256; + layer.SetMaxBufferSize(maxBufferSize); + + // Test with several different buffer sizes + std::vector buffer_sizes{1, 8, 16, 32, 64, 128, 256}; + + for (int buffer_size : buffer_sizes) + { + // Prepare input/condition matrices (allocate before tracking) + Eigen::MatrixXf input(channels, buffer_size); + Eigen::MatrixXf condition(condition_size, buffer_size); + input.setConstant(0.5f); + condition.setConstant(0.5f); + + std::string test_name = + "Layer Process (GATED with activation_post_film) - Buffer size " + std::to_string(buffer_size); + run_allocation_test_no_allocations( + nullptr, // No setup needed + [&]() { + // Call Process() - this should not allocate or free + // This will trigger: _activation_post_film->Process(this->_z.topRows(bottleneck), condition, num_frames) + layer.Process(input, condition, buffer_size); + }, + nullptr, // No teardown needed + test_name.c_str()); + + // Verify output is valid + auto output = layer.GetOutputNextLayer().leftCols(buffer_size); + assert(output.rows() == channels && output.cols() == buffer_size); + assert(std::isfinite(output(0, 0))); + assert(std::isfinite(output(channels - 1, buffer_size - 1))); + } +} + +// Test that Layer::Process() with post-activation FiLM (blended mode) does not allocate or free memory +// This also tests the case where FiLM::Process() receives _z.topRows(bottleneck) +void test_layer_post_activation_film_blended_realtime_safe() +{ + // Setup: Create a Layer with BLENDED mode and activation_post_film enabled + // Use simpler dimensions first to verify weight counting + const int condition_size = 1; + const int channels = 2; + const int bottleneck = 1; // bottleneck < channels to trigger topRows() + const int kernel_size = 1; + const int dilation = 1; + const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); + const auto secondary_activation = + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Sigmoid); + const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::BLENDED; + const int groups_input = 1; + const int groups_input_mixin = 1; + const int groups_1x1 = 1; + nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); + + // Create FiLM params with activation_post_film enabled + nam::wavenet::_FiLMParams inactive_film(false, false); + nam::wavenet::_FiLMParams active_film(true, true); // activation_post_film will be active + + auto layer = + nam::wavenet::_Layer(condition_size, channels, bottleneck, kernel_size, dilation, activation, gating_mode, + groups_input, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation, + inactive_film, // conv_pre_film + inactive_film, // conv_post_film + inactive_film, // input_mixin_pre_film + inactive_film, // input_mixin_post_film + inactive_film, // activation_pre_film + active_film, // activation_post_film - THIS IS THE KEY ONE + inactive_film, // _1x1_post_film + inactive_film // head1x1_post_film + ); + + // Set weights - Order: conv, input_mixin, 1x1, then FiLMs + // NOTE: In BLENDED mode, conv and input_mixin output 2*bottleneck channels! + std::vector weights; + + // Conv weights: In BLENDED mode outputs 2*bottleneck = 2*1 = 2 channels + // Conv: (out_channels, in_channels, kernel_size) + bias = (2, 2, 1) + 2 = 4 + 2 = 6 + weights.push_back(0.5f); // ch0, in0 + weights.push_back(0.5f); // ch0, in1 + weights.push_back(0.5f); // ch1, in0 + weights.push_back(0.5f); // ch1, in1 + weights.push_back(0.0f); // bias ch0 + weights.push_back(0.0f); // bias ch1 + + // Input mixin: outputs 2*bottleneck = 2 channels + // (condition_size, out_channels) = (1, 2) = 2 weights + weights.push_back(0.5f); // ch0 + weights.push_back(0.5f); // ch1 + + // 1x1 weights: (bottleneck, channels) + bias = (1, 2) + 2 = 2 + 2 = 4 + weights.push_back(1.0f); + weights.push_back(1.0f); + weights.push_back(0.0f); // bias + weights.push_back(0.0f); + + // activation_post_film: FiLM(condition_size, bottleneck, shift=true) + // Creates Conv1x1(condition_size, 2*bottleneck, bias=true) internally + // Weight count: (1 * 2) + 2 = 4 weights + weights.push_back(1.0f); // scale weight + weights.push_back(0.0f); // shift weight + weights.push_back(1.0f); // scale bias + weights.push_back(0.0f); // shift bias + + auto it = weights.begin(); + layer.set_weights_(it); + assert(it == weights.end()); + + const int maxBufferSize = 256; + layer.SetMaxBufferSize(maxBufferSize); + + // Test with several different buffer sizes + std::vector buffer_sizes{1, 8, 16, 32, 64, 128, 256}; + + for (int buffer_size : buffer_sizes) + { + // Prepare input/condition matrices (allocate before tracking) + Eigen::MatrixXf input(channels, buffer_size); + Eigen::MatrixXf condition(condition_size, buffer_size); + input.setConstant(0.5f); + condition.setConstant(0.5f); + + std::string test_name = + "Layer Process (BLENDED with activation_post_film) - Buffer size " + std::to_string(buffer_size); + run_allocation_test_no_allocations( + nullptr, // No setup needed + [&]() { + // Call Process() - this should not allocate or free + // This will trigger: _activation_post_film->Process(this->_z.topRows(bottleneck), condition, num_frames) + layer.Process(input, condition, buffer_size); + }, + nullptr, // No teardown needed + test_name.c_str()); + + // Verify output is valid + auto output = layer.GetOutputNextLayer().leftCols(buffer_size); + assert(output.rows() == channels && output.cols() == buffer_size); + assert(std::isfinite(output(0, 0))); + assert(std::isfinite(output(channels - 1, buffer_size - 1))); + } +} + // Test that LayerArray::Process() method does not allocate or free memory void test_layer_array_process_realtime_safe() { @@ -855,12 +912,13 @@ void test_layer_array_process_realtime_safe() const nam::wavenet::GatingMode gating_mode = nam::wavenet::GatingMode::NONE; const bool head_bias = false; const int groups = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); - auto layer_array = - make_layer_array(input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation, - gating_mode, head_bias, groups, groups_1x1, head1x1_params, ""); + auto layer_array = make_layer_array(input_size, condition_size, head_size, channels, bottleneck, kernel_size, + dilations, activation, gating_mode, head_bias, groups, groups_input_mixin, + groups_1x1, head1x1_params, nam::activations::ActivationConfig{}); // Set weights: rechannel(1), layer(conv:1+1, input_mixin:1, 1x1:1+1), head_rechannel(1) std::vector weights{1.0f, // Rechannel @@ -921,6 +979,7 @@ void test_process_realtime_safe() const float head_scale = 1.0f; const bool with_head = false; const int groups = 1; + const int groups_input_mixin = 1; std::vector layer_array_params; // First layer array @@ -930,12 +989,14 @@ void test_process_realtime_safe() nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); layer_array_params.push_back(make_layer_array_params(input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations1), activation, gating_mode, - head_bias, groups, groups_1x1, head1x1_params, "")); + head_bias, groups, groups_input_mixin, groups_1x1, + head1x1_params, nam::activations::ActivationConfig{})); // Second layer array (head_size of first must match channels of second) std::vector dilations2{1}; layer_array_params.push_back(make_layer_array_params(head_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations2), activation, gating_mode, - head_bias, groups, groups_1x1, head1x1_params, "")); + head_bias, groups, groups_input_mixin, groups_1x1, + head1x1_params, nam::activations::ActivationConfig{})); // Weights: Array 0: rechannel(1), layer(conv:1+1, input_mixin:1, 1x1:1+1), head_rechannel(1) // Array 1: same structure @@ -999,6 +1060,7 @@ void test_process_3in_2out_realtime_safe() const float head_scale = 1.0f; const bool with_head = false; const int groups = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); @@ -1007,7 +1069,8 @@ void test_process_3in_2out_realtime_safe() std::vector dilations1{1}; layer_array_params.push_back(make_layer_array_params(input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations1), activation, gating_mode, - head_bias, groups, groups_1x1, head1x1_params, "")); + head_bias, groups, groups_input_mixin, groups_1x1, + head1x1_params, nam::activations::ActivationConfig{})); // Calculate weights: // _rechannel: Conv1x1(3, 4, bias=false) = 3*4 = 12 weights diff --git a/tools/test/test_wavenet_configurable_gating.cpp b/tools/test/test_wavenet_configurable_gating.cpp index 2e5895f..6dbd18d 100644 --- a/tools/test/test_wavenet_configurable_gating.cpp +++ b/tools/test/test_wavenet_configurable_gating.cpp @@ -20,42 +20,45 @@ static nam::wavenet::_Layer make_layer(const int condition_size, const int chann const int kernel_size, const int dilation, const nam::activations::ActivationConfig& activation_config, const nam::wavenet::GatingMode gating_mode, const int groups_input, - const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, - const std::string& secondary_activation) + const int groups_input_mixin, const int groups_1x1, + const nam::wavenet::Head1x1Params& head1x1_params, + const nam::activations::ActivationConfig& secondary_activation_config) { auto film_params = make_default_film_params(); return nam::wavenet::_Layer(condition_size, channels, bottleneck, kernel_size, dilation, activation_config, - gating_mode, groups_input, groups_1x1, head1x1_params, secondary_activation, film_params, - film_params, film_params, film_params, film_params, film_params, film_params, film_params, - film_params); + gating_mode, groups_input, groups_input_mixin, groups_1x1, head1x1_params, + secondary_activation_config, film_params, film_params, film_params, film_params, + film_params, film_params, film_params, film_params); } // Helper function to create LayerArrayParams with default FiLM parameters static nam::wavenet::LayerArrayParams make_layer_array_params( const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck, const int kernel_size, std::vector&& dilations, const nam::activations::ActivationConfig& activation_config, - const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_1x1, - const nam::wavenet::Head1x1Params& head1x1_params, const std::string& secondary_activation) + const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, + const int groups_input_mixin, const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, + const nam::activations::ActivationConfig& secondary_activation_config) { auto film_params = make_default_film_params(); return nam::wavenet::LayerArrayParams( input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::move(dilations), activation_config, - gating_mode, head_bias, groups_input, groups_1x1, head1x1_params, secondary_activation, film_params, film_params, - film_params, film_params, film_params, film_params, film_params, film_params, film_params); + gating_mode, head_bias, groups_input, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation_config, + film_params, film_params, film_params, film_params, film_params, film_params, film_params, film_params); } // Helper function to create a LayerArray with default FiLM parameters static nam::wavenet::_LayerArray make_layer_array( const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck, const int kernel_size, const std::vector& dilations, const nam::activations::ActivationConfig& activation_config, - const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_1x1, - const nam::wavenet::Head1x1Params& head1x1_params, const std::string& secondary_activation) + const nam::wavenet::GatingMode gating_mode, const bool head_bias, const int groups_input, + const int groups_input_mixin, const int groups_1x1, const nam::wavenet::Head1x1Params& head1x1_params, + const nam::activations::ActivationConfig& secondary_activation_config) { auto film_params = make_default_film_params(); return nam::wavenet::_LayerArray(input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, - activation_config, gating_mode, head_bias, groups_input, groups_1x1, head1x1_params, - secondary_activation, film_params, film_params, film_params, film_params, - film_params, film_params, film_params, film_params, film_params); + activation_config, gating_mode, head_bias, groups_input, groups_input_mixin, + groups_1x1, head1x1_params, secondary_activation_config, film_params, film_params, + film_params, film_params, film_params, film_params, film_params, film_params); } class TestConfigurableGating @@ -71,16 +74,21 @@ class TestConfigurableGating const int dilation = 1; const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); // Test different gating activation configurations - std::vector gating_activations = {"Sigmoid", "Tanh", "ReLU"}; + std::vector gating_activations = { + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Sigmoid), + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh), + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU)}; for (const auto& gating_act : gating_activations) { auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, - nam::wavenet::GatingMode::GATED, groups_input, groups_1x1, head1x1_params, gating_act); + nam::wavenet::GatingMode::GATED, groups_input, groups_input_mixin, groups_1x1, + head1x1_params, gating_act); // Verify that the layer was created successfully and has correct dimensions assert(layer.get_channels() == channels); @@ -97,17 +105,21 @@ class TestConfigurableGating const int dilation = 1; const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); // Test different blending activation configurations - std::vector blending_activations = {"Sigmoid", "Tanh", "ReLU"}; + std::vector blending_activations = { + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Sigmoid), + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh), + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU)}; for (const auto& blending_act : blending_activations) { - auto layer = - make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, - nam::wavenet::GatingMode::BLENDED, groups_input, groups_1x1, head1x1_params, blending_act); + auto layer = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, + nam::wavenet::GatingMode::BLENDED, groups_input, groups_input_mixin, groups_1x1, + head1x1_params, blending_act); // Verify that the layer was created successfully and has correct dimensions assert(layer.get_channels() == channels); @@ -128,24 +140,29 @@ class TestConfigurableGating const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const bool head_bias = false; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); // Test with different gating activations - auto params_gated = make_layer_array_params( - input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::vector{1, 2}, activation, - nam::wavenet::GatingMode::GATED, head_bias, groups_input, groups_1x1, head1x1_params, "Tanh"); + auto tanh_config = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); + auto params_gated = + make_layer_array_params(input_size, condition_size, head_size, channels, bottleneck, kernel_size, + std::vector{1, 2}, activation, nam::wavenet::GatingMode::GATED, head_bias, + groups_input, groups_input_mixin, groups_1x1, head1x1_params, tanh_config); assert(params_gated.gating_mode == nam::wavenet::GatingMode::GATED); - assert(params_gated.secondary_activation == "Tanh"); + assert(params_gated.secondary_activation_config.type == nam::activations::ActivationType::Tanh); // Test with different blending activations - auto params_blended = make_layer_array_params( - input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::vector{1, 2}, activation, - nam::wavenet::GatingMode::BLENDED, head_bias, groups_input, groups_1x1, head1x1_params, "ReLU"); + auto relu_config = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU); + auto params_blended = + make_layer_array_params(input_size, condition_size, head_size, channels, bottleneck, kernel_size, + std::vector{1, 2}, activation, nam::wavenet::GatingMode::BLENDED, head_bias, + groups_input, groups_input_mixin, groups_1x1, head1x1_params, relu_config); assert(params_blended.gating_mode == nam::wavenet::GatingMode::BLENDED); - assert(params_blended.secondary_activation == "ReLU"); + assert(params_blended.secondary_activation_config.type == nam::activations::ActivationType::ReLU); } static void test_layer_array_construction() @@ -161,12 +178,14 @@ class TestConfigurableGating const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const bool head_bias = false; const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); - auto layer_array = make_layer_array(input_size, condition_size, head_size, channels, bottleneck, kernel_size, - std::vector{1}, activation, nam::wavenet::GatingMode::GATED, head_bias, - groups_input, groups_1x1, head1x1_params, "ReLU"); + auto layer_array = make_layer_array( + input_size, condition_size, head_size, channels, bottleneck, kernel_size, std::vector{1}, activation, + nam::wavenet::GatingMode::GATED, head_bias, groups_input, groups_input_mixin, groups_1x1, head1x1_params, + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU)); // Verify that layers were created correctly by checking receptive field // This should be non-zero for a valid layer array @@ -179,9 +198,9 @@ class TestConfigurableGating // We'll test the parsing logic directly without creating full WaveNet objects // Test the gating mode parsing logic directly - nlohmann::json gated_config = {{"gating_mode", "gated"}, {"secondary_activation", "ReLU"}}; + nlohmann::json gated_config = {{"gating_mode", "gated"}, {"secondary_activation_config", "ReLU"}}; - nlohmann::json blended_config = {{"gating_mode", "blended"}, {"secondary_activation", "Sigmoid"}}; + nlohmann::json blended_config = {{"gating_mode", "blended"}, {"secondary_activation_config", "Sigmoid"}}; nlohmann::json none_config = {{"gating_mode", "none"}}; @@ -223,19 +242,25 @@ class TestConfigurableGating const int dilation = 1; const auto activation = nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh); const int groups_input = 1; + const int groups_input_mixin = 1; const int groups_1x1 = 1; nam::wavenet::Head1x1Params head1x1_params(false, channels, 1); // Create layers with different gating activations auto layer_sigmoid = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, nam::wavenet::GatingMode::GATED, - groups_input, groups_1x1, head1x1_params, "Sigmoid"); + groups_input, groups_input_mixin, groups_1x1, head1x1_params, + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Sigmoid)); - auto layer_tanh = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, - nam::wavenet::GatingMode::GATED, groups_input, groups_1x1, head1x1_params, "Tanh"); + auto layer_tanh = + make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, nam::wavenet::GatingMode::GATED, + groups_input, groups_input_mixin, groups_1x1, head1x1_params, + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::Tanh)); - auto layer_relu = make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, - nam::wavenet::GatingMode::GATED, groups_input, groups_1x1, head1x1_params, "ReLU"); + auto layer_relu = + make_layer(conditionSize, channels, bottleneck, kernelSize, dilation, activation, nam::wavenet::GatingMode::GATED, + groups_input, groups_input_mixin, groups_1x1, head1x1_params, + nam::activations::ActivationConfig::simple(nam::activations::ActivationType::ReLU)); // Set max buffer size for all layers const int num_frames = 10;