Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ Eigen::MatrixXf nam::Conv1x1::process(const Eigen::MatrixXf& input, const int nu
return result;
}

void nam::Conv1x1::process_(const Eigen::MatrixXf& input, const int num_frames)
void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, const int num_frames)
{
assert(num_frames <= _output.cols());

Expand Down
3 changes: 2 additions & 1 deletion NAM/dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ class Conv1x1
Eigen::MatrixXf process(const Eigen::MatrixXf& input) const { return process(input, (int)input.cols()); };
Eigen::MatrixXf process(const Eigen::MatrixXf& input, const int num_frames) const;
// Store output to pre-allocated _output; access with GetOutput()
void process_(const Eigen::MatrixXf& input, const int num_frames);
// Uses Eigen::Ref to accept matrices and block expressions without creating temporaries (real-time safe)
void process_(const Eigen::Ref<const Eigen::MatrixXf>& input, const int num_frames);

long get_out_channels() const { return this->_weight.rows(); };
long get_in_channels() const { return this->_weight.cols(); };
Expand Down
8 changes: 6 additions & 2 deletions NAM/film.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ class FiLM
// :param input: (input_dim x num_frames)
// :param condition: (condition_dim x num_frames)
// Writes (input_dim x num_frames) into internal output buffer; access via GetOutput().
void Process(const Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames)
// Uses Eigen::Ref to accept matrices and block expressions without creating temporaries (real-time safe)
void Process(const Eigen::Ref<const Eigen::MatrixXf>& input, const Eigen::Ref<const Eigen::MatrixXf>& condition,
const int num_frames)
{
assert(get_input_dim() == input.rows());
assert(get_condition_dim() == condition.rows());
Expand All @@ -72,7 +74,9 @@ class FiLM
}

// in-place
void Process_(Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames)
// Uses Eigen::Ref to accept matrices and block expressions without creating temporaries (real-time safe)
void Process_(Eigen::Ref<Eigen::MatrixXf> input, const Eigen::Ref<const Eigen::MatrixXf>& condition,
const int num_frames)
{
Process(input, condition, num_frames);
input.leftCols(num_frames).noalias() = _output.leftCols(num_frames);
Expand Down
42 changes: 23 additions & 19 deletions NAM/gating_activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ class GatingActivation
{
throw std::invalid_argument("GatingActivation: number of input channels must be positive");
}
// Initialize input buffer with correct size
// Initialize buffers with correct size
// Note: current code copies column-by-column so we only need (num_channels, 1)
input_buffer.resize(num_channels, 1);
gating_buffer.resize(num_channels, 1);
}

~GatingActivation() = default;
Expand All @@ -64,23 +65,20 @@ class GatingActivation
assert(output.cols() == input.cols());

// Process column-by-column to ensure memory contiguity (important for column-major matrices)
// Uses pre-allocated buffers to avoid allocations in the loop (real-time safe)
const int num_samples = input.cols();
for (int i = 0; i < num_samples; i++)
{
// Store pre-activation input values in buffer to avoid overwriting issues
// Copy to pre-allocated buffers and apply activations in-place
input_buffer = input.block(0, i, num_channels, 1);
input_activation->apply(input_buffer);

// Apply activation to input channels
Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1);
input_activation->apply(input_block);

// Apply activation to gating channels
Eigen::MatrixXf gating_block = input.block(num_channels, i, num_channels, 1);
gating_activation->apply(gating_block);
gating_buffer = input.block(num_channels, i, num_channels, 1);
gating_activation->apply(gating_buffer);

// Element-wise multiplication and store result
// For wavenet compatibility, we assume one-to-one mapping
output.block(0, i, num_channels, 1) = input_block.array() * gating_block.array();
output.block(0, i, num_channels, 1) = input_buffer.array() * gating_buffer.array();
}
}

Expand All @@ -99,6 +97,7 @@ class GatingActivation
activations::Activation::Ptr gating_activation;
int num_channels;
Eigen::MatrixXf input_buffer;
Eigen::MatrixXf gating_buffer;
};

