diff --git a/ci_test/unit_tests/test_unit_layer_layer_norm_distconv.py b/ci_test/unit_tests/test_unit_layer_layer_norm_distconv.py new file mode 100644 index 00000000000..19e586dc8a2 --- /dev/null +++ b/ci_test/unit_tests/test_unit_layer_layer_norm_distconv.py @@ -0,0 +1,185 @@ +import functools +import operator +import os +import os.path +import sys +import math +import numpy as np + +# Bamboo utilities +current_file = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file) +sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python')) +import tools + +# ============================================== +# Objects for Python data reader +# ============================================== +# Note: The Python data reader imports this file as a module and calls +# the functions below to ingest data. + +# Data +np.random.seed(20191114) +_num_samples = 31 +_sample_shape = (17, 4, 2) +_sample_size = math.prod(_sample_shape) +_samples = np.random.normal(size=(_num_samples, _sample_size)).astype(np.float32) + +# Sample access functions +def get_sample(index): + return _samples[index,:] +def num_samples(): + return _num_samples +def sample_dims(): + return (_sample_size,) + +# ============================================== +# NumPy softmax +# ============================================== + +def numpy_layer_norm(x, epsilon=1e-5): + if x.dtype is not np.float64: + x = x.astype(np.float64) + mean = np.mean(x) + var = np.var(x, ddof=0) + return (x - mean) / np.sqrt(var + epsilon) + +# ============================================== +# Setup LBANN experiment +# ============================================== + +def setup_experiment(lbann, weekly): + """Construct LBANN experiment. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + mini_batch_size = num_samples() // 2 + trainer = lbann.Trainer(mini_batch_size) + model = construct_model(lbann) + data_reader = construct_data_reader(lbann) + optimizer = lbann.NoOptimizer() + return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes + +def create_parallel_strategy(num_groups): + return {"channel_groups": num_groups} + +def construct_model(lbann): + """Construct LBANN model. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + # Input data + # Note: Sum with a weights layer so that gradient checking will + # verify that error signals are correct. + x_weights = lbann.Weights(optimizer=lbann.SGD(), + initializer=lbann.ConstantInitializer(value=0.0), + name='input_weights') + x = lbann.Sum(lbann.Reshape(lbann.Input(data_field='samples'), + dims=_sample_size), + lbann.WeightsLayer(weights=x_weights, + dims=_sample_size)) + x_lbann = x + + # Objects for LBANN model + obj = [] + metrics = [] + callbacks = [] + + # ------------------------------------------ + # Data-parallel layout + # ------------------------------------------ + + num_groups = tools.gpus_per_node(lbann) + # LBANN implementation + x = x_lbann + # The input tensor must be 4-D so reshape and add an extra axis + x = lbann.Reshape(x, dims=_sample_shape+(1, )) + y = lbann.LayerNorm(x, + data_layout='data_parallel', + parallel_strategy=create_parallel_strategy( + num_groups)) + z = lbann.L2Norm2(y) + obj.append(z) + metrics.append(lbann.Metric(z, name='data-parallel layout')) + + # NumPy implementation + vals = [] + for i in range(num_samples()): + x = get_sample(i).astype(np.float64) + y = numpy_layer_norm(x) + z = tools.numpy_l2norm2(y) + vals.append(z) + val = np.mean(vals) + tol = 8 * val * np.finfo(np.float32).eps + callbacks.append(lbann.CallbackCheckMetric( + metric=metrics[-1].name, + lower_bound=val-tol, + upper_bound=val+tol, + error_on_failure=True, + execution_modes='test')) + + # ------------------------------------------ + # Gradient checking + # ------------------------------------------ + + callbacks.append(lbann.CallbackCheckGradients(error_on_failure=True)) + + # ------------------------------------------ + # Construct model + # ------------------------------------------ + + num_epochs = 0 + return lbann.Model(num_epochs, + layers=lbann.traverse_layer_graph(x_lbann), + objective_function=obj, + metrics=metrics, + callbacks=callbacks) + +def construct_data_reader(lbann): + """Construct Protobuf message for Python data reader. + + The Python data reader will import the current Python file to + access the sample access functions. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + # Note: The training data reader should be removed when + # https://github.com/LLNL/lbann/issues/1098 is resolved. + message = lbann.reader_pb2.DataReader() + message.reader.extend([ + tools.create_python_data_reader( + lbann, + current_file, + 'get_sample', + 'num_samples', + 'sample_dims', + 'train' + ) + ]) + message.reader.extend([ + tools.create_python_data_reader( + lbann, + current_file, + 'get_sample', + 'num_samples', + 'sample_dims', + 'test' + ) + ]) + return message + +# ============================================== +# Setup PyTest +# ============================================== + +# Create test functions that can interact with PyTest +for _test_func in tools.create_tests(setup_experiment, __file__): + globals()[_test_func.__name__] = _test_func diff --git a/include/lbann/layers/regularizers/CMakeLists.txt b/include/lbann/layers/regularizers/CMakeLists.txt index 01c84147c62..2f52c97a4ed 100644 --- a/include/lbann/layers/regularizers/CMakeLists.txt +++ b/include/lbann/layers/regularizers/CMakeLists.txt @@ -34,5 +34,9 @@ set_full_path(THIS_DIR_HEADERS selu_dropout.hpp ) +if (LBANN_HAS_DISTCONV) + add_subdirectory(distconv) +endif() + # Propagate the files up the tree set(HEADERS "${HEADERS}" "${THIS_DIR_HEADERS}" PARENT_SCOPE) diff --git a/include/lbann/layers/regularizers/distconv/CMakeLists.txt b/include/lbann/layers/regularizers/distconv/CMakeLists.txt new file mode 100644 index 00000000000..1901867017e --- /dev/null +++ b/include/lbann/layers/regularizers/distconv/CMakeLists.txt @@ -0,0 +1,31 @@ +################################################################################ +## Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +## Produced at the Lawrence Livermore National Laboratory. +## Written by the LBANN Research Team (B. Van Essen, et al.) listed in +## the CONTRIBUTORS file. +## +## LLNL-CODE-697807. +## All rights reserved. +## +## This file is part of LBANN: Livermore Big Artificial Neural Network +## Toolkit. For details, see http://software.llnl.gov/LBANN or +## https://github.com/LLNL/LBANN. +## +## Licensed under the Apache License, Version 2.0 (the "Licensee"); you +## may not use this file except in compliance with the License. You may +## obtain a copy of the License at: +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +## implied. See the License for the specific language governing +## permissions and limitations under the license. +################################################################################ +set_full_path(THIS_DIR_HEADERS + distconv_layer_norm.hpp + ) + +# Propagate the files up the tree +set(HEADERS "${HEADERS}" "${THIS_DIR_HEADERS}" PARENT_SCOPE) \ No newline at end of file diff --git a/include/lbann/layers/regularizers/distconv/distconv_layer_norm.hpp b/include/lbann/layers/regularizers/distconv/distconv_layer_norm.hpp new file mode 100644 index 00000000000..3206d7b87ac --- /dev/null +++ b/include/lbann/layers/regularizers/distconv/distconv_layer_norm.hpp @@ -0,0 +1,78 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#ifndef LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM +#define LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM + +#ifdef LBANN_HAS_DISTCONV + +namespace distconv { +template +class LayerNormalization +{ + using LocaleMPI = tensor::LocaleMPI; + + template + using DCTensor = tensor::Tensor; + +public: + LayerNormalization(Backend& backend, DataType epsilon) + : m_backend(backend), m_epsilon(epsilon) + {} + + template + void calculate_forward_stats(const DCTensor& input, + DCTensor& statistics); + + template + void apply_normalization(const DCTensor& input, + DCTensor& statistics, + DCTensor& output); + + template + void calculate_backward_stats(const DCTensor& input, + const DCTensor& output_grad, + const DCTensor& statistics, + DCTensor& statistics_grad); + + template + void apply_grad(const DCTensor& input, + const DCTensor& output_grad, + const DCTensor& statistics, + const DCTensor& statistics_grad, + DCTensor& input_grad); + +protected: + Backend& m_backend; + +private: + DataType m_epsilon; + +}; // class definition LayerNorm +} // namespace distconv + +#endif // LBANN_HAS_DISTCONV +#endif // LBANN_LAYERS_REGULARIZERS_DISTCONV_LAYER_NORM \ No newline at end of file diff --git a/include/lbann/layers/regularizers/layer_norm.hpp b/include/lbann/layers/regularizers/layer_norm.hpp index e3542dbdcbf..15c56cb590d 100644 --- a/include/lbann/layers/regularizers/layer_norm.hpp +++ b/include/lbann/layers/regularizers/layer_norm.hpp @@ -35,8 +35,50 @@ #include "lbann/proto/layers.pb.h" #include +#ifdef LBANN_HAS_DISTCONV +#include "lbann/layers/data_type_distconv_adapter.hpp" +#include "lbann/layers/regularizers/distconv/distconv_layer_norm.hpp" +#include "lbann/utils/distconv.hpp" +#endif // LBANN_HAS_DISTCONV + namespace lbann { +#ifdef LBANN_HAS_DISTCONV +namespace dc { +using Shape = ::distconv::tensor::Shape; +using Backend = ::distconv::BackendDNNLib; +template +using LayerNormalization = + ::distconv::LayerNormalization; +} // namespace dc + +template +class layer_norm_distconv_adapter + : public data_type_distconv_adapter +{ + +public: + using TensorDevType = + typename data_type_distconv_adapter::TensorDevType; + + layer_norm_distconv_adapter(Layer& layer) + : data_type_distconv_adapter(layer) + {} + virtual ~layer_norm_distconv_adapter() = default; + + void setup_distributions(tensor_overlap_constraints& constraints) override; + void setup_layer(size_t workspace_capacity) override; + + void fp_compute(); + void bp_compute(); + + TensorDevType m_statistics; + TensorDevType m_statistics_grad; + std::unique_ptr> m_layer_norm_operator; +}; // class definition channelwise_fully_connected_distconv_adapter + +#endif // LBANN_HAS_DISTCONV + /** @brief Normalize over data samples * * Each data sample is normalized to have zero mean and unit standard @@ -103,6 +145,18 @@ class layer_norm_layer : public data_type_layer void fp_compute() override; void bp_compute() override; +#ifdef LBANN_HAS_DISTCONV + friend class layer_norm_distconv_adapter; + +protected: + void setup_distconv_adapter(const DataReaderMetaData& dr_metadata) override; + bool is_distconv_supported() const override; + layer_norm_distconv_adapter& + get_distconv_adapter() override; + const layer_norm_distconv_adapter& + get_distconv_adapter() const override; +#endif // LBANN_HAS_DISTCONV + private: using AbsDistMatType = El::AbstractDistMatrix; @@ -359,6 +413,7 @@ void layer_norm_layer::setup_data( } } + template void layer_norm_layer::get_normdims( El::Int& normalization_size, @@ -404,11 +459,54 @@ void layer_norm_layer::get_normdims( global_normalization_size = normalization_size; } +#ifdef LBANN_HAS_DISTCONV + +// ============================================================= +// DistConv-enabled Scatter member functions +// ============================================================= + +template +bool +layer_norm_layer +::is_distconv_supported() const { + return Device==El::Device::GPU && Layout == data_layout::DATA_PARALLEL; +} + +template +void +layer_norm_layer +::setup_distconv_adapter(const DataReaderMetaData& dr_metadata){ + this->get_distconv_adapter_ptr() = std::make_unique>(*this); +} + +template +const layer_norm_distconv_adapter & +layer_norm_layer +::get_distconv_adapter() const{ + return dynamic_cast&>(data_type_layer::get_distconv_adapter()); +} + +template +layer_norm_distconv_adapter & +layer_norm_layer +::get_distconv_adapter(){ + return const_cast&>( + static_cast&>(*this).get_distconv_adapter()); + + +// ============================================================= +// Scatter DistConv Adapter implementation +// ============================================================= + +#endif // LBANN_HAS_DISTCONV + LBANN_DEFINE_LAYER_BUILDER(layer_norm); -// ========================================================= -// Explicit template instantiation -// ========================================================= + // ========================================================= + // Explicit template instantiation + // ========================================================= #ifndef LBANN_LAYER_NORM_LAYER_INSTANTIATE #define PROTO_DEVICE(T, Device) \ diff --git a/include/lbann/layers/regularizers/layer_norm_impl.hpp b/include/lbann/layers/regularizers/layer_norm_impl.hpp new file mode 100644 index 00000000000..14d7076ccbb --- /dev/null +++ b/include/lbann/layers/regularizers/layer_norm_impl.hpp @@ -0,0 +1,224 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#ifndef LBANN_LAYER_REGULARIZER_LAYER_NORM_IMPL_HPP_INCLUDED +#define LBANN_LAYER_REGULARIZER_LAYER_NORM_IMPL_HPP_INCLUDED + +#include "lbann/layers/regularizers/layer_norm.hpp" + +#ifdef LBANN_HAS_DISTONV +#include "lbann/layers/data_type_distconv_adapter.hpp" +#endif + +namespace lbann { + +// ========================================================= +// Implementation +// ========================================================= + +template +void layer_norm_layer::write_specific_proto( + lbann_data::Layer& proto) const +{ + proto.set_datatype(proto::ProtoDataType); + auto* msg = proto.mutable_layer_norm(); + msg->mutable_epsilon()->set_value(m_epsilon); +} + +template +layer_norm_layer::layer_norm_layer( + TensorDataType epsilon) + : data_type_layer(nullptr), m_epsilon(epsilon) +{} + +template +layer_norm_layer::layer_norm_layer( + const layer_norm_layer& other) + : data_type_layer(other), + m_epsilon(other.m_epsilon), + m_statistics(other.m_statistics ? other.m_statistics->Copy() : nullptr), + m_statistics_gradient(other.m_statistics_gradient + ? other.m_statistics_gradient->Copy() + : nullptr) +{} + +template +layer_norm_layer& +layer_norm_layer::operator=( + const layer_norm_layer& other) +{ + data_type_layer::operator=(other); + m_epsilon = other.m_epsilon; + m_statistics.reset(other.m_statistics ? other.m_statistics->Copy() : nullptr); + m_statistics_gradient.reset(other.m_statistics_gradient + ? other.m_statistics_gradient->Copy() + : nullptr); + return *this; +} + +template +layer_norm_layer* +layer_norm_layer::copy() const +{ + return new layer_norm_layer(*this); +} + +template +std::string layer_norm_layer::get_type() const +{ + return "layer norm"; +} + +template +data_layout +layer_norm_layer::get_data_layout() const +{ + return Layout; +} + +template +El::Device +layer_norm_layer::get_device_allocation() const +{ + return Device; +} + +template +description +layer_norm_layer::get_description() const +{ + auto desc = data_type_layer::get_description(); + desc.add("Epsilon", m_epsilon); + return desc; +} + +template +void layer_norm_layer::setup_dims( + DataReaderMetaData& dr_metadata) +{ + data_type_layer::setup_dims(dr_metadata); + this->set_output_dims(this->get_input_dims()); +} + +template +void layer_norm_layer::setup_data( + size_t max_mini_batch_size) +{ + data_type_layer::setup_data(max_mini_batch_size); + auto dist = this->get_prev_activations().DistData(); + dist.colDist = El::STAR; + m_statistics.reset(AbsDistMatrixType::Instantiate(dist)); + m_statistics_gradient.reset(AbsDistMatrixType::Instantiate(dist)); +} + +#ifdef LBANN_HAS_DISTCONV + +// ============================================================= +// DistConv-enabled Scatter member functions +// ============================================================= + +template +bool layer_norm_layer::is_distconv_supported() + const +{ + return Device == El::Device::GPU && Layout == data_layout::DATA_PARALLEL; +} + +template +void layer_norm_layer::setup_distconv_adapter( + const DataReaderMetaData& dr_metadata) +{ + this->get_distconv_adapter_ptr() = std::make_unique< + layer_norm_distconv_adapter>(*this); +} + +template +const layer_norm_distconv_adapter& +layer_norm_layer::get_distconv_adapter() const +{ + return dynamic_cast< + const layer_norm_distconv_adapter&>( + data_type_layer::get_distconv_adapter()); +} + +template +layer_norm_distconv_adapter& +layer_norm_layer::get_distconv_adapter() +{ + return const_cast< + layer_norm_distconv_adapter&>( + static_cast&>(*this) + .get_distconv_adapter()); +} + +// ============================================================= +// LayerNorm DistConv Adapter implementation +// ============================================================= + +template +void layer_norm_distconv_adapter:: + setup_distributions(tensor_overlap_constraints& constraints) +{ + data_type_distconv_adapter::setup_distributions(constraints); + // no overlap needed + for (auto& d : this->m_prev_activations_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } + for (auto& d : this->m_activations_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } + for (auto& d : this->m_prev_error_signals_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } + for (auto& d : this->m_error_signals_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } +} + +template +void layer_norm_distconv_adapter::setup_layer( + size_t workspace_capacity) +{ + data_type_distconv_adapter::setup_layer(workspace_capacity); + auto& layer = dynamic_cast&>( + this->layer()); + + m_layer_norm_operator = + make_unique>(dc::get_backend(), + layer.m_epsilon); +} + +#endif // LBANN_HAS_DISTCONV +} // namespace lbann +#endif // LBANN_LAYER_REGULARIZER_LAYER_NORM_IMPL_HPP_INCLUDED \ No newline at end of file diff --git a/src/layers/regularizers/CMakeLists.txt b/src/layers/regularizers/CMakeLists.txt index a012f373525..3e2aaec4e78 100644 --- a/src/layers/regularizers/CMakeLists.txt +++ b/src/layers/regularizers/CMakeLists.txt @@ -48,6 +48,10 @@ if (LBANN_HAS_GPU) ) endif () +if (LBANN_HAS_DISTCONV) + add_subdirectory(distconv) +endif() + add_subdirectory(cereal_registration) # Propagate the files up the tree diff --git a/src/layers/regularizers/distconv/CMakeLists.txt b/src/layers/regularizers/distconv/CMakeLists.txt new file mode 100644 index 00000000000..6b148db6bcb --- /dev/null +++ b/src/layers/regularizers/distconv/CMakeLists.txt @@ -0,0 +1,31 @@ +################################################################################ +## Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +## Produced at the Lawrence Livermore National Laboratory. +## Written by the LBANN Research Team (B. Van Essen, et al.) listed in +## the CONTRIBUTORS file. +## +## LLNL-CODE-697807. +## All rights reserved. +## +## This file is part of LBANN: Livermore Big Artificial Neural Network +## Toolkit. For details, see http://software.llnl.gov/LBANN or +## https://github.com/LLNL/LBANN. +## +## Licensed under the Apache License, Version 2.0 (the "Licensee"); you +## may not use this file except in compliance with the License. You may +## obtain a copy of the License at: +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +## implied. See the License for the specific language governing +## permissions and limitations under the license. +################################################################################ +set_full_path(THIS_DIR_CU_SOURCES + distconv_layer_norm.cu + ) + +# Propagate the files up the tree +set(GPU_SOURCES "${GPU_SOURCES}" "${THIS_DIR_CU_SOURCES}" PARENT_SCOPE) diff --git a/src/layers/regularizers/distconv/distconv_layer_norm.cu b/src/layers/regularizers/distconv/distconv_layer_norm.cu new file mode 100644 index 00000000000..19788003f68 --- /dev/null +++ b/src/layers/regularizers/distconv/distconv_layer_norm.cu @@ -0,0 +1,380 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#define LBANN_LAYERS_REGULARIZERS_DISTCONV_LAYER_NORM_INSTANTIATE + +#include "../layer_norm_kernels.cuh" +#include "lbann/layers/regularizers/distconv/distconv_layer_norm.hpp" + +#ifdef LBANN_HAS_DISTCONV + +namespace distconv { + +template +template +void LayerNormalization::calculate_forward_stats( + const DCTensor& input, + DCTensor& statistics) +{ + if (input.get_local_size() == 0) { + util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; + return; // no op for empty inputs + } + + const auto& input_dims = input.get_local_shape(); + const auto& statistics_dims = statistics.get_local_shape(); + const auto local_num_samples = input_dims[3]; + const auto global_num_samples = statistics_dims[3]; + + const auto local_sample_size = std::accumulate(input_dims.begin(), + input_dims.end() - 1, + 1, + std::multiplies()); + + using LocalMat = El::Matrix; + LocalMat local_input(local_sample_size, + local_num_samples, + input.get_buffer(), + local_sample_size); + + LocalMat local_statistics(2, global_num_samples, statistics.get_buffer(), 2); + + El::Zero(local_statistics); + auto local_means = El::View(local_statistics, El::IR(0), El::ALL); + auto local_vars = El::View(local_statistics, El::IR(1), El::ALL); + + { + using namespace hydrogen; + auto multisync = El::MakeMultiSync(El::SyncInfoFromMatrix(local_statistics), + El::SyncInfoFromMatrix(local_input)); + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (local_sample_size + block_size - 1) / block_size; + grid_dims.y = local_num_samples; + hydrogen::gpu::LaunchKernel( + ::lbann::layer_norm_fp_sums_kernel, + grid_dims, + block_dims, + 0, + multisync, + local_num_samples, + local_sample_size, + local_input.LockedBuffer(), + local_input.LDim(), + local_means.Buffer(), + local_means.LDim(), + local_vars.Buffer(), + local_vars.LDim()); + } +} + +template +template +void LayerNormalization::apply_normalization( + const DCTensor& input, + DCTensor& statistics, + DCTensor& output) +{ + const auto& input_dims = input.get_local_shape(); + const auto& statistics_dims = statistics.get_local_shape(); + const auto local_num_samples = input_dims[3]; + const auto global_num_samples = statistics_dims[3]; + const auto local_sample_size = std::accumulate(input_dims.begin(), + input_dims.end() - 1, + 1, + std::multiplies()); + + using LocalMat = El::Matrix; + const LocalMat local_input(local_sample_size, + local_num_samples, + input.get_buffer(), + local_sample_size); + + LocalMat local_statistics(2, global_num_samples, statistics.get_buffer(), 2); + + LocalMat local_output(local_sample_size, + local_num_samples, + output.get_buffer(), + local_sample_size); + + auto local_means = El::View(local_statistics, El::IR(0), El::ALL); + auto local_vars = El::View(local_statistics, El::IR(1), El::ALL); + + { + using namespace hydrogen; + auto sync_info = El::SyncInfoFromMatrix(local_statistics); + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (local_num_samples + block_size - 1) / block_size; + hydrogen::gpu::LaunchKernel( + ::lbann::layer_norm_fp_statistics_kernel, + grid_dims, + block_dims, + 0, + sync_info, + local_sample_size, + local_num_samples, + local_means.Buffer(), + local_means.LDim(), + local_vars.Buffer(), + local_vars.LDim()); + + auto multisync = El::MakeMultiSync(El::SyncInfoFromMatrix(local_output), + El::SyncInfoFromMatrix(local_statistics), + El::SyncInfoFromMatrix(local_input)); + + constexpr size_t block_size_output_kernel = 256; + dim3 block_dims_output_kernel, grid_dims_output_kernel; + block_dims_output_kernel.x = block_size_output_kernel; + grid_dims_output_kernel.x = + (local_sample_size + block_size - 1) / block_size_output_kernel; + grid_dims_output_kernel.y = local_num_samples; + hydrogen::gpu::LaunchKernel(::lbann::layer_norm_fp_output_kernel, + grid_dims_output_kernel, + block_dims_output_kernel, + 0, + multisync, + local_num_samples, + local_sample_size, + m_epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output.Buffer(), + local_output.LDim(), + local_means.Buffer(), + local_means.LDim(), + local_vars.Buffer(), + local_vars.LDim()); + } +} + +template +template +void LayerNormalization::calculate_backward_stats( + const DCTensor& input, + const DCTensor& output_grad, + const DCTensor& statistics, + DCTensor& statistics_grad) +{ + + const auto& input_dims = input.get_local_shape(); + const auto& statistics_dims = statistics.get_local_shape(); + const auto local_num_samples = input_dims[3]; + const auto global_num_samples = statistics_dims[3]; + const auto local_sample_size = std::accumulate(input_dims.begin(), + input_dims.end() - 1, + 1, + std::multiplies()); + using LocalMat = El::Matrix; + const LocalMat local_input(local_sample_size, + local_num_samples, + input.get_buffer(), + local_sample_size); + const LocalMat local_output_grad(local_sample_size, + local_num_samples, + output_grad.get_buffer(), + local_sample_size); + + const LocalMat local_statistics(2, + global_num_samples, + statistics.get_buffer(), + 2); + + LocalMat local_statistics_grad(2, + global_num_samples, + statistics_grad.get_buffer(), + 2); + const auto local_means = El::LockedView(local_statistics, El::IR(0), El::ALL); + const auto local_vars = El::LockedView(local_statistics, El::IR(1), El::ALL); + + auto local_means_grad = El::View(local_statistics_grad, El::IR(0), El::ALL); + auto local_vars_grad = El::View(local_statistics_grad, El::IR(1), El::ALL); + + { + using namespace hydrogen; + auto multisync = + El::MakeMultiSync(El::SyncInfoFromMatrix(local_statistics_grad), + El::SyncInfoFromMatrix(local_output_grad), + El::SyncInfoFromMatrix(local_statistics), + El::SyncInfoFromMatrix(local_input)); + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (local_sample_size + block_size - 1) / block_size; + grid_dims.y = local_num_samples; + hydrogen::gpu::LaunchKernel( + ::lbann::layer_norm_bp_statistics_grad_kernel, + grid_dims, + block_dims, + 0, + multisync, + local_num_samples, + local_sample_size, + m_epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output_grad.LockedBuffer(), + local_output_grad.LDim(), + local_means.LockedBuffer(), + local_means.LDim(), + local_vars.LockedBuffer(), + local_vars.LDim(), + local_means_grad.Buffer(), + local_means_grad.LDim(), + local_vars_grad.Buffer(), + local_vars_grad.LDim()); + } +} + +template +template +void LayerNormalization::apply_grad( + const DCTensor& input, + const DCTensor& output_grad, + const DCTensor& statistics, + const DCTensor& statistics_grad, + DCTensor& input_grad) +{ + const auto& input_dims = input.get_local_shape(); + const auto& statistics_dims = statistics.get_local_shape(); + const auto local_num_samples = input_dims[3]; + const auto global_num_samples = statistics_dims[3]; + const auto local_sample_size = std::accumulate(input_dims.begin(), + input_dims.end() - 1, + 1, + std::multiplies()); + + const auto global_sample_size = local_sample_size; + + using LocalMat = El::Matrix; + const LocalMat local_input(local_sample_size, + local_num_samples, + input.get_buffer(), + local_sample_size); + const LocalMat local_output_grad(local_sample_size, + local_num_samples, + output_grad.get_buffer(), + local_sample_size); + + const LocalMat local_statistics(2, + global_num_samples, + statistics.get_buffer(), + 2); + + const LocalMat local_statistics_grad(2, + global_num_samples, + statistics_grad.get_buffer(), + 2); + + LocalMat local_input_grad(local_sample_size, + local_num_samples, + input_grad.get_buffer(), + local_sample_size); + + const auto local_means = El::LockedView(local_statistics, El::IR(0), El::ALL); + const auto local_vars = El::LockedView(local_statistics, El::IR(1), El::ALL); + const auto local_means_grad = + El::LockedView(local_statistics_grad, El::IR(0), El::ALL); + const auto local_vars_grad = + El::LockedView(local_statistics_grad, El::IR(1), El::ALL); + + { + using namespace hydrogen; + auto multisync = + El::MakeMultiSync(El::SyncInfoFromMatrix(local_statistics_grad), + El::SyncInfoFromMatrix(local_output_grad), + El::SyncInfoFromMatrix(local_statistics), + El::SyncInfoFromMatrix(local_input)); + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (local_sample_size + block_size - 1) / block_size; + grid_dims.y = local_num_samples; + hydrogen::gpu::LaunchKernel( + ::lbann::layer_norm_bp_input_grad_kernel, + grid_dims, + block_dims, + 0, + multisync, + global_sample_size, + local_num_samples, + local_sample_size, + m_epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output_grad.LockedBuffer(), + local_output_grad.LDim(), + local_input_grad.Buffer(), + local_input_grad.LDim(), + local_means.LockedBuffer(), + local_means.LDim(), + local_vars.LockedBuffer(), + local_vars.LDim(), + local_means_grad.LockedBuffer(), + local_means_grad.LDim(), + local_vars_grad.LockedBuffer(), + local_vars_grad.LDim()); + } +} + +#define ETI(T, Backend) \ + template class LayerNormalization; \ + template void LayerNormalization::calculate_forward_stats< \ + tensor::CUDAAllocator>( \ + const tensor::Tensor& input, \ + tensor::Tensor& statistics); \ + template void \ + LayerNormalization::apply_normalization( \ + const tensor::Tensor& input, \ + tensor::Tensor& statistics, \ + tensor::Tensor& output); \ + template void LayerNormalization::calculate_backward_stats< \ + tensor::CUDAAllocator>( \ + const tensor::Tensor& input, \ + const tensor::Tensor& \ + output_grad, \ + const tensor::Tensor& \ + statistics, \ + tensor::Tensor& \ + statistics_grad); \ + template void \ + LayerNormalization::apply_grad( \ + const tensor::Tensor& input, \ + const tensor::Tensor& \ + output_grad, \ + const tensor::Tensor& \ + statistics, \ + const tensor::Tensor& \ + statistics_grad, \ + tensor::Tensor& input_grad); + +ETI(float, BackendDNNLib) +ETI(double, BackendDNNLib) +#undef ETI +} // namespace distconv +#endif // LBANN_HAS_DISTCONV \ No newline at end of file diff --git a/src/layers/regularizers/layer_norm.cpp b/src/layers/regularizers/layer_norm.cpp index d0dbd44c097..dc345047b85 100644 --- a/src/layers/regularizers/layer_norm.cpp +++ b/src/layers/regularizers/layer_norm.cpp @@ -25,9 +25,9 @@ //////////////////////////////////////////////////////////////////////////////// #define LBANN_LAYER_NORM_LAYER_INSTANTIATE -#include "lbann/layers/regularizers/layer_norm.hpp" #include "lbann/comm_impl.hpp" #include "lbann/optimizers/optimizer.hpp" +#include "lbann/layers/regularizers/layer_norm_impl.hpp" #ifdef LBANN_HAS_DISTCONV #include "lbann/layers/data_type_distconv_adapter.hpp" diff --git a/src/layers/regularizers/layer_norm.cu b/src/layers/regularizers/layer_norm.cu index 596cf2c3a9f..3e7290072cc 100644 --- a/src/layers/regularizers/layer_norm.cu +++ b/src/layers/regularizers/layer_norm.cu @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. @@ -25,8 +25,10 @@ //////////////////////////////////////////////////////////////////////////////// #define LBANN_LAYER_NORM_LAYER_INSTANTIATE +#include "layer_norm_kernels.cuh" #include "lbann/comm_impl.hpp" #include "lbann/layers/regularizers/layer_norm.hpp" +#include "lbann/layers/regularizers/layer_norm_impl.hpp" #include "lbann/optimizers/optimizer.hpp" #include "lbann/utils/gpu/helpers.hpp" @@ -316,26 +318,13 @@ void fp_impl(lbann_comm& comm, El::Int block_size = min(El::Int(256), normalization_size); dim3 block_dims, grid_dims; block_dims.x = block_size; - grid_dims.x = (normalization_size + block_size - 1) / block_size; - grid_dims.y = num_normalized; - grid_dims.z = local_num_samples; - auto kernel = - ((!local_scale && !local_bias) - ? fp_output_kernel - : ((local_scale && !local_bias) - ? fp_output_kernel - : ((!local_scale && local_bias) - ? fp_output_kernel - : fp_output_kernel))); - hydrogen::gpu::LaunchKernel(kernel, + hydrogen::gpu::LaunchKernel(layer_norm_fp_output_kernel, grid_dims, block_dims, 0, multisync, local_num_samples, - normalization_size, - num_normalized, - normalization_stride, + local_sample_size, epsilon, local_input.LockedBuffer(), local_input.LDim(), @@ -344,170 +333,10 @@ void fp_impl(lbann_comm& comm, local_means.LockedBuffer(), local_means.LDim(), local_vars.LockedBuffer(), - local_vars.LDim(), - local_scale, - local_bias); - } -} - -/** Compute gradients w.r.t. per-sample statistics. - * - * dL/dmean = - sum(dL/dy_i) / sqrt(var+epsilon) - * - * dL/dvar = - sum(dL/dy_i * (x_i-mean)) * (var+epsilon)^(-3/2) / 2 - * - * On input, means_grad and vars_grad are filled with zeros. - * - * Block dimensions: bsize x 1 x 1 - * - * Grid dimensions: (normalization_size / bsize) x num_normalized x - * local_num_samples - */ -template -__global__ void -bp_statistics_grad_kernel(size_t local_num_samples, - size_t normalization_size, - size_t num_normalized, - size_t normalization_stride, - TensorDataType epsilon, - const TensorDataType* __restrict__ input, - size_t input_ldim, - const TensorDataType* __restrict__ output_grad, - size_t output_grad_ldim, - const TensorDataType* __restrict__ means, - size_t means_stride, - const TensorDataType* __restrict__ vars, - size_t vars_stride, - TensorDataType* __restrict__ means_grad, - size_t means_grad_stride, - TensorDataType* __restrict__ vars_grad, - size_t vars_grad_stride, - const TensorDataType* __restrict__ scale, - TensorDataType* __restrict__ scale_grad, - TensorDataType* __restrict__ bias_grad) -{ - - // Indices and dimensions - constexpr size_t bdimy = 1; - constexpr size_t bdimz = 1; - const size_t tid = threadIdx.x; - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - - for (size_t i = gidz; i < local_num_samples; i += nthreadsz) { - for (size_t j = gidy; j < num_normalized; j += nthreadsy) { - - const auto& var = vars[i * vars_stride + j]; - const auto& inv_stdev = gpu_lib::rsqrt(var + epsilon); - - // Accumulate sums and perform block-wide reduction - using pair_t = thrust::pair; - using pair_sum_t = pair_sum; - pair_t sums(0, 0); - const auto& mean = means[i * means_stride + j]; - for (size_t k = gidx; k < normalization_size; k += nthreadsx) { - const auto& x = input[i * input_ldim + j * normalization_stride + k]; - auto dy = - output_grad[i * output_grad_ldim + j * normalization_stride + k]; - if constexpr (HAS_BIAS) - gpu_lib::atomic_add(bias_grad + k, dy); - - if constexpr (HAS_SCALE) { - gpu_lib::atomic_add(scale_grad + k, dy * (x - mean) * inv_stdev); - dy *= scale[k]; - } - - sums.first += dy; - sums.second += dy * (x - mean); - } - sums = - gpu_lib::block_reduce(sums); - - // Output result to global memory - if (tid == 0) { - const TensorDataType dmean = -sums.first * inv_stdev; - const TensorDataType dvar = - -sums.second * inv_stdev * inv_stdev * inv_stdev / TensorDataType(2); - gpu_lib::atomic_add(&means_grad[i * means_grad_stride + j], dmean); - gpu_lib::atomic_add(&vars_grad[i * vars_grad_stride + j], dvar); - } - } + local_vars.LDim()); } } -/** Compute gradients w.r.t. input. - * - * dL/dx_i = ( dL/dy_i / sqrt(var+epsilon) - * + dL/dmean / n - * + dL/dvar * (x_i - mean) * 2/(n-1) ) - * - * Block dimensions: bdimx x bdimy x 1 - * - * Grid dimensions: (local_sample_size / bdimx) x (local_num_samples / bdimy) x - * 1 - */ -template -__global__ void -bp_input_grad_kernel(size_t local_num_samples, - size_t global_normalization_size, - size_t normalization_size, - size_t num_normalized, - size_t normalization_stride, - TensorDataType epsilon, - const TensorDataType* __restrict__ input, - size_t input_ldim, - const TensorDataType* __restrict__ output_grad, - size_t output_grad_ldim, - TensorDataType* __restrict__ input_grad, - size_t input_grad_ldim, - const TensorDataType* __restrict__ means, - size_t means_stride, - const TensorDataType* __restrict__ vars, - size_t vars_stride, - const TensorDataType* __restrict__ means_grad, - size_t means_grad_stride, - const TensorDataType* __restrict__ vars_grad, - size_t vars_grad_stride, - const TensorDataType* __restrict__ scale) -{ - - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - for (size_t i = gidz; i < local_num_samples; i += nthreadsz) { - for (size_t j = gidy; j < num_normalized; j += nthreadsy) { - const auto& mean = means[i * means_stride + j]; - const auto& var = vars[i * vars_stride + j]; - const auto& inv_stdev = gpu_lib::rsqrt(var + epsilon); - const auto& dmean = means_grad[i * means_grad_stride + j]; - const auto& dvar = vars_grad[i * vars_grad_stride + j]; - for (size_t k = gidx; k < normalization_size; k += nthreadsx) { - const auto& x = input[i * input_ldim + j * normalization_stride + k]; - auto dy = - output_grad[i * output_grad_ldim + j * normalization_stride + k]; - - if constexpr (HAS_SCALE) { - const auto& lscale = scale[k]; - dy *= lscale; - } - - auto& dx = - input_grad[i * input_grad_ldim + j * normalization_stride + k]; - dx = - (dy * inv_stdev + dmean / TensorDataType(global_normalization_size) + - dvar * (x - mean) * TensorDataType(2) / - TensorDataType(global_normalization_size)); - } - } - } -} /** @brief Backprop */ template @@ -576,9 +405,8 @@ void bp_impl(lbann_comm& comm, constexpr size_t block_size = 256; dim3 block_dims, grid_dims; block_dims.x = block_size; - grid_dims.x = (normalization_size + block_size - 1) / block_size; - grid_dims.y = num_normalized; - grid_dims.z = local_num_samples; + grid_dims.x = (local_sample_size + block_size - 1) / block_size; + grid_dims.y = local_num_samples; auto kernel = ((!scale_grad && !bias_grad) ? bp_statistics_grad_kernel @@ -636,9 +464,8 @@ void bp_impl(lbann_comm& comm, El::Int block_size = min(El::Int(256), normalization_size); dim3 block_dims, grid_dims; block_dims.x = block_size; - grid_dims.x = (normalization_size + block_size - 1) / block_size; - grid_dims.y = num_normalized; - grid_dims.z = local_num_samples; + grid_dims.x = (local_sample_size + block_size - 1) / block_size; + grid_dims.y = local_num_samples; auto kernel = (local_scale ? bp_input_grad_kernel : bp_input_grad_kernel); hydrogen::gpu::LaunchKernel(kernel, @@ -672,10 +499,70 @@ void bp_impl(lbann_comm& comm, } // namespace +// ========================================================= +// DistConv-Adapter member implementation +// ========================================================= + +#ifdef LBANN_HAS_DISTCONV +template +void layer_norm_distconv_adapter::fp_compute() +{ + auto& l = dynamic_cast&>( + this->layer()); + lbann_comm& comm = *(l.get_comm()); + + auto& statistics = *l.m_statistics; + assert0(dc::tensor::View(m_statistics, statistics.Buffer())); + + using GPUMatType = El::Matrix; + m_layer_norm_operator->calculate_forward_stats(this->get_prev_activations(), + m_statistics); + comm.allreduce(statistics, statistics.RedundantComm(), El::mpi::SUM); + m_layer_norm_operator->apply_normalization(this->get_prev_activations(), + m_statistics, + this->get_activations()); +} + +template +void layer_norm_distconv_adapter::bp_compute() +{ + auto& l = dynamic_cast&>( + this->layer()); + lbann_comm& comm = *(l.get_comm()); + + auto& statistics = *l.m_statistics; + auto& statistics_grad = *l.m_statistics_gradient; + assert0(dc::tensor::View(m_statistics, statistics.Buffer())); + assert0(dc::tensor::View(m_statistics_grad, statistics_grad.Buffer())); + + using GPUMatType = El::Matrix; + m_layer_norm_operator->calculate_backward_stats( + this->get_prev_activations(), + this->get_prev_error_signals(), + m_statistics, + m_statistics_grad); + comm.allreduce(statistics_grad, + statistics_grad.RedundantComm(), + El::mpi::SUM); + m_layer_norm_operator->apply_grad(this->get_prev_activations(), + this->get_prev_error_signals(), + m_statistics, + m_statistics_grad, + this->get_error_signals()); +} +#endif // LBANN_HAS_DISTCONV + // Template instantiation template void layer_norm_layer::fp_compute() { +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { + this->get_distconv_adapter().fp_compute(); + return; + } +#endif // LBANN_HAS_DISTCONV + int weight_idx = 0; const TensorDataType* scale_weights = nullptr; const TensorDataType* bias_weights = nullptr; @@ -685,10 +572,22 @@ void layer_norm_layer::fp_compute() if (m_bias) bias_weights = this->weights_values(weight_idx).LockedMatrix().LockedBuffer(); - El::Int norm_size, global_norm_size, num_norm, norm_stride; this->get_normdims(norm_size, global_norm_size, num_norm, norm_stride); +<<<<<<< HEAD +<<<<<<< HEAD +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { + this->get_distconv_adapter().fp_compute(); + return; + } +#endif // LBANN_HAS_DISTCONV +======= + +>>>>>>> f02146109 (Updated implementation with updating statistics tensors) +======= +>>>>>>> ecac28c9f (Updating layer norm impl) fp_impl(*this->get_comm(), this->m_epsilon, norm_size, @@ -705,6 +604,13 @@ void layer_norm_layer::fp_compute() template void layer_norm_layer::bp_compute() { +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { + this->get_distconv_adapter().bp_compute(); + return; + } +#endif // LBANN_HAS_DISTCONV + // Obtain optional buffers const TensorDataType* scale_weights = nullptr; TensorDataType* scale_grad = nullptr; @@ -721,10 +627,23 @@ void layer_norm_layer::bp_compute() bias_grad = this->m_bias_gradient->Buffer(); } +<<<<<<< HEAD El::Int norm_size, global_norm_size, num_norm, norm_stride; this->get_normdims(norm_size, global_norm_size, num_norm, norm_stride); // Compute backpropagation + +void layer_norm_layer::bp_compute() +{ +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { + this->get_distconv_adapter().bp_compute(); + return; + } +#endif // LBANN_HAS_DISTCONV + +======= +>>>>>>> f02146109 (Updated implementation with updating statistics tensors) bp_impl(*this->get_comm(), this->m_epsilon, norm_size, diff --git a/src/layers/regularizers/layer_norm_kernels.cuh b/src/layers/regularizers/layer_norm_kernels.cuh new file mode 100644 index 00000000000..9317f2e6086 --- /dev/null +++ b/src/layers/regularizers/layer_norm_kernels.cuh @@ -0,0 +1,301 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#ifndef LBANN_LAYERS_REGULARIZERS_NORM_LAYER_KERNELS +#define LBANN_LAYERS_REGULARIZERS_NORM_LAYER_KERNELS +#include "lbann/comm_impl.hpp" +#include "lbann/layers/regularizers/layer_norm.hpp" +#include "lbann/utils/gpu/helpers.hpp" +#include + +namespace lbann { + +/** Functor for adding @c thrust::pair objects. */ +template +struct pair_sum +{ + __device__ __forceinline__ Pair operator()(const Pair& x, const Pair& y) + { + return Pair(x.first + y.first, x.second + y.second); + } +}; + +// ========================================================= +// Forward prop +// ========================================================= + +/** Accumulate sums and sums of squares for each data sample. + * + * On input, sums and sqsums are filled with zeros. + * + * Block dimensions: bsize x 1 x 1 + * + * Grid dimensions: (local_sample_size / bsize) x local_num_samples x 1 + */ +template +__global__ void +layer_norm_fp_sums_kernel(size_t local_num_samples, + size_t local_sample_size, + const TensorDataType* __restrict__ vals, + size_t vals_ldim, + TensorDataType* sums, + size_t sums_stride, + TensorDataType* sqsums, + size_t sqsums_stride) +{ + + // Indices and dimensions + constexpr size_t bdimy = 1; + constexpr size_t bdimz = 1; + const size_t tid = threadIdx.x + blockDim.x * threadIdx.y; + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + + for (size_t i = gidy; i < local_num_samples; i += nthreadsy) { + + // Accumulate sums and perform block-wide reduction + using pair_t = thrust::pair; + using pair_sum_t = pair_sum; + pair_t sum_sqsum(0, 0); + for (size_t j = gidx; j < local_sample_size; j += nthreadsx) { + const auto& x = vals[i * vals_ldim + j]; + sum_sqsum.first += x; + sum_sqsum.second += x * x; + } + sum_sqsum = + gpu_lib::block_reduce(sum_sqsum); + + // Output result to global memory + if (tid == 0) { + gpu_lib::atomic_add(&sums[i * sums_stride], sum_sqsum.first); + gpu_lib::atomic_add(&sqsums[i * sqsums_stride], sum_sqsum.second); + } + } +} + +/** Compute per-sample statistics. + * + * mean = sum(x_i) / n + * + * var = ( sum(x_i^2)/n - mean^2 ) + * + * On input, means contains per-sample sums and vars contains + * per-sample sums of squares. + * + * Block dimensions: bsize x 1 x 1 + * + * Grid dimensions: (local_num_samples / bsize) x 1 x 1 + */ +template +__global__ void layer_norm_fp_statistics_kernel(unsigned long long sample_size, + size_t local_num_samples, + TensorDataType* means, + size_t means_stride, + TensorDataType* vars, + size_t vars_stride) +{ + + const size_t gid = threadIdx.x + blockIdx.x * blockDim.x; + const size_t nthreads = blockDim.x * gridDim.x; + for (size_t i = gid; i < local_num_samples; i += nthreads) { + const auto sum = means[i * means_stride]; + const auto sqsum = vars[i * means_stride]; + const TensorDataType sample_size_dt = TensorDataType(sample_size); + const auto& mean = sum / sample_size_dt; + const auto& sqmean = sqsum / sample_size_dt; + const auto& var = (sqmean - mean * mean); + means[i * means_stride] = mean; + vars[i * vars_stride] = gpu_lib::max(var, TensorDataType(0.0)); + } +} + +/** Compute outputs. + * + * y_i = (x_i - mean) / sqrt(var + epsilon) + * + * Block dimensions: bdimx x bdimy x 1 + * + * Grid dimensions: (local_sample_size / bdimx) x (local_num_samples / bdimy) x + * 1 + */ +template +__global__ void +layer_norm_fp_output_kernel(size_t local_num_samples, + size_t local_sample_size, + TensorDataType epsilon, + const TensorDataType* __restrict__ input, + size_t input_ldim, + TensorDataType* __restrict__ output, + size_t output_ldim, + const TensorDataType* means, + size_t means_stride, + const TensorDataType* vars, + size_t vars_stride) +{ + + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + for (size_t i = gidy; i < local_num_samples; i += nthreadsy) { + const auto& mean = means[i * means_stride]; + const auto& var = vars[i * vars_stride]; + const auto& inv_stdev = gpu_lib::rsqrt(var + epsilon); + for (size_t j = gidx; j < local_sample_size; j += nthreadsx) { + const auto& x = input[i * input_ldim + j]; + auto& y = output[i * output_ldim + j]; + y = (x - mean) * inv_stdev; + } + } +} + +/** Compute gradients w.r.t. per-sample statistics. + * + * dL/dmean = - sum(dL/dy_i) / sqrt(var+epsilon) + * + * dL/dvar = - sum(dL/dy_i * (x_i-mean)) * (var+epsilon)^(-3/2) / 2 + * + * On input, means_grad and vars_grad are filled with zeros. + * + * Block dimensions: bsize x 1 x 1 + * + * Grid dimensions: (local_sample_size / bsize) x local_num_samples x 1 + */ +template +__global__ void layer_norm_bp_statistics_grad_kernel( + size_t local_num_samples, + size_t local_sample_size, + TensorDataType epsilon, + const TensorDataType* __restrict__ input, + size_t input_ldim, + const TensorDataType* __restrict__ output_grad, + size_t output_grad_ldim, + const TensorDataType* means, + size_t means_stride, + const TensorDataType* vars, + size_t vars_stride, + TensorDataType* means_grad, + size_t means_grad_stride, + TensorDataType* vars_grad, + size_t vars_grad_stride) +{ + + // Indices and dimensions + constexpr size_t bdimy = 1; + constexpr size_t bdimz = 1; + const size_t tid = threadIdx.x + blockDim.x * threadIdx.y; + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + + for (size_t i = gidy; i < local_num_samples; i += nthreadsy) { + + // Accumulate sums and perform block-wide reduction + using pair_t = thrust::pair; + using pair_sum_t = pair_sum; + pair_t sums(0, 0); + const auto& mean = means[i * means_stride]; + for (size_t j = gidx; j < local_sample_size; j += nthreadsx) { + const auto& x = input[i * input_ldim + j]; + const auto& dy = output_grad[i * output_grad_ldim + j]; + sums.first += dy; + sums.second += dy * (x - mean); + } + sums = gpu_lib::block_reduce(sums); + + // Output result to global memory + if (tid == 0) { + const auto& var = vars[i * vars_stride]; + const auto& inv_stdev = gpu_lib::rsqrt(var + epsilon); + const TensorDataType dmean = -sums.first * inv_stdev; + const TensorDataType dvar = + -sums.second * inv_stdev * inv_stdev * inv_stdev / TensorDataType(2); + gpu_lib::atomic_add(&means_grad[i * means_grad_stride], dmean); + gpu_lib::atomic_add(&vars_grad[i * vars_grad_stride], dvar); + } + } +} + +/** Compute gradients w.r.t. input. + * + * dL/dx_i = ( dL/dy_i / sqrt(var+epsilon) + * + dL/dmean / n + * + dL/dvar * (x_i - mean) * 2/(n-1) ) + * + * Block dimensions: bdimx x bdimy x 1 + * + * Grid dimensions: (local_sample_size / bdimx) x (local_num_samples / bdimy) x + * 1 + */ +template +__global__ void +layer_norm_bp_input_grad_kernel(unsigned long long sample_size, + size_t local_num_samples, + size_t local_sample_size, + TensorDataType epsilon, + const TensorDataType* __restrict__ input, + size_t input_ldim, + const TensorDataType* __restrict__ output_grad, + size_t output_grad_ldim, + TensorDataType* __restrict__ input_grad, + size_t input_grad_ldim, + const TensorDataType* __restrict__ means, + size_t means_stride, + const TensorDataType* __restrict__ vars, + size_t vars_stride, + const TensorDataType* means_grad, + size_t means_grad_stride, + const TensorDataType* vars_grad, + size_t vars_grad_stride) +{ + + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + for (size_t i = gidy; i < local_num_samples; i += nthreadsy) { + const auto& mean = means[i * means_stride]; + const auto& var = vars[i * vars_stride]; + const auto& inv_stdev = gpu_lib::rsqrt(var + epsilon); + const auto& dmean = means_grad[i * means_grad_stride]; + const auto& dvar = vars_grad[i * vars_grad_stride]; + for (size_t j = gidx; j < local_sample_size; j += nthreadsx) { + const auto& x = input[i * input_ldim + j]; + const auto& dy = output_grad[i * output_grad_ldim + j]; + auto& dx = input_grad[i * input_grad_ldim + j]; + dx = + (dy * inv_stdev + dmean / TensorDataType(sample_size) + + dvar * (x - mean) * TensorDataType(2) / TensorDataType(sample_size)); + } + } +} + +} // namespace lbann + +#endif // LBANN_LAYERS_REGULARIZERS_NORM_LAYER_KERNELS \ No newline at end of file