class BlendingActivation
Expand All @@ -118,9 +117,11 @@ class BlendingActivation
{
assert(num_channels > 0);

// Initialize input buffer with correct size
// Initialize buffers with correct size
// Note: current code copies column-by-column so we only need (num_channels, 1)
pre_activation_buffer.resize(num_channels, 1);
input_buffer.resize(num_channels, 1);
blend_buffer.resize(num_channels, 1);
}

~BlendingActivation() = default;
Expand All @@ -140,23 +141,24 @@ class BlendingActivation
assert(output.cols() == input.cols());

// Process column-by-column to ensure memory contiguity
// Uses pre-allocated buffers to avoid allocations in the loop (real-time safe)
const int num_samples = input.cols();
for (int i = 0; i < num_samples; i++)
{
// Store pre-activation input values in buffer
input_buffer = input.block(0, i, num_channels, 1);
pre_activation_buffer = input.block(0, i, num_channels, 1);

// Apply activation to input channels
Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1);
input_activation->apply(input_block);
// Copy to pre-allocated buffer and apply activation to input channels
input_buffer = input.block(0, i, num_channels, 1);
input_activation->apply(input_buffer);

// Apply activation to blend channels to compute alpha
Eigen::MatrixXf blend_block = input.block(num_channels, i, num_channels, 1);
blending_activation->apply(blend_block);
// Copy to pre-allocated buffer and apply activation to blend channels to compute alpha
blend_buffer = input.block(num_channels, i, num_channels, 1);
blending_activation->apply(blend_buffer);

// Weighted blending: alpha * activated_input + (1 - alpha) * pre_activation_input
output.block(0, i, num_channels, 1) =
blend_block.array() * input_block.array() + (1.0f - blend_block.array()) * input_buffer.array();
blend_buffer.array() * input_buffer.array() + (1.0f - blend_buffer.array()) * pre_activation_buffer.array();
}
}

Expand All @@ -174,7 +176,9 @@ class BlendingActivation
activations::Activation::Ptr input_activation;
activations::Activation::Ptr blending_activation;
int num_channels;
Eigen::MatrixXf pre_activation_buffer;
Eigen::MatrixXf input_buffer;
Eigen::MatrixXf blend_buffer;
};


Expand Down
88 changes: 44 additions & 44 deletions NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize)
this->_activation_pre_film->SetMaxBufferSize(maxBufferSize);
if (this->_activation_post_film)
this->_activation_post_film->SetMaxBufferSize(maxBufferSize);
if (this->_gating_activation_post_film)
this->_gating_activation_post_film->SetMaxBufferSize(maxBufferSize);
if (this->_1x1_post_film)
this->_1x1_post_film->SetMaxBufferSize(maxBufferSize);
if (this->_head1x1_post_film)
Expand Down Expand Up @@ -77,8 +75,6 @@ void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights)
this->_activation_pre_film->set_weights_(weights);
if (this->_activation_post_film)
this->_activation_post_film->set_weights_(weights);
if (this->_gating_activation_post_film)
this->_gating_activation_post_film->set_weights_(weights);
if (this->_1x1_post_film)
this->_1x1_post_film->set_weights_(weights);
if (this->_head1x1_post_film)
Expand Down Expand Up @@ -150,12 +146,12 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
auto input_block = this->_z.leftCols(num_frames);
auto output_block = this->_z.topRows(bottleneck).leftCols(num_frames);
this->_gating_activation->apply(input_block, output_block);
if (this->_gating_activation_post_film)
if (this->_activation_post_film)
{
// Use Process() for blocks and copy result back
this->_gating_activation_post_film->Process(this->_z.topRows(bottleneck), condition, num_frames);
this->_activation_post_film->Process(this->_z.topRows(bottleneck), condition, num_frames);
this->_z.topRows(bottleneck).leftCols(num_frames).noalias() =
this->_gating_activation_post_film->GetOutput().leftCols(num_frames);
this->_activation_post_film->GetOutput().leftCols(num_frames);
}
_1x1.process_(this->_z.topRows(bottleneck), num_frames);
}
Expand Down Expand Up @@ -219,23 +215,23 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
nam::wavenet::_LayerArray::_LayerArray(
const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck,
const int kernel_size, const std::vector<int>& dilations, const activations::ActivationConfig& activation_config,
const GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_1x1,
const Head1x1Params& head1x1_params, const std::string& secondary_activation, const _FiLMParams& conv_pre_film_params,
const GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_input_mixin,
const int groups_1x1, const Head1x1Params& head1x1_params,
const activations::ActivationConfig& secondary_activation_config, const _FiLMParams& conv_pre_film_params,
const _FiLMParams& conv_post_film_params, const _FiLMParams& input_mixin_pre_film_params,
const _FiLMParams& input_mixin_post_film_params, const _FiLMParams& activation_pre_film_params,
const _FiLMParams& activation_post_film_params, const _FiLMParams& gating_activation_post_film_params,
const _FiLMParams& _1x1_post_film_params, const _FiLMParams& head1x1_post_film_params)
const _FiLMParams& activation_post_film_params, const _FiLMParams& _1x1_post_film_params,
const _FiLMParams& head1x1_post_film_params)
: _rechannel(input_size, channels, false)
, _head_rechannel(bottleneck, head_size, head_bias)
, _bottleneck(bottleneck)
, _head_rechannel(head1x1_params.active ? head1x1_params.out_channels : bottleneck, head_size, head_bias)
, _head_output_size(head1x1_params.active ? head1x1_params.out_channels : bottleneck)
{
for (size_t i = 0; i < dilations.size(); i++)
this->_layers.push_back(_Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation_config,
gating_mode, groups_input, groups_1x1, head1x1_params, secondary_activation,
conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params,
input_mixin_post_film_params, activation_pre_film_params,
activation_post_film_params, gating_activation_post_film_params,
_1x1_post_film_params, head1x1_post_film_params));
this->_layers.push_back(
_Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation_config, gating_mode,
groups_input, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation_config,
conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params,
activation_pre_film_params, activation_post_film_params, _1x1_post_film_params, head1x1_post_film_params));
}

void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize)
Expand All @@ -249,7 +245,8 @@ void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize)
// Pre-allocate output buffers
const long channels = this->_get_channels();
this->_layer_outputs.resize(channels, maxBufferSize);
this->_head_inputs.resize(this->_bottleneck, maxBufferSize);
// _head_inputs size matches actual head output: head1x1.out_channels if active, else bottleneck
this->_head_inputs.resize(this->_head_output_size, maxBufferSize);
}


Expand Down Expand Up @@ -386,12 +383,12 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels,
layer_array_params[i].input_size, layer_array_params[i].condition_size, layer_array_params[i].head_size,
layer_array_params[i].channels, layer_array_params[i].bottleneck, layer_array_params[i].kernel_size,
layer_array_params[i].dilations, layer_array_params[i].activation_config, layer_array_params[i].gating_mode,
layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_1x1,
layer_array_params[i].head1x1_params, layer_array_params[i].secondary_activation,
layer_array_params[i].conv_pre_film_params, layer_array_params[i].conv_post_film_params,
layer_array_params[i].input_mixin_pre_film_params, layer_array_params[i].input_mixin_post_film_params,
layer_array_params[i].activation_pre_film_params, layer_array_params[i].activation_post_film_params,
layer_array_params[i].gating_activation_post_film_params, layer_array_params[i]._1x1_post_film_params,
layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_input_mixin,
layer_array_params[i].groups_1x1, layer_array_params[i].head1x1_params,
layer_array_params[i].secondary_activation_config, layer_array_params[i].conv_pre_film_params,
layer_array_params[i].conv_post_film_params, layer_array_params[i].input_mixin_pre_film_params,
layer_array_params[i].input_mixin_post_film_params, layer_array_params[i].activation_pre_film_params,
layer_array_params[i].activation_post_film_params, layer_array_params[i]._1x1_post_film_params,
layer_array_params[i].head1x1_post_film_params));
if (i > 0)
if (layer_array_params[i].channels != layer_array_params[i - 1].head_size)
Expand Down Expand Up @@ -583,7 +580,8 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
{
nlohmann::json layer_config = config["layers"][i];

const int groups = layer_config.value("groups", 1); // defaults to 1
const int groups = layer_config.value("groups_input", 1); // defaults to 1
const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); // defaults to 1
const int groups_1x1 = layer_config.value("groups_1x1", 1); // defaults to 1

const int channels = layer_config["channels"];
Expand All @@ -599,25 +597,25 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
activations::ActivationConfig::from_json(layer_config["activation"]);
// Parse gating mode - support both old "gated" boolean and new "gating_mode" string
GatingMode gating_mode = GatingMode::NONE;
std::string secondary_activation;
activations::ActivationConfig secondary_activation_config;

if (layer_config.find("gating_mode") != layer_config.end())
{
std::string gating_mode_str = layer_config["gating_mode"].get<std::string>();
if (gating_mode_str == "gated")
{
gating_mode = GatingMode::GATED;
secondary_activation = layer_config["secondary_activation"].get<std::string>();
secondary_activation_config = activations::ActivationConfig::from_json(layer_config["secondary_activation"]);
}
else if (gating_mode_str == "blended")
{
gating_mode = GatingMode::BLENDED;
secondary_activation = layer_config["secondary_activation"].get<std::string>();
secondary_activation_config = activations::ActivationConfig::from_json(layer_config["secondary_activation"]);
}
else if (gating_mode_str == "none")
{
gating_mode = GatingMode::NONE;
secondary_activation.clear();
// Leave secondary_activation_config with empty type
}
else
throw std::runtime_error("Invalid gating_mode: " + gating_mode_str);
Expand All @@ -629,12 +627,9 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
gating_mode = gated ? GatingMode::GATED : GatingMode::NONE;
if (gated)
{
secondary_activation = "Sigmoid";
}
else
{
secondary_activation.clear();
secondary_activation_config = activations::ActivationConfig::simple(activations::ActivationType::Sigmoid);
}
// else: leave secondary_activation_config uninitialized
}
else
{
Expand All @@ -644,9 +639,16 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
const bool head_bias = layer_config["head_bias"];

// Parse head1x1 parameters
bool head1x1_active = layer_config.value("head1x1_active", false);
int head1x1_out_channels = layer_config.value("head1x1_out_channels", channels);
int head1x1_groups = layer_config.value("head1x1_groups", 1);
bool head1x1_active = false;
int head1x1_out_channels = channels;
int head1x1_groups = 1;
if (layer_config.find("head_1x1") != layer_config.end())
{
const auto& head1x1_config = layer_config["head_1x1"];
head1x1_active = head1x1_config["active"];
head1x1_out_channels = head1x1_config["out_channels"];
head1x1_groups = head1x1_config["groups"];
}
nam::wavenet::Head1x1Params head1x1_params(head1x1_active, head1x1_out_channels, head1x1_groups);

// Helper function to parse FiLM parameters
Expand All @@ -668,16 +670,14 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
nam::wavenet::_FiLMParams input_mixin_post_film_params = parse_film_params("input_mixin_post_film");
nam::wavenet::_FiLMParams activation_pre_film_params = parse_film_params("activation_pre_film");
nam::wavenet::_FiLMParams activation_post_film_params = parse_film_params("activation_post_film");
nam::wavenet::_FiLMParams gating_activation_post_film_params = parse_film_params("gating_activation_post_film");
nam::wavenet::_FiLMParams _1x1_post_film_params = parse_film_params("1x1_post_film");
nam::wavenet::_FiLMParams head1x1_post_film_params = parse_film_params("head1x1_post_film");

layer_array_params.push_back(nam::wavenet::LayerArrayParams(
input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation_config,
gating_mode, head_bias, groups, groups_1x1, head1x1_params, secondary_activation, conv_pre_film_params,
conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params,
activation_post_film_params, gating_activation_post_film_params, _1x1_post_film_params,
head1x1_post_film_params));
gating_mode, head_bias, groups, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation_config,
conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params,
activation_pre_film_params, activation_post_film_params, _1x1_post_film_params, head1x1_post_film_params));
}
const bool with_head = !config["head"].is_null();
const float head_scale = config["head_scale"];
Expand Down
Loading