From d41bf547b2b37063255f506585474e5a9fb51001 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Wed, 26 Oct 2022 13:40:09 -0700 Subject: [PATCH 01/23] Adding channelwise softmax distconv unit test. --- ...unit_layer_channelwise_softmax_distconv.py | 176 ++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py diff --git a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py new file mode 100644 index 00000000000..b6f9d4e77e4 --- /dev/null +++ b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py @@ -0,0 +1,176 @@ +import functools +import operator +import os +import os.path +import sys +import numpy as np + +# Bamboo utilities +current_file = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file) +sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python')) +import tools + +# ============================================== +# Objects for Python data reader +# ============================================== +# Note: The Python data reader imports this file as a module and calls +# the functions below to ingest data. + +# Data +np.random.seed(20200115) +_num_samples = 15 +_sample_dims = (5,2,7) +_sample_size = functools.reduce(operator.mul, _sample_dims) +_samples = np.random.normal(loc=0.5, size=(_num_samples,_sample_size)).astype(np.float32) + +# Sample access functions +def get_sample(index): + return _samples[index,:] +def num_samples(): + return _num_samples +def sample_dims(): + return (_sample_size,) + +# ============================================== +# NumPy implementation +# ============================================== + +def numpy_channelwise_softmax(x): + if x.dtype is not np.float64: + x = x.astype(np.float64) + axis = tuple(range(1,x.ndim)) + shift = np.max(x, axis=axis, keepdims=True) + y = np.exp(x-shift) + return y / np.sum(y, axis=axis, keepdims=True) + +# ============================================== +# Setup LBANN experiment +# ============================================== + +def setup_experiment(lbann, weekly): + """Construct LBANN experiment. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + mini_batch_size = num_samples() // 2 + trainer = lbann.Trainer(mini_batch_size) + model = construct_model(lbann) + data_reader = construct_data_reader(lbann) + optimizer = lbann.NoOptimizer() + return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes + +def construct_model(lbann): + """Construct LBANN model. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + # Input data + # Note: Sum with a weights layer so that gradient checking will + # verify that error signals are correct. + x_weights = lbann.Weights(optimizer=lbann.SGD(), + initializer=lbann.ConstantInitializer(value=0.0), + name='input_weights') + x = lbann.Sum(lbann.Reshape(lbann.Input(data_field='samples'), + dims=_sample_dims), + lbann.WeightsLayer(weights=x_weights, + dims=_sample_dims)) + x_lbann = x + + # Objects for LBANN model + obj = [] + metrics = [] + callbacks = [] + + # ------------------------------------------ + # Data-parallel layout + # ------------------------------------------ + + # LBANN implementation + x = x_lbann + y = lbann.ChannelwiseSoftmax(x, data_layout='data_parallel') + z = lbann.L2Norm2(y) + obj.append(z) + metrics.append(lbann.Metric(z, name='data-parallel layout')) + + # NumPy implementation + vals = [] + for i in range(num_samples()): + x = get_sample(i).reshape(_sample_dims).astype(np.float64) + y = numpy_channelwise_softmax(x) + z = tools.numpy_l2norm2(y) + vals.append(z) + val = np.mean(vals) + tol = 8 * val * np.finfo(np.float32).eps + callbacks.append(lbann.CallbackCheckMetric( + metric=metrics[-1].name, + lower_bound=val-tol, + upper_bound=val+tol, + error_on_failure=True, + execution_modes='test')) + + # ------------------------------------------ + # Gradient checking + # ------------------------------------------ + + callbacks.append(lbann.CallbackCheckGradients(error_on_failure=True)) + + # ------------------------------------------ + # Construct model + # ------------------------------------------ + + num_epochs = 0 + return lbann.Model(num_epochs, + layers=lbann.traverse_layer_graph(x_lbann), + objective_function=obj, + metrics=metrics, + callbacks=callbacks) + +def construct_data_reader(lbann): + """Construct Protobuf message for Python data reader. + + The Python data reader will import the current Python file to + access the sample access functions. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + # Note: The training data reader should be removed when + # https://github.com/LLNL/lbann/issues/1098 is resolved. + message = lbann.reader_pb2.DataReader() + message.reader.extend([ + tools.create_python_data_reader( + lbann, + current_file, + 'get_sample', + 'num_samples', + 'sample_dims', + 'train' + ) + ]) + message.reader.extend([ + tools.create_python_data_reader( + lbann, + current_file, + 'get_sample', + 'num_samples', + 'sample_dims', + 'test' + ) + ]) + return message + +# ============================================== +# Setup PyTest +# ============================================== + +# Create test functions that can interact with PyTest +for _test_func in tools.create_tests(setup_experiment, __file__): + globals()[_test_func.__name__] = _test_func From b3cf148c3ac69a1121a45e3acd451ce864e29e8f Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Wed, 26 Oct 2022 14:02:31 -0700 Subject: [PATCH 02/23] Set up boilerplate --- include/lbann/layers/misc/CMakeLists.txt | 3 +++ include/lbann/layers/misc/distconv/CMakeLists.txt | 0 .../misc/distconv/distonv_channelwise_softmax.hpp | 0 include/lbann/utils/distconv.hpp | 10 ++++++++++ src/layers/misc/CMakeLists.txt | 4 ++++ src/layers/misc/distconv/CMakeLists.txt | 0 .../misc/distconv/distconv_channelwise_softmax.cpp | 0 7 files changed, 17 insertions(+) create mode 100644 include/lbann/layers/misc/distconv/CMakeLists.txt create mode 100644 include/lbann/layers/misc/distconv/distonv_channelwise_softmax.hpp create mode 100644 src/layers/misc/distconv/CMakeLists.txt create mode 100644 src/layers/misc/distconv/distconv_channelwise_softmax.cpp diff --git a/include/lbann/layers/misc/CMakeLists.txt b/include/lbann/layers/misc/CMakeLists.txt index d84c4e0accf..258023a8fbb 100644 --- a/include/lbann/layers/misc/CMakeLists.txt +++ b/include/lbann/layers/misc/CMakeLists.txt @@ -40,5 +40,8 @@ set_full_path(THIS_DIR_HEADERS variance.hpp ) +if (LBANN_HAS_DISTCONV) + add_subdirectory(distconv) +endif() # Propagate the files up the tree set(HEADERS "${HEADERS}" "${THIS_DIR_HEADERS}" PARENT_SCOPE) diff --git a/include/lbann/layers/misc/distconv/CMakeLists.txt b/include/lbann/layers/misc/distconv/CMakeLists.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/include/lbann/layers/misc/distconv/distonv_channelwise_softmax.hpp b/include/lbann/layers/misc/distconv/distonv_channelwise_softmax.hpp new file mode 100644 index 00000000000..e69de29bb2d diff --git a/include/lbann/utils/distconv.hpp b/include/lbann/utils/distconv.hpp index 30ab7b60b4d..b03813c4358 100644 --- a/include/lbann/utils/distconv.hpp +++ b/include/lbann/utils/distconv.hpp @@ -53,6 +53,16 @@ #include "p2p/p2p.hpp" #endif // DISTCONV_HAS_P2P +#include "lbann/layers/learning/distconv/distconv_layers.hpp" +#include "lbann/layers/math/distconv/distconv_matmul.hpp" + +#ifdef LBANN_HAS_NVSHMEM +#include "lbann/layers/transform/distconv/distconv_scatter.hpp" +#include "lbann/layers/transform/distconv/distconv_gather.hpp" +#include "lbann/layers/transform/distconv/distconv_nvshmem_vector_addressing.hpp" +#endif // LBANN_HAS_NVSHMEM + +#include "lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp" namespace lbann { inline auto default_hydrogen_stream() diff --git a/src/layers/misc/CMakeLists.txt b/src/layers/misc/CMakeLists.txt index 77b50a02d7b..8fac35660cb 100644 --- a/src/layers/misc/CMakeLists.txt +++ b/src/layers/misc/CMakeLists.txt @@ -66,6 +66,10 @@ endif () # Add the subdirectories add_subdirectory(cereal_registration) +if (LBANN_HAS_DISTCONV) + add_subdirectory(distconv) +endif() + # Propagate the files up the tree set(SOURCES "${SOURCES}" "${THIS_DIR_SOURCES}" PARENT_SCOPE) set(GPU_SOURCES "${GPU_SOURCES}" "${THIS_DIR_CU_SOURCES}" PARENT_SCOPE) diff --git a/src/layers/misc/distconv/CMakeLists.txt b/src/layers/misc/distconv/CMakeLists.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cpp b/src/layers/misc/distconv/distconv_channelwise_softmax.cpp new file mode 100644 index 00000000000..e69de29bb2d From d108b2fee0761001757cec703da8a812d85162ec Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Wed, 26 Oct 2022 14:23:15 -0700 Subject: [PATCH 03/23] Update cmake-ary --- .../lbann/layers/misc/distconv/CMakeLists.txt | 31 +++++++++++++++++++ src/layers/misc/distconv/CMakeLists.txt | 31 +++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/include/lbann/layers/misc/distconv/CMakeLists.txt b/include/lbann/layers/misc/distconv/CMakeLists.txt index e69de29bb2d..c03e884e25d 100644 --- a/include/lbann/layers/misc/distconv/CMakeLists.txt +++ b/include/lbann/layers/misc/distconv/CMakeLists.txt @@ -0,0 +1,31 @@ +################################################################################ +## Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +## Produced at the Lawrence Livermore National Laboratory. +## Written by the LBANN Research Team (B. Van Essen, et al.) listed in +## the CONTRIBUTORS file. +## +## LLNL-CODE-697807. +## All rights reserved. +## +## This file is part of LBANN: Livermore Big Artificial Neural Network +## Toolkit. For details, see http://software.llnl.gov/LBANN or +## https://github.com/LLNL/LBANN. +## +## Licensed under the Apache License, Version 2.0 (the "Licensee"); you +## may not use this file except in compliance with the License. You may +## obtain a copy of the License at: +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +## implied. See the License for the specific language governing +## permissions and limitations under the license. +################################################################################ +set_full_path(THIS_DIR_HEADERS + distconv_channelwise_matmul.hpp + ) + +# Propagate the files up the tree +set(HEADERS "${HEADERS}" "${THIS_DIR_HEADERS}" PARENT_SCOPE) diff --git a/src/layers/misc/distconv/CMakeLists.txt b/src/layers/misc/distconv/CMakeLists.txt index e69de29bb2d..29d9b3a0c32 100644 --- a/src/layers/misc/distconv/CMakeLists.txt +++ b/src/layers/misc/distconv/CMakeLists.txt @@ -0,0 +1,31 @@ +################################################################################ +## Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +## Produced at the Lawrence Livermore National Laboratory. +## Written by the LBANN Research Team (B. Van Essen, et al.) listed in +## the CONTRIBUTORS file. +## +## LLNL-CODE-697807. +## All rights reserved. +## +## This file is part of LBANN: Livermore Big Artificial Neural Network +## Toolkit. For details, see http://software.llnl.gov/LBANN or +## https://github.com/LLNL/LBANN. +## +## Licensed under the Apache License, Version 2.0 (the "Licensee"); you +## may not use this file except in compliance with the License. You may +## obtain a copy of the License at: +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +## implied. See the License for the specific language governing +## permissions and limitations under the license. +################################################################################ +set_full_path(THIS_DIR_HEADERS + distconv_channelwise_softmax.hpp + ) + +# Propagate the files up the tree +set(HEADERS "${HEADERS}" "${THIS_DIR_HEADERS}" PARENT_SCOPE) From 0d95e9a8ad4843231966bc127439204d30e1e087 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Thu, 26 Jan 2023 05:11:06 -0800 Subject: [PATCH 04/23] - Fixed old file naming issues - Adding boilerplate for distconv adapter - Updated CMakeLists to correct file type --- .../lbann/layers/misc/channelwise_softmax.hpp | 145 +++++++++- .../lbann/layers/misc/distconv/CMakeLists.txt | 2 +- .../distconv/distconv_channelwise_softmax.hpp | 59 ++++ .../distconv/distonv_channelwise_softmax.hpp | 0 include/lbann/utils/distconv.hpp | 20 +- src/layers/misc/channelwise_softmax.cu | 30 ++- src/layers/misc/distconv/CMakeLists.txt | 6 +- .../distconv/distconv_channelwise_softmax.cpp | 0 .../distconv/distconv_channelwise_softmax.cu | 251 ++++++++++++++++++ 9 files changed, 488 insertions(+), 25 deletions(-) create mode 100644 include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp delete mode 100644 include/lbann/layers/misc/distconv/distonv_channelwise_softmax.hpp delete mode 100644 src/layers/misc/distconv/distconv_channelwise_softmax.cpp create mode 100644 src/layers/misc/distconv/distconv_channelwise_softmax.cu diff --git a/include/lbann/layers/misc/channelwise_softmax.hpp b/include/lbann/layers/misc/channelwise_softmax.hpp index 7f0e792acae..644d93a833a 100644 --- a/include/lbann/layers/misc/channelwise_softmax.hpp +++ b/include/lbann/layers/misc/channelwise_softmax.hpp @@ -32,8 +32,36 @@ #include "lbann/proto/layers.pb.h" +#ifdef LBANN_HAS_DISTCONV +#include "lbann/layers/data_type_distconv_adapter.hpp" +#include "lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp" +#endif + + namespace lbann { +#ifdef LBANN_HAS_DISTCONV +template +class channelwise_softmax_distconv_adapter + : public data_type_distconv_adapter{ + public: + using TensorDevType = typename data_type_distconv_adapter::TensorDevType; + + channelwise_softmax_distconv_adapter(Layer& layer) + : data_type_distconv_adapter(layer){} + + virtual ~channelwise_softmax_distconv_adapter() = default; + void setup_distributions(tensor_overlap_constraints &constraints) override; + void setup_layer(size_t workspace_capacity) override; + void fp_compute(); + void bp_compute(); + dc::Shape get_activations_local_shape(int index=0) const override; + std::unique_ptr> m_channelwise_softmax_operator; + }; // class definition channelwise_softmax_distconv_adapter + +#endif // LBANN_HAS_DISTCONV + + /** @brief Apply softmax to tensor channels. * * The input tensor is sliced along the first tensor dimension (the @@ -93,17 +121,15 @@ class channelwise_softmax_layer : public data_type_layer void fp_compute() override; void bp_compute() override; -private: - void get_channel_size_and_stride(El::Int& channel_size, - El::Int& channel_stride, - El::Int& num_channels) const; +#ifdef LBANN_HAS_DISTCONV + friend class channelwise_softmax_distconv_adapter; + protected: + void setup_distconv_adapter(const DataReaderMetaData& dr_metadata) override; + bool is_distconv_supported() const override; + channelwise_softmax_distconv_adapter& get_distconv_adapter() override; + const channelwise_softmax_distconv_adapter& get_distconv_adapter() const override; +#endif // LBANN_HAS_DISTCONV - /** Specifies the dimension of the tensor to perform softmax on. */ - int64_t m_dim; - - /** @brief If true, only performs softmax on the chosen dimension. Otherwise - all dimensions but ``m_dim`` will be used. */ - bool m_single_dim_mode; }; // Builder function @@ -159,9 +185,106 @@ El::Device channelwise_softmax_layer:: return Device; } +template +void channelwise_softmax_layer::setup_dims(DataReaderMetaData& dr_metadata) { + data_type_layer::setup_dims(dr_metadata); + this->set_output_dims(this->get_input_dims()); +} + +#ifdef LBANN_HAS_DISTCONV + // ========================================================= -// Explicit template instantiation +// DistConv-Adapter member functions // ========================================================= +template +void +channelwise_softmax_distconv_adapter +::setup_distributions(tensor_overlap_constraints &constraints){ + data_type_distconv_adapter::setup_distributions(constraints); + + for (auto &d: this->m_prev_activations_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } + for (auto &d: this->m_activations_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } + for (auto &d: this->m_prev_error_signals_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } + for (auto &d: this->m_error_signals_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } +} + +template +void +channelwise_softmax_distconv_adapter +::setup_layer(size_t workspace_capacity){ + data_type_distconv_adapter::setup_layer(workspace_capacity); + + m_channelwise_softmax_operator = std::make_unique>(dc::get_backend()); +} + +template +void +channelwise_softmax_distconv_adapter +::fp_compute(){ + auto &layer = dynamic_cast< + channelwise_softmax_layer&>(this->layer()); +} + +template +void +channelwise_softmax_distconv_adapter +::bp_compute(){ + auto &layer = dynamic_cast< + channelwise_softmax_layer&>(this->layer()); +} +// ============================================================= +// DistConv-enabled Channelwise-Softmax member functions +// ============================================================= + +template +bool +channelwise_softmax_layer +::is_distconv_supported() const { + return Device==El::Device::GPU && Layout == data_layout::DATA_PARALLEL; +} + +template +void +channelwise_softmax_layer +::setup_distconv_adapter(const DataReaderMetaData& dr_metadata){ + this->get_distconv_adapter_ptr() = std::make_unique>(*this); +} + +template +const channelwise_softmax_distconv_adapter& +channelwise_softmax_layer +::get_distconv_adapter() const{ + return dynamic_cast&>(data_type_layer::get_distconv_adapter()); +} + +template +channelwise_softmax_distconv_adapter& +channelwise_softmax_layer +::get_distconv_adapter(){ + return const_cast&>( + static_cast&>(*this).get_distconv_adapter()); +} + + +#endif // LBANN_HAS_DISTCONV #ifndef LBANN_CHANNELWISE_SOFTMAX_LAYER_INSTANTIATE #define PROTO_DEVICE(T, Device) \ diff --git a/include/lbann/layers/misc/distconv/CMakeLists.txt b/include/lbann/layers/misc/distconv/CMakeLists.txt index c03e884e25d..29d9b3a0c32 100644 --- a/include/lbann/layers/misc/distconv/CMakeLists.txt +++ b/include/lbann/layers/misc/distconv/CMakeLists.txt @@ -24,7 +24,7 @@ ## permissions and limitations under the license. ################################################################################ set_full_path(THIS_DIR_HEADERS - distconv_channelwise_matmul.hpp + distconv_channelwise_softmax.hpp ) # Propagate the files up the tree diff --git a/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp b/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp new file mode 100644 index 00000000000..896794ee7bc --- /dev/null +++ b/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp @@ -0,0 +1,59 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#ifndef LBANN_LAYERS_MISC_DISTCONV_CHANNELWISE_SOFTMAX +#define LBANN_LAYERS_MISC_DISTCONV_CHANNELWISE_SOFTMAX + +#ifdef LBANN_HAS_DISTCONV +namespace distconv{ + template + class ChannelwiseSoftmax{ + using LocaleMPI = tensor::LocaleMPI; + + public: + ChannelwiseSoftmax(Backend &backend):m_be(backend){}; + + template + int forward( + const tensor::Tensor &input_0, + tensor::Tensor &output; + ); + + template + int backward( + const tensor::Tensor &input_0, + const tensor::Tensor &output_grad, + tensor::Tensor &input_grad_0, + ); + + protected: + Backend &m_be; + + }; +} + +#endif // LBANN_HAS_DISTCONV +#endif // LBANN_LAYERS_MISC_DISTCONV_CHANNELWISE_SOFTMAX \ No newline at end of file diff --git a/include/lbann/layers/misc/distconv/distonv_channelwise_softmax.hpp b/include/lbann/layers/misc/distconv/distonv_channelwise_softmax.hpp deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/include/lbann/utils/distconv.hpp b/include/lbann/utils/distconv.hpp index b03813c4358..3e4ebb2cd64 100644 --- a/include/lbann/utils/distconv.hpp +++ b/include/lbann/utils/distconv.hpp @@ -63,6 +63,7 @@ #endif // LBANN_HAS_NVSHMEM #include "lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp" + namespace lbann { inline auto default_hydrogen_stream() @@ -136,8 +137,23 @@ using MPIRootPrintStreamWaning = ::distconv::util::MPIRootPrintStreamWarning; // Distconv layer classes using Backend = ::distconv::BackendDNNLib; -using AlCommType = typename decltype(std::declval() - .get_al_mpi_cuda_comm())::element_type; +using ReLU = ::distconv::ReLU; +using LeakyReLU = ::distconv::LeakyReLU; +template +using Convolution = ::distconv::Convolution; +template +using ChannelwiseFullyConnected = ::distconv::ChannelwiseFullyConnected; +template +using Pooling = ::distconv::Pooling; +template +using BatchNormalization = ::distconv::BatchNormalization; +template +using MatMul = ::distconv::MatMul; +template +using ChannelwiseSoftmax = ::distconv::ChannelwiseSoftmax; +using Softmax = ::distconv::Softmax; +using CrossEntropy = ::distconv::CrossEntropy; +using MeanSquaredError = ::distconv::MeanSquaredError; using ::distconv::get_channel_dim; using ::distconv::get_sample_dim; diff --git a/src/layers/misc/channelwise_softmax.cu b/src/layers/misc/channelwise_softmax.cu index 083f4f55a7d..2f962329997 100644 --- a/src/layers/misc/channelwise_softmax.cu +++ b/src/layers/misc/channelwise_softmax.cu @@ -335,10 +335,17 @@ void fp_impl(size_t num_channels, } // namespace template -void channelwise_softmax_layer::fp_compute() -{ - El::Int num_channels, channel_size, channel_stride; - this->get_channel_size_and_stride(channel_size, channel_stride, num_channels); +void channelwise_softmax_layer::fp_compute() { + + #ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()){ + this->get_distconv_adapter().fp_compute(); + return ; + } + #endif // LBANN_HAS_DISTCONV + + const size_t num_channels = this->get_output_dims().front(); + const size_t channel_size = this->get_output_size() / num_channels; fp_impl(num_channels, channel_size, channel_stride, @@ -536,10 +543,17 @@ void bp_impl(size_t num_channels, } // namespace template -void channelwise_softmax_layer::bp_compute() -{ - El::Int num_channels, channel_size, channel_stride; - this->get_channel_size_and_stride(channel_size, channel_stride, num_channels); +void channelwise_softmax_layer::bp_compute() { + + #ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()){ + this->get_distconv_adapter().bp_compute(); + return ; + } + #endif // LBANN_HAS_DISTCONV + + const size_t num_channels = this->get_output_dims().front(); + const size_t channel_size = this->get_output_size() / num_channels; bp_impl(num_channels, channel_size, channel_stride, diff --git a/src/layers/misc/distconv/CMakeLists.txt b/src/layers/misc/distconv/CMakeLists.txt index 29d9b3a0c32..30270ed7b63 100644 --- a/src/layers/misc/distconv/CMakeLists.txt +++ b/src/layers/misc/distconv/CMakeLists.txt @@ -23,9 +23,9 @@ ## implied. See the License for the specific language governing ## permissions and limitations under the license. ################################################################################ -set_full_path(THIS_DIR_HEADERS - distconv_channelwise_softmax.hpp +set_full_path(THIS_DIR_CU_SOURCES + distconv_channelwise_softmax.cu ) # Propagate the files up the tree -set(HEADERS "${HEADERS}" "${THIS_DIR_HEADERS}" PARENT_SCOPE) +set(GPU_SOURCES "${GPU_SOURCES}" "${THIS_DIR_CU_SOURCES}" PARENT_SCOPE) diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cpp b/src/layers/misc/distconv/distconv_channelwise_softmax.cpp deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu new file mode 100644 index 00000000000..88a4a3fb739 --- /dev/null +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -0,0 +1,251 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#define LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_INSTANTIATE + +#ifdef LBANN_HAS_DISTCONV +namespace distconv{ +namespace{ + +using Size3 = gpu_lib::array; + +/** @brief Max functor */ +template +struct max_op { + __device__ __forceinline__ + DataType operator()(const T& x1, const T& x2) const { + return gpu_lib::max(x1, x2); + } +}; + +} // namespace + +// ========================================================= +// Forward prop +// ========================================================= + +namespace { + +/** @brief Max reduction over last dimension of 3D tensor. + * + * Each CUDA block computes the max over a subset of tensor entries + * in @c vals and outputs the result to @c maxvals. This should be + * repeated multiple times to fully reduce the last tensor dimension. + * + * Block dimensions: bdimx x 1 x 1 + * + * Grid dimensions: (vals_dims[2] / bdimx) x vals_dims[1] x vals_dims[0] + * + * maxvals: vals_dims[0] x vals_dims[1] x (vals_dims[2] / bdimx) + */ +template +__global__ void fp_max_kernel( + Size3 vals_dims, + const TensorDataType* __restrict__ vals_buffer, + Size3 vals_strides, + TensorDataType* __restrict__ maxvals_buffer, + Size3 maxvals_strides) { + + // Indices and dimensions + constexpr size_t bdimy = 1; + constexpr size_t bdimz = 1; + const size_t tid = threadIdx.x; + const size_t bidx = blockIdx.x; + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + + for (size_t k = gidz; k < vals_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < vals_dims[1]; j += nthreadsy) { + + // Find largest value for each thread + TensorDataType maxval{-gpu_lib::infinity()}; + for (size_t i = gidx; i < vals_dims[2]; i += nthreadsx) { + const auto& val = vals_buffer[k * vals_strides[0] + + j * vals_strides[1] + + i * vals_strides[2]]; + maxval = gpu_lib::max(maxval, val); + } + + // Find largest value for each block + maxval = gpu_lib::block_reduce>(maxval); + if (tid == 0) { + const auto& pos = (k * maxvals_strides[0] + + j * maxvals_strides[1] + + bidx * maxvals_strides[2]); + maxvals_buffer[pos] = maxval; + } + + } + } + +} + +} // Namespace + + + +// ========================================================= +// Backprop +// ========================================================= + +namespace { +/** Compute dot product between output and gradient w.r.t. output. + * + * Block dimensions: bdimx x 1 x 1 + * + * Grid dimensions: (output_dims[2] / bdimx) x output_dims[1] x output_dims[0] + * + * y_dot_dy is a fully-packed 2D tensor with dimensions of + * output_dims[0] x output_dims[1]. + */ +template +__global__ void bp_y_dot_dy_kernel( + Size3 output_dims, + const TensorDataType* __restrict__ output_buffer, + Size3 output_strides, + const TensorDataType* __restrict__ output_grad_buffer, + Size3 output_grad_strides, + TensorDataType* __restrict__ y_dot_dy) { + + // Indices and dimensions + constexpr size_t bdimy = 1; + constexpr size_t bdimz = 1; + const size_t tid = threadIdx.x; + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + + for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { + + // Compute contribution from each thread + TensorDataType _y_dot_dy{0.}; + for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { + const auto& y = output_buffer[k * output_strides[0] + + j * output_strides[1] + + i * output_strides[2]]; + const auto& dy = output_grad_buffer[k * output_grad_strides[0] + + j * output_grad_strides[1] + + i * output_grad_strides[2]]; + _y_dot_dy += y * dy; + } + + // Compute contribution from each block + _y_dot_dy = gpu_lib::block_reduce(_y_dot_dy); + if (tid == 0) { + gpu_lib::atomic_add(&y_dot_dy[j+k*output_dims[1]], _y_dot_dy); + } + + } + } + +} + +/** Compute gradient w.r.t. input. + * + * dL/dx_i = y_i * ( dL/dy_i - dot(y,dL/dy) ) + * + * Block dimensions: bdimx x bdimy x bdimz + * + * Grid dimensions: (output_dims[2] / bdimx) x (output_dims[1] / bdimy) x (output_dims[0] / bdimz) + * + * y_dot_dy is a fully-packed 2D tensor with dimensions of + * output_dims[0] x output_dims[1]. + */ +template +__global__ void bp_input_grad_kernel( + Size3 output_dims, + const TensorDataType* __restrict__ output_buffer, + Size3 output_strides, + const TensorDataType* __restrict__ output_grad_buffer, + Size3 output_grad_strides, + TensorDataType* __restrict__ input_grad_buffer, + Size3 input_grad_strides, + const TensorDataType* __restrict__ y_dot_dy) { + + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { + const auto& _y_dot_dy = y_dot_dy[j + k*output_dims[1]]; + for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { + const auto& y = output_buffer[k * output_strides[0] + + j * output_strides[1] + + i * output_strides[2]]; + const auto& dy = output_grad_buffer[k * output_grad_strides[0] + + j * output_grad_strides[1] + + i * output_grad_strides[2]]; + auto& dx = input_grad_buffer[k * input_grad_strides[0] + + j * input_grad_strides[1] + + i * input_grad_strides[2]]; + dx = y * (dy - _y_dot_dy); + } + } + } + +} + +} // namespace + + + template + template + int + ChannelwiseSoftmax + ::forward(const tensor::Tensor &input_0, + tensor::Tensor &output){ + + + return 1; + } + + template + template + int + ChannelwiseSoftmax + ::backward(const tensor::Tensor &input_0, + const tensor::Tensor &output_grad, + tensor::Tensor &input_grad_0){ + + return 1; + } + +// ========================================================= +// Explicit template instantiation +// ========================================================= +} // namespace distconv +#endif // LBANN_HAS_DISTCONV \ No newline at end of file From 5fc580cddfbae5e71a906be1e3b11d369172395f Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Thu, 26 Jan 2023 10:33:30 -0800 Subject: [PATCH 05/23] Adding forward impl --- .../lbann/layers/misc/channelwise_softmax.hpp | 8 +- .../distconv/distconv_channelwise_softmax.hpp | 6 +- include/lbann/utils/distconv.hpp | 2 +- .../distconv/distconv_channelwise_softmax.cu | 152 ++++++++++++++++++ 4 files changed, 160 insertions(+), 8 deletions(-) diff --git a/include/lbann/layers/misc/channelwise_softmax.hpp b/include/lbann/layers/misc/channelwise_softmax.hpp index 644d93a833a..b167e7ed80a 100644 --- a/include/lbann/layers/misc/channelwise_softmax.hpp +++ b/include/lbann/layers/misc/channelwise_softmax.hpp @@ -41,7 +41,7 @@ namespace lbann { #ifdef LBANN_HAS_DISTCONV -template +template class channelwise_softmax_distconv_adapter : public data_type_distconv_adapter{ public: @@ -55,7 +55,6 @@ class channelwise_softmax_distconv_adapter void setup_layer(size_t workspace_capacity) override; void fp_compute(); void bp_compute(); - dc::Shape get_activations_local_shape(int index=0) const override; std::unique_ptr> m_channelwise_softmax_operator; }; // class definition channelwise_softmax_distconv_adapter @@ -230,7 +229,7 @@ channelwise_softmax_distconv_adapter ::setup_layer(size_t workspace_capacity){ data_type_distconv_adapter::setup_layer(workspace_capacity); - m_channelwise_softmax_operator = std::make_unique>(dc::get_backend()); + m_channelwise_softmax_operator = std::make_unique>(dc::get_backend()); } template @@ -239,6 +238,8 @@ channelwise_softmax_distconv_adapter ::fp_compute(){ auto &layer = dynamic_cast< channelwise_softmax_layer&>(this->layer()); + m_channelwise_softmax_operator->forward(this->get_prev_activations(0), + this->get_activations(0)); } template @@ -247,6 +248,7 @@ channelwise_softmax_distconv_adapter ::bp_compute(){ auto &layer = dynamic_cast< channelwise_softmax_layer&>(this->layer()); + m_channelwise_softmax_operator->backward(); } // ============================================================= // DistConv-enabled Channelwise-Softmax member functions diff --git a/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp b/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp index 896794ee7bc..4fb30d26e51 100644 --- a/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp +++ b/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp @@ -39,15 +39,13 @@ namespace distconv{ template int forward( const tensor::Tensor &input_0, - tensor::Tensor &output; - ); + tensor::Tensor &output); template int backward( const tensor::Tensor &input_0, const tensor::Tensor &output_grad, - tensor::Tensor &input_grad_0, - ); + tensor::Tensor &input_grad_0); protected: Backend &m_be; diff --git a/include/lbann/utils/distconv.hpp b/include/lbann/utils/distconv.hpp index 3e4ebb2cd64..bad629f468f 100644 --- a/include/lbann/utils/distconv.hpp +++ b/include/lbann/utils/distconv.hpp @@ -149,7 +149,7 @@ template using BatchNormalization = ::distconv::BatchNormalization; template using MatMul = ::distconv::MatMul; -template +template using ChannelwiseSoftmax = ::distconv::ChannelwiseSoftmax; using Softmax = ::distconv::Softmax; using CrossEntropy = ::distconv::CrossEntropy; diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu index 88a4a3fb739..81fec85444e 100644 --- a/src/layers/misc/distconv/distconv_channelwise_softmax.cu +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -107,6 +107,50 @@ __global__ void fp_max_kernel( } +/** Compute softmax. + * + * y_i = exp(x_i-shift) / denom + * + * Block dimensions: bdimx x bdimy x bdimz + * + * Grid dimensions: (input_dims[2] / bdimx) x (input_dims[1] / bdimy) x (input_dims[0] / bdimz) + * + * shifts and denoms are fully-packed 2D tensors with dimensions of + * input_dims[0] x input_dims[1]. + */ +template +__global__ void fp_output_kernel( + Size3 input_dims, + const TensorDataType* __restrict__ input_buffer, + Size3 input_strides, + TensorDataType* __restrict__ output_buffer, + Size3 output_strides, + const TensorDataType* __restrict__ shifts, + const TensorDataType* __restrict__ denoms) { + + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + for (size_t k = gidz; k < input_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < input_dims[1]; j += nthreadsy) { + const auto& shift = shifts[j + k*input_dims[1]]; + const auto& denom = denoms[j + k*input_dims[1]]; + for (size_t i = gidx; i < input_dims[2]; i += nthreadsx) { + const auto& x = input_buffer[k * input_strides[0] + + j * input_strides[1] + + i * input_strides[2]]; + auto& y = output_buffer[k * output_strides[0] + + j * output_strides[1] + + i * output_strides[2]]; + y = gpu_lib::exp(x-shift) / denom; + } + } + } +} + } // Namespace @@ -229,7 +273,115 @@ __global__ void bp_input_grad_kernel( ::forward(const tensor::Tensor &input_0, tensor::Tensor &output){ + if (input_0.get_local_size() == 0 || output.get_local_size()){ + return 1; // no op for empty inputs + } + + const auto& input_0_dims = input_0.get_local_shape(); + + const auto num_channels = input_0_dims[2]; + const auto local_mini_batch_size = input_0_dims[3]; + const auto mat_channel_size = input_0_dims[0] * input_0_dims[1]; + const auto mat_stride = num_channels * mat_channel_size; + + // Convert to Hydrogen matrices for kernel launch + + using LocalMat = El::Matrix; + + LocalMat local_input(mat_stride, + local_mini_batch_size, + input_0.get_buffer(), + mat_stride); + + LocalMat local_output(mat_stride, + local_mini_batch_size, + output.get_buffer(), + mat_stride); + { + using namespace hydrogen; + using Size3 = gpu_lib::array; + + auto multisync = MakeMultiSync(El::SyncInfoFromMatrix(local_input), + El::SyncInfoFromMatrix(local_output)); + + LocalMat local_shifts; + if (!local_input.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + LocalMat maxvals(grid_dims.x * num_channels, local_mini_batch_size); + hydrogen::gpu::LaunchKernel( + fp_max_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_input.LockedBuffer(), + Size3{static_cast(local_input.LDim()), channel_size, 1}, + maxvals.Buffer(), + Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); + while (grid_dims.x > 1) { + const size_t prev_dim = grid_dims.x; + grid_dims.x = (prev_dim + block_size - 1) / block_size; + const LocalMat prev_maxvals(std::move(maxvals)); + maxvals.Resize(grid_dims.x * num_channels, local_mini_batch_size); + hydrogen::gpu::LaunchKernel( + fp_max_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, prev_dim}, + prev_maxvals.LockedBuffer(), + Size3{static_cast(prev_maxvals.LDim()), prev_dim, 1}, + maxvals.Buffer(), + Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); + } + local_shifts = std::move(maxvals); + } + // Compute softmax denominators + LocalMat local_denoms(num_channels, local_mini_batch_size); + El::Zero(local_denoms); + if (!local_input.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + hydrogen::gpu::LaunchKernel( + fp_denom_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_input.LockedBuffer(), + Size3{static_cast(local_input.LDim()), channel_size, 1}, + local_shifts.LockedBuffer(), + local_denoms.Buffer()); + } + + // Compute softmax + if (!local_input.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + hydrogen::gpu::LaunchKernel( + fp_output_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_input.LockedBuffer(), + Size3{static_cast(local_input.LDim()), channel_size, 1}, + local_output.Buffer(), + Size3{static_cast(local_output.LDim()), channel_size, 1}, + local_shifts.LockedBuffer(), + local_denoms.LockedBuffer()); + } + + } // namespace hydrogen return 1; } From 3368e4f84fb1ca26b5afffd5152f444b54b3e48b Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Thu, 26 Jan 2023 14:29:25 -0800 Subject: [PATCH 06/23] Moved shareed kernels to channelwise_softmax_kernels.cuh - Updated forward impls for both non-distconv and distconv channelwise_softmax --- .../lbann/layers/misc/channelwise_softmax.hpp | 4 +- .../distconv/distconv_channelwise_softmax.hpp | 4 + src/layers/misc/CMakeLists.txt | 1 + src/layers/misc/channelwise_softmax.cu | 19 + .../misc/channelwise_softmax_kernels.cuh | 295 +++++++++++++++ .../distconv/distconv_channelwise_softmax.cu | 348 ++---------------- 6 files changed, 347 insertions(+), 324 deletions(-) create mode 100644 src/layers/misc/channelwise_softmax_kernels.cuh diff --git a/include/lbann/layers/misc/channelwise_softmax.hpp b/include/lbann/layers/misc/channelwise_softmax.hpp index b167e7ed80a..59fd72f609b 100644 --- a/include/lbann/layers/misc/channelwise_softmax.hpp +++ b/include/lbann/layers/misc/channelwise_softmax.hpp @@ -248,7 +248,9 @@ channelwise_softmax_distconv_adapter ::bp_compute(){ auto &layer = dynamic_cast< channelwise_softmax_layer&>(this->layer()); - m_channelwise_softmax_operator->backward(); + m_channelwise_softmax_operator->backward(this->get_prev_activations(0), + this->get_prev_error_signals(), + this->get_error_signals(0)); } // ============================================================= // DistConv-enabled Channelwise-Softmax member functions diff --git a/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp b/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp index 4fb30d26e51..b039bf09738 100644 --- a/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp +++ b/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp @@ -26,6 +26,7 @@ #ifndef LBANN_LAYERS_MISC_DISTCONV_CHANNELWISE_SOFTMAX #define LBANN_LAYERS_MISC_DISTCONV_CHANNELWISE_SOFTMAX +#include "lbann/utils/distconv.hpp" #ifdef LBANN_HAS_DISTCONV namespace distconv{ @@ -51,6 +52,9 @@ namespace distconv{ Backend &m_be; }; + + extern template class ChannelwiseSoftmax<::distconv::BackendDNNLib, float>; + extern template class ChannelwiseSoftmax<::distconv::BackendDNNLib, double>; } #endif // LBANN_HAS_DISTCONV diff --git a/src/layers/misc/CMakeLists.txt b/src/layers/misc/CMakeLists.txt index 8fac35660cb..0d6940ebefe 100644 --- a/src/layers/misc/CMakeLists.txt +++ b/src/layers/misc/CMakeLists.txt @@ -57,6 +57,7 @@ if (LBANN_HAS_GPU) rowwise_weights_norms.cu uniform_hash.cu variance.cu + channelwise_softmax_kernels.cuh ) if (LBANN_HAS_FFTW) list(APPEND THIS_DIR_CU_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/dft_abs.cu") diff --git a/src/layers/misc/channelwise_softmax.cu b/src/layers/misc/channelwise_softmax.cu index 2f962329997..56cb9b08bbe 100644 --- a/src/layers/misc/channelwise_softmax.cu +++ b/src/layers/misc/channelwise_softmax.cu @@ -27,7 +27,9 @@ #define LBANN_CHANNELWISE_SOFTMAX_LAYER_INSTANTIATE #include "lbann/layers/misc/channelwise_softmax_impl.hpp" #include "lbann/utils/gpu/helpers.hpp" +#include "channelwise_softmax_kernels.cuh" +<<<<<<< HEAD namespace lbann { namespace { @@ -333,6 +335,11 @@ void fp_impl(size_t num_channels, } } // namespace +======= + +namespace lbann { + +>>>>>>> 46c2c7a51 (Moved shareed kernels to channelwise_softmax_kernels.cuh) template void channelwise_softmax_layer::fp_compute() { @@ -344,13 +351,25 @@ void channelwise_softmax_layer::fp_compute() { } #endif // LBANN_HAS_DISTCONV + // Local matrices const size_t num_channels = this->get_output_dims().front(); const size_t channel_size = this->get_output_size() / num_channels; +<<<<<<< HEAD fp_impl(num_channels, channel_size, channel_stride, this->get_prev_activations(), this->get_activations()); +======= + using LocalMat = El::Matrix; + const auto& local_input = dynamic_cast(this->get_prev_activations().LockedMatrix()); + auto& local_output = dynamic_cast(this->get_activations().Matrix()); + + channelwise_softmax_fp_impl(num_channels, + channel_size, + local_input, + local_output); +>>>>>>> 46c2c7a51 (Moved shareed kernels to channelwise_softmax_kernels.cuh) } // ========================================================= diff --git a/src/layers/misc/channelwise_softmax_kernels.cuh b/src/layers/misc/channelwise_softmax_kernels.cuh new file mode 100644 index 00000000000..36d2e189d52 --- /dev/null +++ b/src/layers/misc/channelwise_softmax_kernels.cuh @@ -0,0 +1,295 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// +#ifndef LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS +#define LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS +namespace lbann{ +using Size3 = gpu_lib::array; + +/** @brief Max functor */ +template +struct max_op { + __device__ __forceinline__ + DataType operator()(const T& x1, const T& x2) const { + return gpu_lib::max(x1, x2); + } +}; + +/** @brief Max reduction over last dimension of 3D tensor. + * + * Each CUDA block computes the max over a subset of tensor entries + * in @c vals and outputs the result to @c maxvals. This should be + * repeated multiple times to fully reduce the last tensor dimension. + * + * Block dimensions: bdimx x 1 x 1 + * + * Grid dimensions: (vals_dims[2] / bdimx) x vals_dims[1] x vals_dims[0] + * + * maxvals: vals_dims[0] x vals_dims[1] x (vals_dims[2] / bdimx) + */ +template +__global__ void channelwise_softmax_fp_max_kernel( + Size3 vals_dims, + const TensorDataType* __restrict__ vals_buffer, + Size3 vals_strides, + TensorDataType* __restrict__ maxvals_buffer, + Size3 maxvals_strides) { + + // Indices and dimensions + constexpr size_t bdimy = 1; + constexpr size_t bdimz = 1; + const size_t tid = threadIdx.x; + const size_t bidx = blockIdx.x; + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + + for (size_t k = gidz; k < vals_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < vals_dims[1]; j += nthreadsy) { + + // Find largest value for each thread + TensorDataType maxval{-gpu_lib::infinity()}; + for (size_t i = gidx; i < vals_dims[2]; i += nthreadsx) { + const auto& val = vals_buffer[k * vals_strides[0] + + j * vals_strides[1] + + i * vals_strides[2]]; + maxval = gpu_lib::max(maxval, val); + } + + // Find largest value for each block + maxval = gpu_lib::block_reduce>(maxval); + if (tid == 0) { + const auto& pos = (k * maxvals_strides[0] + + j * maxvals_strides[1] + + bidx * maxvals_strides[2]); + maxvals_buffer[pos] = maxval; + } + + } + } + +} + +/** Compute softmax denominator. + * + * denom = sum( exp(x_i-shift) ) + * + * Block dimensions: bdimx x 1 x 1 + * + * Grid dimensions: (input_dims[2] / bdimx) x input_dims[1] x input_dims[0] + * + * shifts and denoms are fully-packed 2D tensors with dimensions of + * input_dims[0] x input_dims[1]. + */ +template +__global__ void channelwise_softmax_fp_denom_kernel( + Size3 input_dims, + const TensorDataType* __restrict__ input_buffer, + Size3 input_strides, + const TensorDataType* __restrict__ shifts, + TensorDataType* __restrict__ denoms) { + + // Indices and dimensions + constexpr size_t bdimy = 1; + constexpr size_t bdimz = 1; + const size_t tid = threadIdx.x; + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + + for (size_t k = gidz; k < input_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < input_dims[1]; j += nthreadsy) { + + // Compute contribution from each thread + const auto& shift = shifts[j + k*input_dims[1]]; + TensorDataType denom{0.}; + for (size_t i = gidx; i < input_dims[2]; i += nthreadsx) { + const auto& x = input_buffer[k * input_strides[0] + + j * input_strides[1] + + i * input_strides[2]]; + denom += gpu_lib::exp(x-shift); + } + + // Compute contribution from each block + denom = gpu_lib::block_reduce(denom); + if (tid == 0) { + gpu_lib::atomic_add(&denoms[j+k*input_dims[1]], denom); + } + + } + } + +} + +/** Compute softmax. + * + * y_i = exp(x_i-shift) / denom + * + * Block dimensions: bdimx x bdimy x bdimz + * + * Grid dimensions: (input_dims[2] / bdimx) x (input_dims[1] / bdimy) x (input_dims[0] / bdimz) + * + * shifts and denoms are fully-packed 2D tensors with dimensions of + * input_dims[0] x input_dims[1]. + */ +template +__global__ void channelwise_softmax_fp_output_kernel( + Size3 input_dims, + const TensorDataType* __restrict__ input_buffer, + Size3 input_strides, + TensorDataType* __restrict__ output_buffer, + Size3 output_strides, + const TensorDataType* __restrict__ shifts, + const TensorDataType* __restrict__ denoms) { + + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + for (size_t k = gidz; k < input_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < input_dims[1]; j += nthreadsy) { + const auto& shift = shifts[j + k*input_dims[1]]; + const auto& denom = denoms[j + k*input_dims[1]]; + for (size_t i = gidx; i < input_dims[2]; i += nthreadsx) { + const auto& x = input_buffer[k * input_strides[0] + + j * input_strides[1] + + i * input_strides[2]]; + auto& y = output_buffer[k * output_strides[0] + + j * output_strides[1] + + i * output_strides[2]]; + y = gpu_lib::exp(x-shift) / denom; + } + } + } + +} + +/** @brief Forward prop */ +template +void channelwise_softmax_fp_impl(size_t num_channels, + size_t channel_size, + const El::Matrix& local_input, + El::Matrix& local_output) { + + // Local matrices + using LocalMat = El::Matrix; + + auto multisync = El::MakeMultiSync(gpu::get_sync_info(local_output), + gpu::get_sync_info(local_input)); + + // Dimensions + const size_t local_mini_batch_size = local_input.Width(); + // const Size3 input_dims{local_mini_batch_size, num_channels, channel_size}; + + // Compute softmax shifts + LocalMat local_shifts; + if (!local_input.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + LocalMat maxvals(grid_dims.x * num_channels, local_mini_batch_size); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_fp_max_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_input.LockedBuffer(), + Size3{static_cast(local_input.LDim()), channel_size, 1}, + maxvals.Buffer(), + Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); + while (grid_dims.x > 1) { + const size_t prev_dim = grid_dims.x; + grid_dims.x = (prev_dim + block_size - 1) / block_size; + const LocalMat prev_maxvals(std::move(maxvals)); + maxvals.Resize(grid_dims.x * num_channels, local_mini_batch_size); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_fp_max_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, prev_dim}, + prev_maxvals.LockedBuffer(), + Size3{static_cast(prev_maxvals.LDim()), prev_dim, 1}, + maxvals.Buffer(), + Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); + } + local_shifts = std::move(maxvals); + } + + // Compute softmax denominators + LocalMat local_denoms(num_channels, local_mini_batch_size); + El::Zero(local_denoms); + if (!local_input.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_fp_denom_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_input.LockedBuffer(), + Size3{static_cast(local_input.LDim()), channel_size, 1}, + local_shifts.LockedBuffer(), + local_denoms.Buffer()); + } + + // Compute softmax + if (!local_input.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_fp_output_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_input.LockedBuffer(), + Size3{static_cast(local_input.LDim()), channel_size, 1}, + local_output.Buffer(), + Size3{static_cast(local_output.LDim()), channel_size, 1}, + local_shifts.LockedBuffer(), + local_denoms.LockedBuffer()); + } + +} + +} // namespace lbann +#endif // LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu index 81fec85444e..bc0090bdaae 100644 --- a/src/layers/misc/distconv/distconv_channelwise_softmax.cu +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -25,247 +25,15 @@ //////////////////////////////////////////////////////////////////////////////// #define LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_INSTANTIATE +#include "lbann/utils/distconv.hpp" +#include "lbann/base.hpp" +#include "lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp" +#include "lbann/utils/gpu/helpers.hpp" +#include "../channelwise_softmax_kernels.cuh" + #ifdef LBANN_HAS_DISTCONV namespace distconv{ -namespace{ - -using Size3 = gpu_lib::array; - -/** @brief Max functor */ -template -struct max_op { - __device__ __forceinline__ - DataType operator()(const T& x1, const T& x2) const { - return gpu_lib::max(x1, x2); - } -}; - -} // namespace - -// ========================================================= -// Forward prop -// ========================================================= - -namespace { - -/** @brief Max reduction over last dimension of 3D tensor. - * - * Each CUDA block computes the max over a subset of tensor entries - * in @c vals and outputs the result to @c maxvals. This should be - * repeated multiple times to fully reduce the last tensor dimension. - * - * Block dimensions: bdimx x 1 x 1 - * - * Grid dimensions: (vals_dims[2] / bdimx) x vals_dims[1] x vals_dims[0] - * - * maxvals: vals_dims[0] x vals_dims[1] x (vals_dims[2] / bdimx) - */ -template -__global__ void fp_max_kernel( - Size3 vals_dims, - const TensorDataType* __restrict__ vals_buffer, - Size3 vals_strides, - TensorDataType* __restrict__ maxvals_buffer, - Size3 maxvals_strides) { - - // Indices and dimensions - constexpr size_t bdimy = 1; - constexpr size_t bdimz = 1; - const size_t tid = threadIdx.x; - const size_t bidx = blockIdx.x; - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - - for (size_t k = gidz; k < vals_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < vals_dims[1]; j += nthreadsy) { - - // Find largest value for each thread - TensorDataType maxval{-gpu_lib::infinity()}; - for (size_t i = gidx; i < vals_dims[2]; i += nthreadsx) { - const auto& val = vals_buffer[k * vals_strides[0] - + j * vals_strides[1] - + i * vals_strides[2]]; - maxval = gpu_lib::max(maxval, val); - } - - // Find largest value for each block - maxval = gpu_lib::block_reduce>(maxval); - if (tid == 0) { - const auto& pos = (k * maxvals_strides[0] - + j * maxvals_strides[1] - + bidx * maxvals_strides[2]); - maxvals_buffer[pos] = maxval; - } - - } - } - -} - -/** Compute softmax. - * - * y_i = exp(x_i-shift) / denom - * - * Block dimensions: bdimx x bdimy x bdimz - * - * Grid dimensions: (input_dims[2] / bdimx) x (input_dims[1] / bdimy) x (input_dims[0] / bdimz) - * - * shifts and denoms are fully-packed 2D tensors with dimensions of - * input_dims[0] x input_dims[1]. - */ -template -__global__ void fp_output_kernel( - Size3 input_dims, - const TensorDataType* __restrict__ input_buffer, - Size3 input_strides, - TensorDataType* __restrict__ output_buffer, - Size3 output_strides, - const TensorDataType* __restrict__ shifts, - const TensorDataType* __restrict__ denoms) { - - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - for (size_t k = gidz; k < input_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < input_dims[1]; j += nthreadsy) { - const auto& shift = shifts[j + k*input_dims[1]]; - const auto& denom = denoms[j + k*input_dims[1]]; - for (size_t i = gidx; i < input_dims[2]; i += nthreadsx) { - const auto& x = input_buffer[k * input_strides[0] - + j * input_strides[1] - + i * input_strides[2]]; - auto& y = output_buffer[k * output_strides[0] - + j * output_strides[1] - + i * output_strides[2]]; - y = gpu_lib::exp(x-shift) / denom; - } - } - } -} - -} // Namespace - - - -// ========================================================= -// Backprop -// ========================================================= - -namespace { -/** Compute dot product between output and gradient w.r.t. output. - * - * Block dimensions: bdimx x 1 x 1 - * - * Grid dimensions: (output_dims[2] / bdimx) x output_dims[1] x output_dims[0] - * - * y_dot_dy is a fully-packed 2D tensor with dimensions of - * output_dims[0] x output_dims[1]. - */ -template -__global__ void bp_y_dot_dy_kernel( - Size3 output_dims, - const TensorDataType* __restrict__ output_buffer, - Size3 output_strides, - const TensorDataType* __restrict__ output_grad_buffer, - Size3 output_grad_strides, - TensorDataType* __restrict__ y_dot_dy) { - - // Indices and dimensions - constexpr size_t bdimy = 1; - constexpr size_t bdimz = 1; - const size_t tid = threadIdx.x; - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - - for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { - - // Compute contribution from each thread - TensorDataType _y_dot_dy{0.}; - for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { - const auto& y = output_buffer[k * output_strides[0] - + j * output_strides[1] - + i * output_strides[2]]; - const auto& dy = output_grad_buffer[k * output_grad_strides[0] - + j * output_grad_strides[1] - + i * output_grad_strides[2]]; - _y_dot_dy += y * dy; - } - - // Compute contribution from each block - _y_dot_dy = gpu_lib::block_reduce(_y_dot_dy); - if (tid == 0) { - gpu_lib::atomic_add(&y_dot_dy[j+k*output_dims[1]], _y_dot_dy); - } - - } - } - -} - -/** Compute gradient w.r.t. input. - * - * dL/dx_i = y_i * ( dL/dy_i - dot(y,dL/dy) ) - * - * Block dimensions: bdimx x bdimy x bdimz - * - * Grid dimensions: (output_dims[2] / bdimx) x (output_dims[1] / bdimy) x (output_dims[0] / bdimz) - * - * y_dot_dy is a fully-packed 2D tensor with dimensions of - * output_dims[0] x output_dims[1]. - */ -template -__global__ void bp_input_grad_kernel( - Size3 output_dims, - const TensorDataType* __restrict__ output_buffer, - Size3 output_strides, - const TensorDataType* __restrict__ output_grad_buffer, - Size3 output_grad_strides, - TensorDataType* __restrict__ input_grad_buffer, - Size3 input_grad_strides, - const TensorDataType* __restrict__ y_dot_dy) { - - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { - const auto& _y_dot_dy = y_dot_dy[j + k*output_dims[1]]; - for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { - const auto& y = output_buffer[k * output_strides[0] - + j * output_strides[1] - + i * output_strides[2]]; - const auto& dy = output_grad_buffer[k * output_grad_strides[0] - + j * output_grad_strides[1] - + i * output_grad_strides[2]]; - auto& dx = input_grad_buffer[k * input_grad_strides[0] - + j * input_grad_strides[1] - + i * input_grad_strides[2]]; - dx = y * (dy - _y_dot_dy); - } - } - } - -} - -} // namespace - - template template int @@ -297,91 +65,11 @@ __global__ void bp_input_grad_kernel( local_mini_batch_size, output.get_buffer(), mat_stride); - { - using namespace hydrogen; - using Size3 = gpu_lib::array; - - auto multisync = MakeMultiSync(El::SyncInfoFromMatrix(local_input), - El::SyncInfoFromMatrix(local_output)); - - LocalMat local_shifts; - if (!local_input.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - grid_dims.x = (channel_size + block_size - 1) / block_size; - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - LocalMat maxvals(grid_dims.x * num_channels, local_mini_batch_size); - hydrogen::gpu::LaunchKernel( - fp_max_kernel, - grid_dims, block_dims, 0, multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_input.LockedBuffer(), - Size3{static_cast(local_input.LDim()), channel_size, 1}, - maxvals.Buffer(), - Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); - while (grid_dims.x > 1) { - const size_t prev_dim = grid_dims.x; - grid_dims.x = (prev_dim + block_size - 1) / block_size; - const LocalMat prev_maxvals(std::move(maxvals)); - maxvals.Resize(grid_dims.x * num_channels, local_mini_batch_size); - hydrogen::gpu::LaunchKernel( - fp_max_kernel, - grid_dims, block_dims, 0, multisync, - Size3{local_mini_batch_size, num_channels, prev_dim}, - prev_maxvals.LockedBuffer(), - Size3{static_cast(prev_maxvals.LDim()), prev_dim, 1}, - maxvals.Buffer(), - Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); - } - local_shifts = std::move(maxvals); - } - - // Compute softmax denominators - LocalMat local_denoms(num_channels, local_mini_batch_size); - El::Zero(local_denoms); - if (!local_input.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - grid_dims.x = (channel_size + block_size - 1) / block_size; - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - hydrogen::gpu::LaunchKernel( - fp_denom_kernel, - grid_dims, block_dims, 0, multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_input.LockedBuffer(), - Size3{static_cast(local_input.LDim()), channel_size, 1}, - local_shifts.LockedBuffer(), - local_denoms.Buffer()); - } - - // Compute softmax - if (!local_input.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - grid_dims.x = (channel_size + block_size - 1) / block_size; - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - hydrogen::gpu::LaunchKernel( - fp_output_kernel, - grid_dims, block_dims, 0, multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_input.LockedBuffer(), - Size3{static_cast(local_input.LDim()), channel_size, 1}, - local_output.Buffer(), - Size3{static_cast(local_output.LDim()), channel_size, 1}, - local_shifts.LockedBuffer(), - local_denoms.LockedBuffer()); - } - - } // namespace hydrogen + + ::lbann::channelwise_softmax_fp_impl(num_channels, + mat_channel_size, + local_input, + local_output); return 1; } @@ -399,5 +87,19 @@ __global__ void bp_input_grad_kernel( // ========================================================= // Explicit template instantiation // ========================================================= + +#define ETI(T, Backend) \ + template class ChannelwiseSoftmax; \ + template int ChannelwiseSoftmax::forward( \ + const tensor::Tensor &input_0, \ + tensor::Tensor &output_0); \ + template int ChannelwiseSoftmax::backward( \ + const tensor::Tensor &input_0, \ + const tensor::Tensor &input_1, \ + tensor::Tensor &output_grad); + +ETI(float, BackendDNNLib) +ETI(double, BackendDNNLib) +#undef ETI } // namespace distconv #endif // LBANN_HAS_DISTCONV \ No newline at end of file From 569c8309274a1f8919d3f4805e91c9e10e381b1e Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Thu, 26 Jan 2023 14:59:26 -0800 Subject: [PATCH 07/23] Compiling and linking correctly --- src/layers/misc/channelwise_softmax.cu | 521 +----------------- .../misc/channelwise_softmax_kernels.cuh | 175 ++++++ .../distconv/distconv_channelwise_softmax.cu | 35 ++ 3 files changed, 224 insertions(+), 507 deletions(-) diff --git a/src/layers/misc/channelwise_softmax.cu b/src/layers/misc/channelwise_softmax.cu index 56cb9b08bbe..cd5abc3b414 100644 --- a/src/layers/misc/channelwise_softmax.cu +++ b/src/layers/misc/channelwise_softmax.cu @@ -29,317 +29,9 @@ #include "lbann/utils/gpu/helpers.hpp" #include "channelwise_softmax_kernels.cuh" -<<<<<<< HEAD -namespace lbann { - -namespace { - -using Size3 = gpu_lib::array; - -/** @brief Max functor */ -template -struct max_op -{ - __device__ __forceinline__ DataType operator()(const T& x1, const T& x2) const - { - return gpu_lib::max(x1, x2); - } -}; - -} // namespace - -// ========================================================= -// Forward prop -// ========================================================= - -namespace { - -/** @brief Max reduction over last dimension of 3D tensor. - * - * Each CUDA block computes the max over a subset of tensor entries - * in @c vals and outputs the result to @c maxvals. This should be - * repeated multiple times to fully reduce the last tensor dimension. - * - * Block dimensions: bdimx x 1 x 1 - * - * Grid dimensions: (vals_dims[2] / bdimx) x vals_dims[1] x vals_dims[0] - * - * maxvals: vals_dims[0] x vals_dims[1] x (vals_dims[2] / bdimx) - */ -template -__global__ void fp_max_kernel(Size3 vals_dims, - const TensorDataType* __restrict__ vals_buffer, - Size3 vals_strides, - TensorDataType* __restrict__ maxvals_buffer, - Size3 maxvals_strides) -{ - - // Indices and dimensions - constexpr size_t bdimy = 1; - constexpr size_t bdimz = 1; - const size_t tid = threadIdx.x; - const size_t bidx = blockIdx.x; - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - - for (size_t k = gidz; k < vals_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < vals_dims[1]; j += nthreadsy) { - - // Find largest value for each thread - TensorDataType maxval{-gpu_lib::infinity()}; - for (size_t i = gidx; i < vals_dims[2]; i += nthreadsx) { - const auto& val = - vals_buffer[k * vals_strides[0] + j * vals_strides[1] + - i * vals_strides[2]]; - maxval = gpu_lib::max(maxval, val); - } - - // Find largest value for each block - maxval = gpu_lib::block_reduce>(maxval); - if (tid == 0) { - const auto& pos = (k * maxvals_strides[0] + j * maxvals_strides[1] + - bidx * maxvals_strides[2]); - maxvals_buffer[pos] = maxval; - } - } - } -} - -/** Compute softmax denominator. - * - * denom = sum( exp(x_i-shift) ) - * - * Block dimensions: bdimx x 1 x 1 - * - * Grid dimensions: (input_dims[2] / bdimx) x input_dims[1] x input_dims[0] - * - * shifts and denoms are fully-packed 2D tensors with dimensions of - * input_dims[0] x input_dims[1]. - */ -template -__global__ void fp_denom_kernel(Size3 input_dims, - const TensorDataType* __restrict__ input_buffer, - Size3 input_strides, - const TensorDataType* __restrict__ shifts, - TensorDataType* __restrict__ denoms) -{ - - // Indices and dimensions - constexpr size_t bdimy = 1; - constexpr size_t bdimz = 1; - const size_t tid = threadIdx.x; - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - - for (size_t k = gidz; k < input_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < input_dims[1]; j += nthreadsy) { - - // Compute contribution from each thread - const auto& shift = shifts[j + k * input_dims[1]]; - TensorDataType denom{0.}; - for (size_t i = gidx; i < input_dims[2]; i += nthreadsx) { - const auto& x = - input_buffer[k * input_strides[0] + j * input_strides[1] + - i * input_strides[2]]; - denom += gpu_lib::exp(x - shift); - } - - // Compute contribution from each block - denom = gpu_lib::block_reduce(denom); - if (tid == 0) { - if (gridDim.x > 1) - gpu_lib::atomic_add(&denoms[j + k * input_dims[1]], denom); - else - denoms[j + k * input_dims[1]] = denom; - } - } - } -} - -/** Compute softmax. - * - * y_i = exp(x_i-shift) / denom - * - * Block dimensions: bdimx x bdimy x bdimz - * - * Grid dimensions: (input_dims[2] / bdimx) x (input_dims[1] / bdimy) x - * (input_dims[0] / bdimz) - * - * shifts and denoms are fully-packed 2D tensors with dimensions of - * input_dims[0] x input_dims[1]. - */ -template -__global__ void -fp_output_kernel(Size3 input_dims, - const TensorDataType* __restrict__ input_buffer, - Size3 input_strides, - TensorDataType* __restrict__ output_buffer, - Size3 output_strides, - const TensorDataType* __restrict__ shifts, - const TensorDataType* __restrict__ denoms) -{ - - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - for (size_t k = gidz; k < input_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < input_dims[1]; j += nthreadsy) { - const auto& shift = shifts[j + k * input_dims[1]]; - const auto& denom = denoms[j + k * input_dims[1]]; - for (size_t i = gidx; i < input_dims[2]; i += nthreadsx) { - const auto& x = - input_buffer[k * input_strides[0] + j * input_strides[1] + - i * input_strides[2]]; - auto& y = output_buffer[k * output_strides[0] + j * output_strides[1] + - i * output_strides[2]]; - y = gpu_lib::exp(x - shift) / denom; - } - } - } -} - -/** @brief Forward prop */ -template -void fp_impl(size_t num_channels, - size_t channel_size, - size_t channel_stride, - const El::AbstractDistMatrix& input, - El::AbstractDistMatrix& output) -{ - - // Local matrices - using LocalMat = El::Matrix; - const auto& local_input = dynamic_cast(input.LockedMatrix()); - auto& local_output = dynamic_cast(output.Matrix()); - - auto multisync = El::MakeMultiSync(gpu::get_sync_info(local_output), - gpu::get_sync_info(local_input)); - - // Dimensions - const size_t local_mini_batch_size = local_input.Width(); - // const Size3 input_dims{local_mini_batch_size, num_channels, channel_size}; - - // Compute softmax shifts - LocalMat local_shifts; - if (!local_input.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - grid_dims.x = (channel_size + block_size - 1) / block_size; - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - LocalMat maxvals(grid_dims.x * num_channels, local_mini_batch_size); - hydrogen::gpu::LaunchKernel( - fp_max_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_input.LockedBuffer(), - Size3{static_cast(local_input.LDim()), channel_stride, 1}, - maxvals.Buffer(), - Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); - while (grid_dims.x > 1) { - const size_t prev_dim = grid_dims.x; - grid_dims.x = (prev_dim + block_size - 1) / block_size; - const LocalMat prev_maxvals(std::move(maxvals)); - maxvals.Resize(grid_dims.x * num_channels, local_mini_batch_size); - hydrogen::gpu::LaunchKernel( - fp_max_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, prev_dim}, - prev_maxvals.LockedBuffer(), - Size3{static_cast(prev_maxvals.LDim()), prev_dim, 1}, - maxvals.Buffer(), - Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); - } - local_shifts = std::move(maxvals); - } - - // Compute softmax denominators - LocalMat local_denoms(num_channels, local_mini_batch_size); - El::Zero(local_denoms); - if (!local_input.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - - // Simple heuristic to switch between atomic softmax denominator vs. - // sequentially accumulating, block-reducing - int sequential_sum_batch = (channel_size + block_size - 1) / block_size; - // The below threshold value has nothing to do with block size - if (sequential_sum_batch < 256) - grid_dims.x = 1; - else - grid_dims.x = sequential_sum_batch; - - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - hydrogen::gpu::LaunchKernel( - fp_denom_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_input.LockedBuffer(), - Size3{static_cast(local_input.LDim()), channel_stride, 1}, - local_shifts.LockedBuffer(), - local_denoms.Buffer()); - } - - // Compute softmax - if (!local_input.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - grid_dims.x = (channel_size + block_size - 1) / block_size; - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - hydrogen::gpu::LaunchKernel( - fp_output_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_input.LockedBuffer(), - Size3{static_cast(local_input.LDim()), channel_stride, 1}, - local_output.Buffer(), - Size3{static_cast(local_output.LDim()), channel_stride, 1}, - local_shifts.LockedBuffer(), - local_denoms.LockedBuffer()); - } -} - -} // namespace -======= namespace lbann { ->>>>>>> 46c2c7a51 (Moved shareed kernels to channelwise_softmax_kernels.cuh) template void channelwise_softmax_layer::fp_compute() { @@ -354,213 +46,21 @@ void channelwise_softmax_layer::fp_compute() { // Local matrices const size_t num_channels = this->get_output_dims().front(); const size_t channel_size = this->get_output_size() / num_channels; -<<<<<<< HEAD - fp_impl(num_channels, - channel_size, - channel_stride, - this->get_prev_activations(), - this->get_activations()); -======= using LocalMat = El::Matrix; const auto& local_input = dynamic_cast(this->get_prev_activations().LockedMatrix()); auto& local_output = dynamic_cast(this->get_activations().Matrix()); + // TODO: This looks wrong. Maybe wrap the implementation of this function in a namespace - SZ channelwise_softmax_fp_impl(num_channels, channel_size, local_input, local_output); ->>>>>>> 46c2c7a51 (Moved shareed kernels to channelwise_softmax_kernels.cuh) } // ========================================================= // Backprop // ========================================================= -namespace { - -/** Compute dot product between output and gradient w.r.t. output. - * - * Block dimensions: bdimx x 1 x 1 - * - * Grid dimensions: (output_dims[2] / bdimx) x output_dims[1] x output_dims[0] - * - * y_dot_dy is a fully-packed 2D tensor with dimensions of - * output_dims[0] x output_dims[1]. - */ -template -__global__ void -bp_y_dot_dy_kernel(Size3 output_dims, - const TensorDataType* __restrict__ output_buffer, - Size3 output_strides, - const TensorDataType* __restrict__ output_grad_buffer, - Size3 output_grad_strides, - TensorDataType* __restrict__ y_dot_dy) -{ - - // Indices and dimensions - constexpr size_t bdimy = 1; - constexpr size_t bdimz = 1; - const size_t tid = threadIdx.x; - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - - for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { - - // Compute contribution from each thread - TensorDataType _y_dot_dy{0.}; - for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { - const auto& y = - output_buffer[k * output_strides[0] + j * output_strides[1] + - i * output_strides[2]]; - const auto& dy = output_grad_buffer[k * output_grad_strides[0] + - j * output_grad_strides[1] + - i * output_grad_strides[2]]; - _y_dot_dy += y * dy; - } - - // Compute contribution from each block - _y_dot_dy = gpu_lib::block_reduce(_y_dot_dy); - if (tid == 0) { - gpu_lib::atomic_add(&y_dot_dy[j + k * output_dims[1]], _y_dot_dy); - } - } - } -} - -/** Compute gradient w.r.t. input. - * - * dL/dx_i = y_i * ( dL/dy_i - dot(y,dL/dy) ) - * - * Block dimensions: bdimx x bdimy x bdimz - * - * Grid dimensions: (output_dims[2] / bdimx) x (output_dims[1] / bdimy) x - * (output_dims[0] / bdimz) - * - * y_dot_dy is a fully-packed 2D tensor with dimensions of - * output_dims[0] x output_dims[1]. - */ -template -__global__ void -bp_input_grad_kernel(Size3 output_dims, - const TensorDataType* __restrict__ output_buffer, - Size3 output_strides, - const TensorDataType* __restrict__ output_grad_buffer, - Size3 output_grad_strides, - TensorDataType* __restrict__ input_grad_buffer, - Size3 input_grad_strides, - const TensorDataType* __restrict__ y_dot_dy) -{ - - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { - const auto& _y_dot_dy = y_dot_dy[j + k * output_dims[1]]; - for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { - const auto& y = - output_buffer[k * output_strides[0] + j * output_strides[1] + - i * output_strides[2]]; - const auto& dy = output_grad_buffer[k * output_grad_strides[0] + - j * output_grad_strides[1] + - i * output_grad_strides[2]]; - auto& dx = input_grad_buffer[k * input_grad_strides[0] + - j * input_grad_strides[1] + - i * input_grad_strides[2]]; - dx = y * (dy - _y_dot_dy); - } - } - } -} - -/** @brief Backprop */ -template -void bp_impl(size_t num_channels, - size_t channel_size, - size_t channel_stride, - const El::AbstractDistMatrix& output, - const El::AbstractDistMatrix& output_grad, - El::AbstractDistMatrix& input_grad) -{ - - // Local matrices - using LocalMat = El::Matrix; - const auto& local_output = - dynamic_cast(output.LockedMatrix()); - const auto& local_output_grad = - dynamic_cast(output_grad.LockedMatrix()); - auto& local_input_grad = dynamic_cast(input_grad.Matrix()); - - // Dimensions - const size_t local_mini_batch_size = local_output.Width(); - - // dot(y,dL/dy) - LocalMat local_y_dot_dy(num_channels, local_mini_batch_size); - El::Zero(local_y_dot_dy); - - auto multisync = El::MakeMultiSync(gpu::get_sync_info(local_y_dot_dy), - gpu::get_sync_info(local_output_grad), - gpu::get_sync_info(local_output), - gpu::get_sync_info(local_input_grad)); - - if (!local_output.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - grid_dims.x = (channel_size + block_size - 1) / block_size; - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - hydrogen::gpu::LaunchKernel( - bp_y_dot_dy_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_output.LockedBuffer(), - Size3{static_cast(local_output.LDim()), channel_stride, 1}, - local_output_grad.LockedBuffer(), - Size3{static_cast(local_output_grad.LDim()), channel_stride, 1}, - local_y_dot_dy.Buffer()); - } - - // Compute gradient w.r.t. input - if (!local_output.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - grid_dims.x = (channel_size + block_size - 1) / block_size; - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - hydrogen::gpu::LaunchKernel( - bp_input_grad_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_output.LockedBuffer(), - Size3{static_cast(local_output.LDim()), channel_stride, 1}, - local_output_grad.LockedBuffer(), - Size3{static_cast(local_output_grad.LDim()), channel_stride, 1}, - local_input_grad.Buffer(), - Size3{static_cast(local_input_grad.LDim()), channel_stride, 1}, - local_y_dot_dy.LockedBuffer()); - } -} - -} // namespace - template void channelwise_softmax_layer::bp_compute() { @@ -573,12 +73,19 @@ void channelwise_softmax_layer::bp_compute() { const size_t num_channels = this->get_output_dims().front(); const size_t channel_size = this->get_output_size() / num_channels; - bp_impl(num_channels, - channel_size, - channel_stride, - this->get_activations(), - this->get_prev_error_signals(), - this->get_error_signals()); + + + // Local matrices + using LocalMat = El::Matrix; + const auto& local_output = dynamic_cast(this->get_activations().LockedMatrix()); + const auto& local_output_grad = dynamic_cast(this->get_prev_error_signals().LockedMatrix()); + auto& local_input_grad = dynamic_cast(this->get_error_signals().Matrix()); + + channelwise_softmax_bp_impl(num_channels, + channel_size, + local_output, + local_output_grad, + local_input_grad); } // ========================================================= diff --git a/src/layers/misc/channelwise_softmax_kernels.cuh b/src/layers/misc/channelwise_softmax_kernels.cuh index 36d2e189d52..8363c86d296 100644 --- a/src/layers/misc/channelwise_softmax_kernels.cuh +++ b/src/layers/misc/channelwise_softmax_kernels.cuh @@ -37,6 +37,10 @@ struct max_op { } }; +// ========================================================= +// Forward prop +// ========================================================= + /** @brief Max reduction over last dimension of 3D tensor. * * Each CUDA block computes the max over a subset of tensor entries @@ -291,5 +295,176 @@ void channelwise_softmax_fp_impl(size_t num_channels, } +// ========================================================= +// Backprop +// ========================================================= + +/** Compute dot product between output and gradient w.r.t. output. + * + * Block dimensions: bdimx x 1 x 1 + * + * Grid dimensions: (output_dims[2] / bdimx) x output_dims[1] x output_dims[0] + * + * y_dot_dy is a fully-packed 2D tensor with dimensions of + * output_dims[0] x output_dims[1]. + */ +template +__global__ void channelwise_softmax_bp_y_dot_dy_kernel( + Size3 output_dims, + const TensorDataType* __restrict__ output_buffer, + Size3 output_strides, + const TensorDataType* __restrict__ output_grad_buffer, + Size3 output_grad_strides, + TensorDataType* __restrict__ y_dot_dy) { + + // Indices and dimensions + constexpr size_t bdimy = 1; + constexpr size_t bdimz = 1; + const size_t tid = threadIdx.x; + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + + for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { + + // Compute contribution from each thread + TensorDataType _y_dot_dy{0.}; + for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { + const auto& y = output_buffer[k * output_strides[0] + + j * output_strides[1] + + i * output_strides[2]]; + const auto& dy = output_grad_buffer[k * output_grad_strides[0] + + j * output_grad_strides[1] + + i * output_grad_strides[2]]; + _y_dot_dy += y * dy; + } + + // Compute contribution from each block + _y_dot_dy = gpu_lib::block_reduce(_y_dot_dy); + if (tid == 0) { + gpu_lib::atomic_add(&y_dot_dy[j+k*output_dims[1]], _y_dot_dy); + } + + } + } + +} + +/** Compute gradient w.r.t. input. + * + * dL/dx_i = y_i * ( dL/dy_i - dot(y,dL/dy) ) + * + * Block dimensions: bdimx x bdimy x bdimz + * + * Grid dimensions: (output_dims[2] / bdimx) x (output_dims[1] / bdimy) x (output_dims[0] / bdimz) + * + * y_dot_dy is a fully-packed 2D tensor with dimensions of + * output_dims[0] x output_dims[1]. + */ +template +__global__ void channelwise_softmax_bp_input_grad_kernel( + Size3 output_dims, + const TensorDataType* __restrict__ output_buffer, + Size3 output_strides, + const TensorDataType* __restrict__ output_grad_buffer, + Size3 output_grad_strides, + TensorDataType* __restrict__ input_grad_buffer, + Size3 input_grad_strides, + const TensorDataType* __restrict__ y_dot_dy) { + + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { + const auto& _y_dot_dy = y_dot_dy[j + k*output_dims[1]]; + for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { + const auto& y = output_buffer[k * output_strides[0] + + j * output_strides[1] + + i * output_strides[2]]; + const auto& dy = output_grad_buffer[k * output_grad_strides[0] + + j * output_grad_strides[1] + + i * output_grad_strides[2]]; + auto& dx = input_grad_buffer[k * input_grad_strides[0] + + j * input_grad_strides[1] + + i * input_grad_strides[2]]; + dx = y * (dy - _y_dot_dy); + } + } + } + +} + + +/** @brief Backprop */ +template +void channelwise_softmax_bp_impl(size_t num_channels, + size_t channel_size, + const El::Matrix& local_output, + const El::Matrix& local_output_grad, + El::Matrix& local_input_grad) { + + // Dimensions + const size_t local_mini_batch_size = local_output.Width(); + using LocalMat = El::Matrix; + // dot(y,dL/dy) + LocalMat local_y_dot_dy(num_channels, local_mini_batch_size); + El::Zero(local_y_dot_dy); + + auto multisync = El::MakeMultiSync(gpu::get_sync_info(local_y_dot_dy), + gpu::get_sync_info(local_output_grad), + gpu::get_sync_info(local_output), + gpu::get_sync_info(local_input_grad)); + + if (!local_output.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_bp_y_dot_dy_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_output.LockedBuffer(), + Size3{static_cast(local_output.LDim()), channel_size, 1}, + local_output_grad.LockedBuffer(), + Size3{static_cast(local_output_grad.LDim()), channel_size, 1}, + local_y_dot_dy.Buffer()); + } + + // Compute gradient w.r.t. input + if (!local_output.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_bp_input_grad_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_output.LockedBuffer(), + Size3{static_cast(local_output.LDim()), channel_size, 1}, + local_output_grad.LockedBuffer(), + Size3{static_cast(local_output_grad.LDim()), channel_size, 1}, + local_input_grad.Buffer(), + Size3{static_cast(local_input_grad.LDim()), channel_size, 1}, + local_y_dot_dy.LockedBuffer()); + } + +} + } // namespace lbann #endif // LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu index bc0090bdaae..4959ba749cd 100644 --- a/src/layers/misc/distconv/distconv_channelwise_softmax.cu +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -80,7 +80,42 @@ namespace distconv{ ::backward(const tensor::Tensor &input_0, const tensor::Tensor &output_grad, tensor::Tensor &input_grad_0){ + if (input_0.get_local_size() == 0 || + output_grad.get_local_size() == 0 || + input_grad_0.get_local_size() == 0){ + return 1; // no op for empty inputs + } + + const auto& input_0_dims = input_0.get_local_shape(); + const auto num_channels = input_0_dims[2]; + const auto local_mini_batch_size = input_0_dims[3]; + const auto mat_channel_size = input_0_dims[0] * input_0_dims[1]; + const auto mat_stride = num_channels * mat_channel_size; + + // Convert to Hydrogen matrices for kernel launch + + using LocalMat = El::Matrix; + + LocalMat local_input(mat_stride, + local_mini_batch_size, + input_0.get_buffer(), + mat_stride); + + LocalMat local_output_grad(mat_stride, + local_mini_batch_size, + output_grad.get_buffer(), + mat_stride); + + LocalMat local_input_grad(mat_stride, + local_mini_batch_size, + input_grad_0.get_buffer(), + mat_stride); + ::lbann::channelwise_softmax_bp_impl(num_channels, + mat_channel_size, + local_input, + local_output_grad, + local_input_grad); return 1; } From 3a4303743415f5d48d2d9280e03ee690e2d41007 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Sat, 28 Jan 2023 01:58:23 -0800 Subject: [PATCH 08/23] Updated ci test to test split case - CI test failing --- ...est_unit_layer_channelwise_softmax_distconv.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py index b6f9d4e77e4..f9123a1c2e9 100644 --- a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py +++ b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py @@ -20,7 +20,7 @@ # Data np.random.seed(20200115) _num_samples = 15 -_sample_dims = (5,2,7) +_sample_dims = (15,7,1) _sample_size = functools.reduce(operator.mul, _sample_dims) _samples = np.random.normal(loc=0.5, size=(_num_samples,_sample_size)).astype(np.float32) @@ -62,6 +62,10 @@ def setup_experiment(lbann, weekly): optimizer = lbann.NoOptimizer() return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes +def create_parallel_strategy(num_channel_groups): + return {"channel_groups": num_channel_groups, + "filter_groups": num_channel_groups} + def construct_model(lbann): """Construct LBANN model. @@ -87,13 +91,20 @@ def construct_model(lbann): metrics = [] callbacks = [] + num_channel_groups = tools.gpus_per_node(lbann) + if num_channel_groups == 0: + e = 'this test requires GPUs.' + print('Skip - ' + e) + pytest.skip(e) + # ------------------------------------------ # Data-parallel layout # ------------------------------------------ # LBANN implementation x = x_lbann - y = lbann.ChannelwiseSoftmax(x, data_layout='data_parallel') + y = lbann.ChannelwiseSoftmax(x, + parallel_strategy=create_parallel_strategy(num_channel_groups),) z = lbann.L2Norm2(y) obj.append(z) metrics.append(lbann.Metric(z, name='data-parallel layout')) From 52399f19d4fec6d6de861dcd42b1d7f7d72144d3 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Sat, 28 Jan 2023 02:24:05 -0800 Subject: [PATCH 09/23] Adding some debug code to see why output is always --- src/layers/misc/distconv/distconv_channelwise_softmax.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu index 4959ba749cd..fae07e669b3 100644 --- a/src/layers/misc/distconv/distconv_channelwise_softmax.cu +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -52,6 +52,7 @@ namespace distconv{ const auto mat_channel_size = input_0_dims[0] * input_0_dims[1]; const auto mat_stride = num_channels * mat_channel_size; + util::MPIRootPrintStreamInfo()<< "Num channels: \t" << num_channels << "\t MB size: \t" << local_mini_batch_size; // Convert to Hydrogen matrices for kernel launch using LocalMat = El::Matrix; @@ -67,9 +68,9 @@ namespace distconv{ mat_stride); ::lbann::channelwise_softmax_fp_impl(num_channels, - mat_channel_size, - local_input, - local_output); + mat_channel_size, + local_input, + local_output); return 1; } From cfd768d31cea44d4f0029ae02022ab574ec5619f Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Sun, 29 Jan 2023 23:27:36 -0800 Subject: [PATCH 10/23] Passing forward pass on CI --- .../test_unit_layer_channelwise_softmax_distconv.py | 13 +++++++------ src/layers/distconv_adapter.cpp | 7 ++++--- src/layers/misc/CMakeLists.txt | 8 ++++---- .../misc/distconv/distconv_channelwise_softmax.cu | 4 ++-- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py index f9123a1c2e9..abc0ee725b6 100644 --- a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py +++ b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py @@ -20,7 +20,7 @@ # Data np.random.seed(20200115) _num_samples = 15 -_sample_dims = (15,7,1) +_sample_dims = (15,36,1) _sample_size = functools.reduce(operator.mul, _sample_dims) _samples = np.random.normal(loc=0.5, size=(_num_samples,_sample_size)).astype(np.float32) @@ -85,8 +85,6 @@ def construct_model(lbann): lbann.WeightsLayer(weights=x_weights, dims=_sample_dims)) x_lbann = x - - # Objects for LBANN model obj = [] metrics = [] callbacks = [] @@ -103,8 +101,10 @@ def construct_model(lbann): # LBANN implementation x = x_lbann + y = lbann.ChannelwiseSoftmax(x, - parallel_strategy=create_parallel_strategy(num_channel_groups),) + parallel_strategy=create_parallel_strategy(num_channel_groups), + name="Channelwise_softmax_distconv") z = lbann.L2Norm2(y) obj.append(z) metrics.append(lbann.Metric(z, name='data-parallel layout')) @@ -129,8 +129,9 @@ def construct_model(lbann): # Gradient checking # ------------------------------------------ - callbacks.append(lbann.CallbackCheckGradients(error_on_failure=True)) - + # callbacks.append(lbann.CallbackCheckGradients(error_on_failure=True)) + callbacks.append(lbann.CallbackDumpOutputs(layers="Channelwise_softmax_distconv", + directory=f"{os.path.dirname(os.path.realpath(__file__))}")) # ------------------------------------------ # Construct model # ------------------------------------------ diff --git a/src/layers/distconv_adapter.cpp b/src/layers/distconv_adapter.cpp index b4fe2337820..b91ca6a516c 100644 --- a/src/layers/distconv_adapter.cpp +++ b/src/layers/distconv_adapter.cpp @@ -263,9 +263,10 @@ void distconv_adapter::adjust_parallel_strategy() } } - else if (layer_type == "channel-wise fully-connected" || - layer_type == "matmul") { - if (c != f) { + else if (layer_type == "channel-wise fully-connected" + || layer_type == "matmul" + || layer_type == "channel-wise softmax"){ + if (c != f){ if (layer().get_comm()->am_trainer_master()) { LBANN_WARNING("The number of channel and filter decomposition should " "be the same. Setting", diff --git a/src/layers/misc/CMakeLists.txt b/src/layers/misc/CMakeLists.txt index 0d6940ebefe..14cd1945c3f 100644 --- a/src/layers/misc/CMakeLists.txt +++ b/src/layers/misc/CMakeLists.txt @@ -38,7 +38,6 @@ set_full_path(THIS_DIR_SOURCES rowwise_weights_norms.cpp uniform_hash.cpp variance.cpp - misc_builders.cpp ) @@ -64,13 +63,14 @@ if (LBANN_HAS_GPU) endif () endif () -# Add the subdirectories -add_subdirectory(cereal_registration) - if (LBANN_HAS_DISTCONV) add_subdirectory(distconv) endif() +# Add the subdirectories +add_subdirectory(cereal_registration) + + # Propagate the files up the tree set(SOURCES "${SOURCES}" "${THIS_DIR_SOURCES}" PARENT_SCOPE) set(GPU_SOURCES "${GPU_SOURCES}" "${THIS_DIR_CU_SOURCES}" PARENT_SCOPE) diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu index fae07e669b3..5c5db7ec4ad 100644 --- a/src/layers/misc/distconv/distconv_channelwise_softmax.cu +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -41,7 +41,8 @@ namespace distconv{ ::forward(const tensor::Tensor &input_0, tensor::Tensor &output){ - if (input_0.get_local_size() == 0 || output.get_local_size()){ + if (input_0.get_local_size() == 0 || output.get_local_size() == 0){ + util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; return 1; // no op for empty inputs } @@ -52,7 +53,6 @@ namespace distconv{ const auto mat_channel_size = input_0_dims[0] * input_0_dims[1]; const auto mat_stride = num_channels * mat_channel_size; - util::MPIRootPrintStreamInfo()<< "Num channels: \t" << num_channels << "\t MB size: \t" << local_mini_batch_size; // Convert to Hydrogen matrices for kernel launch using LocalMat = El::Matrix; From ce7896fb4e85bf71f9751748a4f428b82f97209a Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Sun, 29 Jan 2023 23:50:26 -0800 Subject: [PATCH 11/23] Passing CI tests --- ...unit_layer_channelwise_softmax_distconv.py | 188 ------------------ .../lbann/layers/misc/channelwise_softmax.hpp | 2 +- .../distconv/distconv_channelwise_softmax.cu | 7 +- 3 files changed, 5 insertions(+), 192 deletions(-) diff --git a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py index abc0ee725b6..e69de29bb2d 100644 --- a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py +++ b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py @@ -1,188 +0,0 @@ -import functools -import operator -import os -import os.path -import sys -import numpy as np - -# Bamboo utilities -current_file = os.path.realpath(__file__) -current_dir = os.path.dirname(current_file) -sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python')) -import tools - -# ============================================== -# Objects for Python data reader -# ============================================== -# Note: The Python data reader imports this file as a module and calls -# the functions below to ingest data. - -# Data -np.random.seed(20200115) -_num_samples = 15 -_sample_dims = (15,36,1) -_sample_size = functools.reduce(operator.mul, _sample_dims) -_samples = np.random.normal(loc=0.5, size=(_num_samples,_sample_size)).astype(np.float32) - -# Sample access functions -def get_sample(index): - return _samples[index,:] -def num_samples(): - return _num_samples -def sample_dims(): - return (_sample_size,) - -# ============================================== -# NumPy implementation -# ============================================== - -def numpy_channelwise_softmax(x): - if x.dtype is not np.float64: - x = x.astype(np.float64) - axis = tuple(range(1,x.ndim)) - shift = np.max(x, axis=axis, keepdims=True) - y = np.exp(x-shift) - return y / np.sum(y, axis=axis, keepdims=True) - -# ============================================== -# Setup LBANN experiment -# ============================================== - -def setup_experiment(lbann, weekly): - """Construct LBANN experiment. - - Args: - lbann (module): Module for LBANN Python frontend - - """ - mini_batch_size = num_samples() // 2 - trainer = lbann.Trainer(mini_batch_size) - model = construct_model(lbann) - data_reader = construct_data_reader(lbann) - optimizer = lbann.NoOptimizer() - return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes - -def create_parallel_strategy(num_channel_groups): - return {"channel_groups": num_channel_groups, - "filter_groups": num_channel_groups} - -def construct_model(lbann): - """Construct LBANN model. - - Args: - lbann (module): Module for LBANN Python frontend - - """ - - # Input data - # Note: Sum with a weights layer so that gradient checking will - # verify that error signals are correct. - x_weights = lbann.Weights(optimizer=lbann.SGD(), - initializer=lbann.ConstantInitializer(value=0.0), - name='input_weights') - x = lbann.Sum(lbann.Reshape(lbann.Input(data_field='samples'), - dims=_sample_dims), - lbann.WeightsLayer(weights=x_weights, - dims=_sample_dims)) - x_lbann = x - obj = [] - metrics = [] - callbacks = [] - - num_channel_groups = tools.gpus_per_node(lbann) - if num_channel_groups == 0: - e = 'this test requires GPUs.' - print('Skip - ' + e) - pytest.skip(e) - - # ------------------------------------------ - # Data-parallel layout - # ------------------------------------------ - - # LBANN implementation - x = x_lbann - - y = lbann.ChannelwiseSoftmax(x, - parallel_strategy=create_parallel_strategy(num_channel_groups), - name="Channelwise_softmax_distconv") - z = lbann.L2Norm2(y) - obj.append(z) - metrics.append(lbann.Metric(z, name='data-parallel layout')) - - # NumPy implementation - vals = [] - for i in range(num_samples()): - x = get_sample(i).reshape(_sample_dims).astype(np.float64) - y = numpy_channelwise_softmax(x) - z = tools.numpy_l2norm2(y) - vals.append(z) - val = np.mean(vals) - tol = 8 * val * np.finfo(np.float32).eps - callbacks.append(lbann.CallbackCheckMetric( - metric=metrics[-1].name, - lower_bound=val-tol, - upper_bound=val+tol, - error_on_failure=True, - execution_modes='test')) - - # ------------------------------------------ - # Gradient checking - # ------------------------------------------ - - # callbacks.append(lbann.CallbackCheckGradients(error_on_failure=True)) - callbacks.append(lbann.CallbackDumpOutputs(layers="Channelwise_softmax_distconv", - directory=f"{os.path.dirname(os.path.realpath(__file__))}")) - # ------------------------------------------ - # Construct model - # ------------------------------------------ - - num_epochs = 0 - return lbann.Model(num_epochs, - layers=lbann.traverse_layer_graph(x_lbann), - objective_function=obj, - metrics=metrics, - callbacks=callbacks) - -def construct_data_reader(lbann): - """Construct Protobuf message for Python data reader. - - The Python data reader will import the current Python file to - access the sample access functions. - - Args: - lbann (module): Module for LBANN Python frontend - - """ - - # Note: The training data reader should be removed when - # https://github.com/LLNL/lbann/issues/1098 is resolved. - message = lbann.reader_pb2.DataReader() - message.reader.extend([ - tools.create_python_data_reader( - lbann, - current_file, - 'get_sample', - 'num_samples', - 'sample_dims', - 'train' - ) - ]) - message.reader.extend([ - tools.create_python_data_reader( - lbann, - current_file, - 'get_sample', - 'num_samples', - 'sample_dims', - 'test' - ) - ]) - return message - -# ============================================== -# Setup PyTest -# ============================================== - -# Create test functions that can interact with PyTest -for _test_func in tools.create_tests(setup_experiment, __file__): - globals()[_test_func.__name__] = _test_func diff --git a/include/lbann/layers/misc/channelwise_softmax.hpp b/include/lbann/layers/misc/channelwise_softmax.hpp index 59fd72f609b..b36d1acf226 100644 --- a/include/lbann/layers/misc/channelwise_softmax.hpp +++ b/include/lbann/layers/misc/channelwise_softmax.hpp @@ -248,7 +248,7 @@ channelwise_softmax_distconv_adapter ::bp_compute(){ auto &layer = dynamic_cast< channelwise_softmax_layer&>(this->layer()); - m_channelwise_softmax_operator->backward(this->get_prev_activations(0), + m_channelwise_softmax_operator->backward(this->get_activations(0), this->get_prev_error_signals(), this->get_error_signals(0)); } diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu index 5c5db7ec4ad..f78c438961e 100644 --- a/src/layers/misc/distconv/distconv_channelwise_softmax.cu +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -84,6 +84,7 @@ namespace distconv{ if (input_0.get_local_size() == 0 || output_grad.get_local_size() == 0 || input_grad_0.get_local_size() == 0){ + util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; return 1; // no op for empty inputs } @@ -98,9 +99,9 @@ namespace distconv{ using LocalMat = El::Matrix; LocalMat local_input(mat_stride, - local_mini_batch_size, - input_0.get_buffer(), - mat_stride); + local_mini_batch_size, + input_0.get_buffer(), + mat_stride); LocalMat local_output_grad(mat_stride, local_mini_batch_size, From af77dcc1121fb1810144d674dda8f33e9edc49b7 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 30 Jan 2023 00:15:07 -0800 Subject: [PATCH 12/23] - Added model compile-time checks on the shape of the input when distconv is enabled - Updated ReleaseNotes --- ...unit_layer_channelwise_softmax_distconv.py | 187 ++++++++++++++++++ .../lbann/layers/misc/channelwise_softmax.hpp | 18 ++ 2 files changed, 205 insertions(+) diff --git a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py index e69de29bb2d..d67c14316a6 100644 --- a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py +++ b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py @@ -0,0 +1,187 @@ +import functools +import operator +import os +import os.path +import sys +import numpy as np + +# Bamboo utilities +current_file = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file) +sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python')) +import tools + +# ============================================== +# Objects for Python data reader +# ============================================== +# Note: The Python data reader imports this file as a module and calls +# the functions below to ingest data. + +# Data +np.random.seed(20200115) +_num_samples = 15 +_sample_dims = (15,36,1) +_sample_size = functools.reduce(operator.mul, _sample_dims) +_samples = np.random.normal(loc=0.5, size=(_num_samples,_sample_size)).astype(np.float32) + +# Sample access functions +def get_sample(index): + return _samples[index,:] +def num_samples(): + return _num_samples +def sample_dims(): + return (_sample_size,) + +# ============================================== +# NumPy implementation +# ============================================== + +def numpy_channelwise_softmax(x): + if x.dtype is not np.float64: + x = x.astype(np.float64) + axis = tuple(range(1,x.ndim)) + shift = np.max(x, axis=axis, keepdims=True) + y = np.exp(x-shift) + return y / np.sum(y, axis=axis, keepdims=True) + +# ============================================== +# Setup LBANN experiment +# ============================================== + +def setup_experiment(lbann, weekly): + """Construct LBANN experiment. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + mini_batch_size = num_samples() // 2 + trainer = lbann.Trainer(mini_batch_size) + model = construct_model(lbann) + data_reader = construct_data_reader(lbann) + optimizer = lbann.NoOptimizer() + return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes + +def create_parallel_strategy(num_channel_groups): + return {"channel_groups": num_channel_groups, + "filter_groups": num_channel_groups} + +def construct_model(lbann): + """Construct LBANN model. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + # Input data + # Note: Sum with a weights layer so that gradient checking will + # verify that error signals are correct. + x_weights = lbann.Weights(optimizer=lbann.SGD(), + initializer=lbann.ConstantInitializer(value=0.0), + name='input_weights') + x = lbann.Sum(lbann.Reshape(lbann.Input(data_field='samples'), + dims=_sample_dims), + lbann.WeightsLayer(weights=x_weights, + dims=_sample_dims)) + x_lbann = x + obj = [] + metrics = [] + callbacks = [] + + num_channel_groups = tools.gpus_per_node(lbann) + if num_channel_groups == 0: + e = 'this test requires GPUs.' + print('Skip - ' + e) + pytest.skip(e) + + # ------------------------------------------ + # Data-parallel layout + # ------------------------------------------ + + # LBANN implementation + x = x_lbann + + y = lbann.ChannelwiseSoftmax(x, + parallel_strategy=create_parallel_strategy(num_channel_groups), + name="Channelwise_softmax_distconv") + z = lbann.L2Norm2(y) + obj.append(z) + metrics.append(lbann.Metric(z, name='data-parallel layout')) + + # NumPy implementation + vals = [] + for i in range(num_samples()): + x = get_sample(i).reshape(_sample_dims).astype(np.float64) + y = numpy_channelwise_softmax(x) + z = tools.numpy_l2norm2(y) + vals.append(z) + val = np.mean(vals) + tol = 8 * val * np.finfo(np.float32).eps + callbacks.append(lbann.CallbackCheckMetric( + metric=metrics[-1].name, + lower_bound=val-tol, + upper_bound=val+tol, + error_on_failure=True, + execution_modes='test')) + + # ------------------------------------------ + # Gradient checking + # ------------------------------------------ + + callbacks.append(lbann.CallbackCheckGradients(error_on_failure=True)) + + # ------------------------------------------ + # Construct model + # ------------------------------------------ + + num_epochs = 0 + return lbann.Model(num_epochs, + layers=lbann.traverse_layer_graph(x_lbann), + objective_function=obj, + metrics=metrics, + callbacks=callbacks) + +def construct_data_reader(lbann): + """Construct Protobuf message for Python data reader. + + The Python data reader will import the current Python file to + access the sample access functions. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + # Note: The training data reader should be removed when + # https://github.com/LLNL/lbann/issues/1098 is resolved. + message = lbann.reader_pb2.DataReader() + message.reader.extend([ + tools.create_python_data_reader( + lbann, + current_file, + 'get_sample', + 'num_samples', + 'sample_dims', + 'train' + ) + ]) + message.reader.extend([ + tools.create_python_data_reader( + lbann, + current_file, + 'get_sample', + 'num_samples', + 'sample_dims', + 'test' + ) + ]) + return message + +# ============================================== +# Setup PyTest +# ============================================== + +# Create test functions that can interact with PyTest +for _test_func in tools.create_tests(setup_experiment, __file__): + globals()[_test_func.__name__] = _test_func diff --git a/include/lbann/layers/misc/channelwise_softmax.hpp b/include/lbann/layers/misc/channelwise_softmax.hpp index b36d1acf226..bda2502653b 100644 --- a/include/lbann/layers/misc/channelwise_softmax.hpp +++ b/include/lbann/layers/misc/channelwise_softmax.hpp @@ -188,6 +188,24 @@ template void channelwise_softmax_layer::setup_dims(DataReaderMetaData& dr_metadata) { data_type_layer::setup_dims(dr_metadata); this->set_output_dims(this->get_input_dims()); + + #ifdef LBANN_HAS_DISTCONV + + if (this->distconv_enabled()){ + // Additional checks when distconv mode is enabled + const auto& input_dims = this->get_input_dims(); + const auto& output_dims = this->get_output_dims(); + + if (input_dims.size() != 3 || output_dims.size() != 3){ + LBANN_ERROR(this->get_type()," layer \"",this->get_name(),"\" ", + "expects an input and output tensor with 3 dimensions (channel, *, *), " + "but it has been configured as a ", + input_dims.size(), "-D input tensor and ", + output_dims.size(),"-D output tensor"); + } + } + + #endif } #ifdef LBANN_HAS_DISTCONV From a2b176047077845dead7fdc61e4131e725464378 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 30 Jan 2023 00:41:33 -0800 Subject: [PATCH 13/23] Strange behavior on CI. Every couple of gradient checks fail... --- .../test_unit_layer_channelwise_softmax_distconv.py | 5 +++-- include/lbann/layers/misc/channelwise_softmax.hpp | 3 +-- .../misc/distconv/distconv_channelwise_softmax.cu | 12 ++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py index d67c14316a6..60f8369ac0d 100644 --- a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py +++ b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py @@ -20,7 +20,7 @@ # Data np.random.seed(20200115) _num_samples = 15 -_sample_dims = (15,36,1) +_sample_dims = (15,5,1) _sample_size = functools.reduce(operator.mul, _sample_dims) _samples = np.random.normal(loc=0.5, size=(_num_samples,_sample_size)).astype(np.float32) @@ -103,11 +103,12 @@ def construct_model(lbann): x = x_lbann y = lbann.ChannelwiseSoftmax(x, + data_layout='data_parallel', parallel_strategy=create_parallel_strategy(num_channel_groups), name="Channelwise_softmax_distconv") z = lbann.L2Norm2(y) obj.append(z) - metrics.append(lbann.Metric(z, name='data-parallel layout')) + metrics.append(lbann.Metric(z, name='channelwise split distconv')) # NumPy implementation vals = [] diff --git a/include/lbann/layers/misc/channelwise_softmax.hpp b/include/lbann/layers/misc/channelwise_softmax.hpp index bda2502653b..ce62d4c65c0 100644 --- a/include/lbann/layers/misc/channelwise_softmax.hpp +++ b/include/lbann/layers/misc/channelwise_softmax.hpp @@ -204,8 +204,7 @@ void channelwise_softmax_layer::setup_dims(DataRea output_dims.size(),"-D output tensor"); } } - - #endif + #endif // LBANN_HAS_DISTCONV } #ifdef LBANN_HAS_DISTCONV diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu index f78c438961e..98c458eaeb9 100644 --- a/src/layers/misc/distconv/distconv_channelwise_softmax.cu +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -78,17 +78,17 @@ namespace distconv{ template int ChannelwiseSoftmax - ::backward(const tensor::Tensor &input_0, + ::backward(const tensor::Tensor &output, const tensor::Tensor &output_grad, tensor::Tensor &input_grad_0){ - if (input_0.get_local_size() == 0 || + if (output.get_local_size() == 0 || output_grad.get_local_size() == 0 || input_grad_0.get_local_size() == 0){ util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; return 1; // no op for empty inputs } - const auto& input_0_dims = input_0.get_local_shape(); + const auto& input_0_dims = output.get_local_shape(); const auto num_channels = input_0_dims[2]; const auto local_mini_batch_size = input_0_dims[3]; const auto mat_channel_size = input_0_dims[0] * input_0_dims[1]; @@ -98,9 +98,9 @@ namespace distconv{ using LocalMat = El::Matrix; - LocalMat local_input(mat_stride, + LocalMat local_output(mat_stride, local_mini_batch_size, - input_0.get_buffer(), + output.get_buffer(), mat_stride); LocalMat local_output_grad(mat_stride, @@ -115,7 +115,7 @@ namespace distconv{ ::lbann::channelwise_softmax_bp_impl(num_channels, mat_channel_size, - local_input, + local_output, local_output_grad, local_input_grad); return 1; From 281aafdb70a0a6b3c0f50c82ae223a04706c7864 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 30 Jan 2023 12:21:52 -0800 Subject: [PATCH 14/23] Passing CI tests --- .../unit_tests/test_unit_layer_channelwise_softmax_distconv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py index 60f8369ac0d..0db96d05d58 100644 --- a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py +++ b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py @@ -184,5 +184,5 @@ def construct_data_reader(lbann): # ============================================== # Create test functions that can interact with PyTest -for _test_func in tools.create_tests(setup_experiment, __file__): +for _test_func in tools.create_tests(setup_experiment, __file__, environment=tools.get_distconv_environment()): globals()[_test_func.__name__] = _test_func From 4930744aacece809e0385795df0123ae5d79ba25 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Sun, 9 Jun 2024 22:16:06 -0700 Subject: [PATCH 15/23] Updated implementation to incorporate updated channelwise softmax API - Currently building and linking --- .../lbann/layers/misc/channelwise_softmax.hpp | 186 +++++++++--------- .../layers/misc/channelwise_softmax_impl.hpp | 22 +++ include/lbann/utils/distconv.hpp | 30 +-- 3 files changed, 122 insertions(+), 116 deletions(-) diff --git a/include/lbann/layers/misc/channelwise_softmax.hpp b/include/lbann/layers/misc/channelwise_softmax.hpp index ce62d4c65c0..9168bb60fac 100644 --- a/include/lbann/layers/misc/channelwise_softmax.hpp +++ b/include/lbann/layers/misc/channelwise_softmax.hpp @@ -37,29 +37,37 @@ #include "lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp" #endif - namespace lbann { #ifdef LBANN_HAS_DISTCONV +namespace dc { +template +using ChannelwiseSoftmax = + ::distconv::ChannelwiseSoftmax; +} // namespace dc + template class channelwise_softmax_distconv_adapter - : public data_type_distconv_adapter{ - public: - using TensorDevType = typename data_type_distconv_adapter::TensorDevType; - - channelwise_softmax_distconv_adapter(Layer& layer) - : data_type_distconv_adapter(layer){} - - virtual ~channelwise_softmax_distconv_adapter() = default; - void setup_distributions(tensor_overlap_constraints &constraints) override; - void setup_layer(size_t workspace_capacity) override; - void fp_compute(); - void bp_compute(); - std::unique_ptr> m_channelwise_softmax_operator; - }; // class definition channelwise_softmax_distconv_adapter - -#endif // LBANN_HAS_DISTCONV + : public data_type_distconv_adapter +{ +public: + using TensorDevType = + typename data_type_distconv_adapter::TensorDevType; + + channelwise_softmax_distconv_adapter(Layer& layer) + : data_type_distconv_adapter(layer) + {} + + virtual ~channelwise_softmax_distconv_adapter() = default; + void setup_distributions(tensor_overlap_constraints& constraints) override; + void setup_layer(size_t workspace_capacity) override; + void fp_compute(); + void bp_compute(); + std::unique_ptr> + m_channelwise_softmax_operator; +}; // class definition channelwise_softmax_distconv_adapter +#endif // LBANN_HAS_DISTCONV /** @brief Apply softmax to tensor channels. * @@ -121,14 +129,29 @@ class channelwise_softmax_layer : public data_type_layer void bp_compute() override; #ifdef LBANN_HAS_DISTCONV - friend class channelwise_softmax_distconv_adapter; - protected: - void setup_distconv_adapter(const DataReaderMetaData& dr_metadata) override; - bool is_distconv_supported() const override; - channelwise_softmax_distconv_adapter& get_distconv_adapter() override; - const channelwise_softmax_distconv_adapter& get_distconv_adapter() const override; + friend class channelwise_softmax_distconv_adapter; + +protected: + void setup_distconv_adapter() override; + bool is_distconv_supported() const override; + channelwise_softmax_distconv_adapter& + get_distconv_adapter() override; + const channelwise_softmax_distconv_adapter& + get_distconv_adapter() const override; #endif // LBANN_HAS_DISTCONV +private: + void get_channel_size_and_stride(El::Int& channel_size, + El::Int& channel_stride, + El::Int& num_channels) const; + + /** Specifies the dimension of the tensor to perform softmax on. */ + int64_t m_dim; + /** @brief If true, only performs softmax on the chosen dimension. Otherwise + all dimensions but ``m_dim`` will be used. */ + bool m_single_dim_mode; }; // Builder function @@ -184,56 +207,33 @@ El::Device channelwise_softmax_layer:: return Device; } -template -void channelwise_softmax_layer::setup_dims(DataReaderMetaData& dr_metadata) { - data_type_layer::setup_dims(dr_metadata); - this->set_output_dims(this->get_input_dims()); - - #ifdef LBANN_HAS_DISTCONV - - if (this->distconv_enabled()){ - // Additional checks when distconv mode is enabled - const auto& input_dims = this->get_input_dims(); - const auto& output_dims = this->get_output_dims(); - - if (input_dims.size() != 3 || output_dims.size() != 3){ - LBANN_ERROR(this->get_type()," layer \"",this->get_name(),"\" ", - "expects an input and output tensor with 3 dimensions (channel, *, *), " - "but it has been configured as a ", - input_dims.size(), "-D input tensor and ", - output_dims.size(),"-D output tensor"); - } - } - #endif // LBANN_HAS_DISTCONV -} - #ifdef LBANN_HAS_DISTCONV // ========================================================= // DistConv-Adapter member functions // ========================================================= template -void -channelwise_softmax_distconv_adapter -::setup_distributions(tensor_overlap_constraints &constraints){ +void channelwise_softmax_distconv_adapter:: + setup_distributions(tensor_overlap_constraints& constraints) +{ data_type_distconv_adapter::setup_distributions(constraints); - for (auto &d: this->m_prev_activations_dists) { + for (auto& d : this->m_prev_activations_dists) { d.clear_overlap(); constraints.mark_updated(d); constraints.mark_invariant(d); } - for (auto &d: this->m_activations_dists) { + for (auto& d : this->m_activations_dists) { d.clear_overlap(); constraints.mark_updated(d); constraints.mark_invariant(d); } - for (auto &d: this->m_prev_error_signals_dists) { + for (auto& d : this->m_prev_error_signals_dists) { d.clear_overlap(); constraints.mark_updated(d); constraints.mark_invariant(d); } - for (auto &d: this->m_error_signals_dists) { + for (auto& d : this->m_error_signals_dists) { d.clear_overlap(); constraints.mark_updated(d); constraints.mark_invariant(d); @@ -241,70 +241,80 @@ ::setup_distributions(tensor_overlap_constraints &constraints){ } template -void -channelwise_softmax_distconv_adapter -::setup_layer(size_t workspace_capacity){ +void channelwise_softmax_distconv_adapter:: + setup_layer(size_t workspace_capacity) +{ data_type_distconv_adapter::setup_layer(workspace_capacity); - m_channelwise_softmax_operator = std::make_unique>(dc::get_backend()); + m_channelwise_softmax_operator = + std::make_unique>(dc::get_backend()); } template -void -channelwise_softmax_distconv_adapter -::fp_compute(){ - auto &layer = dynamic_cast< - channelwise_softmax_layer&>(this->layer()); +void channelwise_softmax_distconv_adapter:: + fp_compute() +{ + auto& layer = + dynamic_cast&>( + this->layer()); m_channelwise_softmax_operator->forward(this->get_prev_activations(0), this->get_activations(0)); } template -void -channelwise_softmax_distconv_adapter -::bp_compute(){ - auto &layer = dynamic_cast< - channelwise_softmax_layer&>(this->layer()); - m_channelwise_softmax_operator->backward(this->get_activations(0), - this->get_prev_error_signals(), - this->get_error_signals(0)); +void channelwise_softmax_distconv_adapter:: + bp_compute() +{ + auto& layer = + dynamic_cast&>( + this->layer()); + m_channelwise_softmax_operator->backward(this->get_activations(0), + this->get_prev_error_signals(), + this->get_error_signals(0)); } // ============================================================= // DistConv-enabled Channelwise-Softmax member functions // ============================================================= template -bool -channelwise_softmax_layer -::is_distconv_supported() const { - return Device==El::Device::GPU && Layout == data_layout::DATA_PARALLEL; +bool channelwise_softmax_layer:: + is_distconv_supported() const +{ + return Device == El::Device::GPU && Layout == data_layout::DATA_PARALLEL; } template -void -channelwise_softmax_layer -::setup_distconv_adapter(const DataReaderMetaData& dr_metadata){ - this->get_distconv_adapter_ptr() = std::make_unique>(*this); +void channelwise_softmax_layer:: + setup_distconv_adapter() +{ + this->get_distconv_adapter_ptr() = std::make_unique< + channelwise_softmax_distconv_adapter>( + *this); } template const channelwise_softmax_distconv_adapter& -channelwise_softmax_layer -::get_distconv_adapter() const{ - return dynamic_cast&>(data_type_layer::get_distconv_adapter()); +channelwise_softmax_layer:: + get_distconv_adapter() const +{ + return dynamic_cast&>( + data_type_layer::get_distconv_adapter()); } template channelwise_softmax_distconv_adapter& -channelwise_softmax_layer -::get_distconv_adapter(){ - return const_cast&>( - static_cast&>(*this).get_distconv_adapter()); +channelwise_softmax_layer:: + get_distconv_adapter() +{ + return const_cast< + channelwise_softmax_distconv_adapter&>( + static_cast< + const channelwise_softmax_layer&>(*this) + .get_distconv_adapter()); } - #endif // LBANN_HAS_DISTCONV #ifndef LBANN_CHANNELWISE_SOFTMAX_LAYER_INSTANTIATE diff --git a/include/lbann/layers/misc/channelwise_softmax_impl.hpp b/include/lbann/layers/misc/channelwise_softmax_impl.hpp index b13095d9fcc..fd1e7b1702d 100644 --- a/include/lbann/layers/misc/channelwise_softmax_impl.hpp +++ b/include/lbann/layers/misc/channelwise_softmax_impl.hpp @@ -53,6 +53,28 @@ void channelwise_softmax_layer::setup_dims() } this->set_output_dims(this->get_input_dims()); +#ifdef LBANN_HAS_DISTCONV + + if (this->distconv_enabled()) { + // Additional checks when distconv mode is enabled + const auto& input_dims = this->get_input_dims(); + const auto& output_dims = this->get_output_dims(); + + if (input_dims.size() != 3 || output_dims.size() != 3) { + LBANN_ERROR( + this->get_type(), + " layer \"", + this->get_name(), + "\" ", + "expects an input and output tensor with 3 dimensions (channel, *, *), " + "but it has been configured as a ", + input_dims.size(), + "-D input tensor and ", + output_dims.size(), + "-D output tensor"); + } + } +#endif // LBANN_HAS_DISTCONV } template diff --git a/include/lbann/utils/distconv.hpp b/include/lbann/utils/distconv.hpp index bad629f468f..30ab7b60b4d 100644 --- a/include/lbann/utils/distconv.hpp +++ b/include/lbann/utils/distconv.hpp @@ -53,17 +53,6 @@ #include "p2p/p2p.hpp" #endif // DISTCONV_HAS_P2P -#include "lbann/layers/learning/distconv/distconv_layers.hpp" -#include "lbann/layers/math/distconv/distconv_matmul.hpp" - -#ifdef LBANN_HAS_NVSHMEM -#include "lbann/layers/transform/distconv/distconv_scatter.hpp" -#include "lbann/layers/transform/distconv/distconv_gather.hpp" -#include "lbann/layers/transform/distconv/distconv_nvshmem_vector_addressing.hpp" -#endif // LBANN_HAS_NVSHMEM - -#include "lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp" - namespace lbann { inline auto default_hydrogen_stream() @@ -137,23 +126,8 @@ using MPIRootPrintStreamWaning = ::distconv::util::MPIRootPrintStreamWarning; // Distconv layer classes using Backend = ::distconv::BackendDNNLib; -using ReLU = ::distconv::ReLU; -using LeakyReLU = ::distconv::LeakyReLU; -template -using Convolution = ::distconv::Convolution; -template -using ChannelwiseFullyConnected = ::distconv::ChannelwiseFullyConnected; -template -using Pooling = ::distconv::Pooling; -template -using BatchNormalization = ::distconv::BatchNormalization; -template -using MatMul = ::distconv::MatMul; -template -using ChannelwiseSoftmax = ::distconv::ChannelwiseSoftmax; -using Softmax = ::distconv::Softmax; -using CrossEntropy = ::distconv::CrossEntropy; -using MeanSquaredError = ::distconv::MeanSquaredError; +using AlCommType = typename decltype(std::declval() + .get_al_mpi_cuda_comm())::element_type; using ::distconv::get_channel_dim; using ::distconv::get_sample_dim; From 4b575095527e44b5178b70ba8d2ae3dab17c17a0 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Sun, 9 Jun 2024 22:33:54 -0700 Subject: [PATCH 16/23] Added guard on double ETI --- .../distconv/distconv_channelwise_softmax.cu | 200 +++++++++--------- 1 file changed, 103 insertions(+), 97 deletions(-) diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu index 98c458eaeb9..b99c6cd00a2 100644 --- a/src/layers/misc/distconv/distconv_channelwise_softmax.cu +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -25,118 +25,124 @@ //////////////////////////////////////////////////////////////////////////////// #define LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_INSTANTIATE -#include "lbann/utils/distconv.hpp" +#include "../channelwise_softmax_kernels.cuh" #include "lbann/base.hpp" #include "lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp" +#include "lbann/utils/distconv.hpp" #include "lbann/utils/gpu/helpers.hpp" -#include "../channelwise_softmax_kernels.cuh" - #ifdef LBANN_HAS_DISTCONV -namespace distconv{ - template - template - int - ChannelwiseSoftmax - ::forward(const tensor::Tensor &input_0, - tensor::Tensor &output){ - - if (input_0.get_local_size() == 0 || output.get_local_size() == 0){ - util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; - return 1; // no op for empty inputs - } - - const auto& input_0_dims = input_0.get_local_shape(); - - const auto num_channels = input_0_dims[2]; - const auto local_mini_batch_size = input_0_dims[3]; - const auto mat_channel_size = input_0_dims[0] * input_0_dims[1]; - const auto mat_stride = num_channels * mat_channel_size; - - // Convert to Hydrogen matrices for kernel launch - - using LocalMat = El::Matrix; - - LocalMat local_input(mat_stride, +namespace distconv { +template +template +int ChannelwiseSoftmax::forward( + const tensor::Tensor& input_0, + tensor::Tensor& output) +{ + + if (input_0.get_local_size() == 0 || output.get_local_size() == 0) { + util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; + return 1; // no op for empty inputs + } + + const auto& input_0_dims = input_0.get_local_shape(); + + const auto num_channels = input_0_dims[2]; + const auto local_mini_batch_size = input_0_dims[3]; + const auto mat_channel_size = input_0_dims[0] * input_0_dims[1]; + const auto mat_stride = num_channels * mat_channel_size; + + // Convert to Hydrogen matrices for kernel launch + + using LocalMat = El::Matrix; + + LocalMat local_input(mat_stride, + local_mini_batch_size, + input_0.get_buffer(), + mat_stride); + + LocalMat local_output(mat_stride, local_mini_batch_size, - input_0.get_buffer(), + output.get_buffer(), mat_stride); - LocalMat local_output(mat_stride, - local_mini_batch_size, - output.get_buffer(), - mat_stride); - - ::lbann::channelwise_softmax_fp_impl(num_channels, - mat_channel_size, - local_input, - local_output); - return 1; + ::lbann::channelwise_softmax_fp_impl(num_channels, + mat_channel_size, + local_input, + local_output); + return 1; +} + +template +template +int ChannelwiseSoftmax::backward( + const tensor::Tensor& output, + const tensor::Tensor& output_grad, + tensor::Tensor& input_grad_0) +{ + if (output.get_local_size() == 0 || output_grad.get_local_size() == 0 || + input_grad_0.get_local_size() == 0) { + util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; + return 1; // no op for empty inputs } - template - template - int - ChannelwiseSoftmax - ::backward(const tensor::Tensor &output, - const tensor::Tensor &output_grad, - tensor::Tensor &input_grad_0){ - if (output.get_local_size() == 0 || - output_grad.get_local_size() == 0 || - input_grad_0.get_local_size() == 0){ - util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; - return 1; // no op for empty inputs - } - - const auto& input_0_dims = output.get_local_shape(); - const auto num_channels = input_0_dims[2]; - const auto local_mini_batch_size = input_0_dims[3]; - const auto mat_channel_size = input_0_dims[0] * input_0_dims[1]; - const auto mat_stride = num_channels * mat_channel_size; - - // Convert to Hydrogen matrices for kernel launch - - using LocalMat = El::Matrix; - - LocalMat local_output(mat_stride, - local_mini_batch_size, - output.get_buffer(), - mat_stride); - - LocalMat local_output_grad(mat_stride, - local_mini_batch_size, - output_grad.get_buffer(), - mat_stride); - - LocalMat local_input_grad(mat_stride, - local_mini_batch_size, - input_grad_0.get_buffer(), - mat_stride); - - ::lbann::channelwise_softmax_bp_impl(num_channels, - mat_channel_size, - local_output, - local_output_grad, - local_input_grad); - return 1; - } + const auto& input_0_dims = output.get_local_shape(); + const auto num_channels = input_0_dims[2]; + const auto local_mini_batch_size = input_0_dims[3]; + const auto mat_channel_size = input_0_dims[0] * input_0_dims[1]; + const auto mat_stride = num_channels * mat_channel_size; + + // Convert to Hydrogen matrices for kernel launch + + using LocalMat = El::Matrix; + + LocalMat local_output(mat_stride, + local_mini_batch_size, + output.get_buffer(), + mat_stride); + + LocalMat local_output_grad(mat_stride, + local_mini_batch_size, + output_grad.get_buffer(), + mat_stride); + + LocalMat local_input_grad(mat_stride, + local_mini_batch_size, + input_grad_0.get_buffer(), + mat_stride); + + ::lbann::channelwise_softmax_bp_impl(num_channels, + mat_channel_size, + local_output, + local_output_grad, + local_input_grad); + return 1; +} // ========================================================= // Explicit template instantiation // ========================================================= -#define ETI(T, Backend) \ - template class ChannelwiseSoftmax; \ - template int ChannelwiseSoftmax::forward( \ - const tensor::Tensor &input_0, \ - tensor::Tensor &output_0); \ - template int ChannelwiseSoftmax::backward( \ - const tensor::Tensor &input_0, \ - const tensor::Tensor &input_1, \ - tensor::Tensor &output_grad); - +#define ETI(T, Backend) \ + template class ChannelwiseSoftmax; \ + template int ChannelwiseSoftmax::forward( \ + const tensor::Tensor& \ + input_0, \ + tensor::Tensor& output_0); \ + template int \ + ChannelwiseSoftmax::backward( \ + const tensor::Tensor& \ + input_0, \ + const tensor::Tensor& \ + input_1, \ + tensor::Tensor& output_grad); + +/// @todo: fp16 ETI(float, BackendDNNLib) +#ifdef LBANN_HAS_DOUBLE ETI(double, BackendDNNLib) +#endif // LBANN_HAS_DOUBLE + #undef ETI -} // namespace distconv -#endif // LBANN_HAS_DISTCONV \ No newline at end of file +} // namespace distconv +#endif // LBANN_HAS_DISTCONV From 4896b5dd675e6c1f1cf967b3a62d1b19fa8cabc0 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Sun, 9 Jun 2024 22:41:34 -0700 Subject: [PATCH 17/23] Updated CI test with new environment imports - CI test passing on Lassen --- ...unit_layer_channelwise_softmax_distconv.py | 138 +++++++++++------- 1 file changed, 83 insertions(+), 55 deletions(-) diff --git a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py index 0db96d05d58..2400ca94bff 100644 --- a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py +++ b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py @@ -4,11 +4,12 @@ import os.path import sys import numpy as np +import lbann.contrib.args # Bamboo utilities current_file = os.path.realpath(__file__) current_dir = os.path.dirname(current_file) -sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python')) +sys.path.insert(0, os.path.join(os.path.dirname(current_dir), "common_python")) import tools # ============================================== @@ -20,34 +21,45 @@ # Data np.random.seed(20200115) _num_samples = 15 -_sample_dims = (15,5,1) +_sample_dims = (15, 5, 1) _sample_size = functools.reduce(operator.mul, _sample_dims) -_samples = np.random.normal(loc=0.5, size=(_num_samples,_sample_size)).astype(np.float32) +_samples = np.random.normal(loc=0.5, size=(_num_samples, _sample_size)).astype( + np.float32 +) + # Sample access functions def get_sample(index): - return _samples[index,:] + return _samples[index, :] + + def num_samples(): return _num_samples + + def sample_dims(): return (_sample_size,) + # ============================================== # NumPy implementation # ============================================== + def numpy_channelwise_softmax(x): if x.dtype is not np.float64: x = x.astype(np.float64) - axis = tuple(range(1,x.ndim)) + axis = tuple(range(1, x.ndim)) shift = np.max(x, axis=axis, keepdims=True) - y = np.exp(x-shift) + y = np.exp(x - shift) return y / np.sum(y, axis=axis, keepdims=True) + # ============================================== # Setup LBANN experiment # ============================================== + def setup_experiment(lbann, weekly): """Construct LBANN experiment. @@ -60,11 +72,18 @@ def setup_experiment(lbann, weekly): model = construct_model(lbann) data_reader = construct_data_reader(lbann) optimizer = lbann.NoOptimizer() - return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes + return ( + trainer, + model, + data_reader, + optimizer, + None, + ) # Don't request any specific number of nodes + def create_parallel_strategy(num_channel_groups): - return {"channel_groups": num_channel_groups, - "filter_groups": num_channel_groups} + return {"channel_groups": num_channel_groups, "filter_groups": num_channel_groups} + def construct_model(lbann): """Construct LBANN model. @@ -77,13 +96,15 @@ def construct_model(lbann): # Input data # Note: Sum with a weights layer so that gradient checking will # verify that error signals are correct. - x_weights = lbann.Weights(optimizer=lbann.SGD(), - initializer=lbann.ConstantInitializer(value=0.0), - name='input_weights') - x = lbann.Sum(lbann.Reshape(lbann.Input(data_field='samples'), - dims=_sample_dims), - lbann.WeightsLayer(weights=x_weights, - dims=_sample_dims)) + x_weights = lbann.Weights( + optimizer=lbann.SGD(), + initializer=lbann.ConstantInitializer(value=0.0), + name="input_weights", + ) + x = lbann.Sum( + lbann.Reshape(lbann.Input(data_field="samples"), dims=_sample_dims), + lbann.WeightsLayer(weights=x_weights, dims=_sample_dims), + ) x_lbann = x obj = [] metrics = [] @@ -91,8 +112,8 @@ def construct_model(lbann): num_channel_groups = tools.gpus_per_node(lbann) if num_channel_groups == 0: - e = 'this test requires GPUs.' - print('Skip - ' + e) + e = "this test requires GPUs." + print("Skip - " + e) pytest.skip(e) # ------------------------------------------ @@ -102,13 +123,15 @@ def construct_model(lbann): # LBANN implementation x = x_lbann - y = lbann.ChannelwiseSoftmax(x, - data_layout='data_parallel', - parallel_strategy=create_parallel_strategy(num_channel_groups), - name="Channelwise_softmax_distconv") + y = lbann.ChannelwiseSoftmax( + x, + data_layout="data_parallel", + parallel_strategy=create_parallel_strategy(num_channel_groups), + name="Channelwise_softmax_distconv", + ) z = lbann.L2Norm2(y) obj.append(z) - metrics.append(lbann.Metric(z, name='channelwise split distconv')) + metrics.append(lbann.Metric(z, name="channelwise split distconv")) # NumPy implementation vals = [] @@ -119,12 +142,15 @@ def construct_model(lbann): vals.append(z) val = np.mean(vals) tol = 8 * val * np.finfo(np.float32).eps - callbacks.append(lbann.CallbackCheckMetric( - metric=metrics[-1].name, - lower_bound=val-tol, - upper_bound=val+tol, - error_on_failure=True, - execution_modes='test')) + callbacks.append( + lbann.CallbackCheckMetric( + metric=metrics[-1].name, + lower_bound=val - tol, + upper_bound=val + tol, + error_on_failure=True, + execution_modes="test", + ) + ) # ------------------------------------------ # Gradient checking @@ -137,11 +163,14 @@ def construct_model(lbann): # ------------------------------------------ num_epochs = 0 - return lbann.Model(num_epochs, - layers=lbann.traverse_layer_graph(x_lbann), - objective_function=obj, - metrics=metrics, - callbacks=callbacks) + return lbann.Model( + num_epochs, + layers=lbann.traverse_layer_graph(x_lbann), + objective_function=obj, + metrics=metrics, + callbacks=callbacks, + ) + def construct_data_reader(lbann): """Construct Protobuf message for Python data reader. @@ -157,32 +186,31 @@ def construct_data_reader(lbann): # Note: The training data reader should be removed when # https://github.com/LLNL/lbann/issues/1098 is resolved. message = lbann.reader_pb2.DataReader() - message.reader.extend([ - tools.create_python_data_reader( - lbann, - current_file, - 'get_sample', - 'num_samples', - 'sample_dims', - 'train' - ) - ]) - message.reader.extend([ - tools.create_python_data_reader( - lbann, - current_file, - 'get_sample', - 'num_samples', - 'sample_dims', - 'test' - ) - ]) + message.reader.extend( + [ + tools.create_python_data_reader( + lbann, current_file, "get_sample", "num_samples", "sample_dims", "train" + ) + ] + ) + message.reader.extend( + [ + tools.create_python_data_reader( + lbann, current_file, "get_sample", "num_samples", "sample_dims", "test" + ) + ] + ) return message + # ============================================== # Setup PyTest # ============================================== # Create test functions that can interact with PyTest -for _test_func in tools.create_tests(setup_experiment, __file__, environment=tools.get_distconv_environment()): +for _test_func in tools.create_tests( + setup_experiment, + __file__, + environment=lbann.contrib.args.get_distconv_environment(), +): globals()[_test_func.__name__] = _test_func From 66f5ee32db4bdf8db694c17ab1bced3ffa311abe Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Tue, 11 Jun 2024 18:33:08 -0400 Subject: [PATCH 18/23] Update ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py Co-authored-by: Brian Van Essen --- .../unit_tests/test_unit_layer_channelwise_softmax_distconv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py index 2400ca94bff..e0f5587a65a 100644 --- a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py +++ b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py @@ -6,7 +6,7 @@ import numpy as np import lbann.contrib.args -# Bamboo utilities +# CI utilities current_file = os.path.realpath(__file__) current_dir = os.path.dirname(current_file) sys.path.insert(0, os.path.join(os.path.dirname(current_dir), "common_python")) From 8eef5a0e3deb814a31314875c5e7a6ca74c19a0e Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Mon, 17 Jun 2024 17:26:17 -0700 Subject: [PATCH 19/23] Updated year on textsd --- include/lbann/layers/CMakeLists.txt | 2 +- include/lbann/layers/misc/CMakeLists.txt | 2 +- include/lbann/layers/misc/distconv/CMakeLists.txt | 2 +- src/layers/misc/CMakeLists.txt | 2 +- src/layers/misc/channelwise_softmax_kernels.cuh | 3 ++- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/include/lbann/layers/CMakeLists.txt b/include/lbann/layers/CMakeLists.txt index fa6442b3bc8..f32d5766043 100644 --- a/include/lbann/layers/CMakeLists.txt +++ b/include/lbann/layers/CMakeLists.txt @@ -1,5 +1,5 @@ ################################################################################ -## Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +## Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. ## Produced at the Lawrence Livermore National Laboratory. ## Written by the LBANN Research Team (B. Van Essen, et al.) listed in ## the CONTRIBUTORS file. diff --git a/include/lbann/layers/misc/CMakeLists.txt b/include/lbann/layers/misc/CMakeLists.txt index 258023a8fbb..18da384a7ad 100644 --- a/include/lbann/layers/misc/CMakeLists.txt +++ b/include/lbann/layers/misc/CMakeLists.txt @@ -1,5 +1,5 @@ ################################################################################ -## Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +## Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. ## Produced at the Lawrence Livermore National Laboratory. ## Written by the LBANN Research Team (B. Van Essen, et al.) listed in ## the CONTRIBUTORS file. diff --git a/include/lbann/layers/misc/distconv/CMakeLists.txt b/include/lbann/layers/misc/distconv/CMakeLists.txt index 29d9b3a0c32..fc835a868a5 100644 --- a/include/lbann/layers/misc/distconv/CMakeLists.txt +++ b/include/lbann/layers/misc/distconv/CMakeLists.txt @@ -1,5 +1,5 @@ ################################################################################ -## Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +## Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. ## Produced at the Lawrence Livermore National Laboratory. ## Written by the LBANN Research Team (B. Van Essen, et al.) listed in ## the CONTRIBUTORS file. diff --git a/src/layers/misc/CMakeLists.txt b/src/layers/misc/CMakeLists.txt index 14cd1945c3f..8bb7c092cfa 100644 --- a/src/layers/misc/CMakeLists.txt +++ b/src/layers/misc/CMakeLists.txt @@ -1,5 +1,5 @@ ################################################################################ -## Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +## Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. ## Produced at the Lawrence Livermore National Laboratory. ## Written by the LBANN Research Team (B. Van Essen, et al.) listed in ## the CONTRIBUTORS file. diff --git a/src/layers/misc/channelwise_softmax_kernels.cuh b/src/layers/misc/channelwise_softmax_kernels.cuh index 8363c86d296..95a9c2423d0 100644 --- a/src/layers/misc/channelwise_softmax_kernels.cuh +++ b/src/layers/misc/channelwise_softmax_kernels.cuh @@ -26,6 +26,7 @@ #ifndef LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS #define LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS namespace lbann{ +namespace{ using Size3 = gpu_lib::array; /** @brief Max functor */ @@ -465,6 +466,6 @@ void channelwise_softmax_bp_impl(size_t num_channels, } } - +} // namespace anonymous } // namespace lbann #endif // LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS From e9fcd8429884a19262f828c728b6cbc4e2d71331 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Mon, 17 Jun 2024 20:13:02 -0700 Subject: [PATCH 20/23] Updated instantiation code --- .../lbann/layers/misc/channelwise_softmax.hpp | 2 +- .../layers/misc/channelwise_softmax_impl.hpp | 2 +- .../distconv/distconv_channelwise_softmax.hpp | 55 +++++++++---------- .../misc/channelwise_softmax_kernels.cuh | 1 + .../distconv/distconv_channelwise_softmax.cu | 17 ++---- 5 files changed, 36 insertions(+), 41 deletions(-) diff --git a/include/lbann/layers/misc/channelwise_softmax.hpp b/include/lbann/layers/misc/channelwise_softmax.hpp index 9168bb60fac..086ace20b73 100644 --- a/include/lbann/layers/misc/channelwise_softmax.hpp +++ b/include/lbann/layers/misc/channelwise_softmax.hpp @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. diff --git a/include/lbann/layers/misc/channelwise_softmax_impl.hpp b/include/lbann/layers/misc/channelwise_softmax_impl.hpp index fd1e7b1702d..6879eb294b5 100644 --- a/include/lbann/layers/misc/channelwise_softmax_impl.hpp +++ b/include/lbann/layers/misc/channelwise_softmax_impl.hpp @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. diff --git a/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp b/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp index b039bf09738..4ca40035f84 100644 --- a/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp +++ b/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. @@ -29,33 +29,32 @@ #include "lbann/utils/distconv.hpp" #ifdef LBANN_HAS_DISTCONV -namespace distconv{ - template - class ChannelwiseSoftmax{ - using LocaleMPI = tensor::LocaleMPI; - - public: - ChannelwiseSoftmax(Backend &backend):m_be(backend){}; - - template - int forward( - const tensor::Tensor &input_0, - tensor::Tensor &output); - - template - int backward( - const tensor::Tensor &input_0, - const tensor::Tensor &output_grad, - tensor::Tensor &input_grad_0); - - protected: - Backend &m_be; - - }; - - extern template class ChannelwiseSoftmax<::distconv::BackendDNNLib, float>; - extern template class ChannelwiseSoftmax<::distconv::BackendDNNLib, double>; -} +namespace distconv { +template +class ChannelwiseSoftmax +{ + using LocaleMPI = tensor::LocaleMPI; + +public: + ChannelwiseSoftmax(Backend& backend) : m_be(backend){}; + + template + int forward(const tensor::Tensor& input_0, + tensor::Tensor& output); + + template + int backward( + const tensor::Tensor& input_0, + const tensor::Tensor& output_grad, + tensor::Tensor& input_grad_0); + +protected: + Backend& m_be; +}; + +extern template class ChannelwiseSoftmax<::distconv::BackendDNNLib, float>; +extern template class ChannelwiseSoftmax<::distconv::BackendDNNLib, double>; +} // namespace distconv #endif // LBANN_HAS_DISTCONV #endif // LBANN_LAYERS_MISC_DISTCONV_CHANNELWISE_SOFTMAX \ No newline at end of file diff --git a/src/layers/misc/channelwise_softmax_kernels.cuh b/src/layers/misc/channelwise_softmax_kernels.cuh index 95a9c2423d0..9c1ae86bc1b 100644 --- a/src/layers/misc/channelwise_softmax_kernels.cuh +++ b/src/layers/misc/channelwise_softmax_kernels.cuh @@ -25,6 +25,7 @@ //////////////////////////////////////////////////////////////////////////////// #ifndef LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS #define LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS +#include "lbann/utils/gpu/helpers.hpp" namespace lbann{ namespace{ using Size3 = gpu_lib::array; diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu index b99c6cd00a2..0d3db2bad8d 100644 --- a/src/layers/misc/distconv/distconv_channelwise_softmax.cu +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -123,26 +123,21 @@ int ChannelwiseSoftmax::backward( // Explicit template instantiation // ========================================================= -#define ETI(T, Backend) \ - template class ChannelwiseSoftmax; \ - template int ChannelwiseSoftmax::forward( \ +#define PROTO(T) \ + template class ChannelwiseSoftmax; \ + template int \ + ChannelwiseSoftmax::forward( \ const tensor::Tensor& \ input_0, \ tensor::Tensor& output_0); \ template int \ - ChannelwiseSoftmax::backward( \ + ChannelwiseSoftmax::backward( \ const tensor::Tensor& \ input_0, \ const tensor::Tensor& \ input_1, \ tensor::Tensor& output_grad); -/// @todo: fp16 -ETI(float, BackendDNNLib) -#ifdef LBANN_HAS_DOUBLE -ETI(double, BackendDNNLib) -#endif // LBANN_HAS_DOUBLE - -#undef ETI +#include "lbann/macros/instantiate.hpp" } // namespace distconv #endif // LBANN_HAS_DISTCONV From 602eeab884e92c721478b233508df550e685bca8 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Mon, 17 Jun 2024 20:17:12 -0700 Subject: [PATCH 21/23] Updated copyright years --- src/layers/misc/channelwise_softmax.cpp | 2 +- src/layers/misc/channelwise_softmax.cu | 51 ++++++++++--------- .../misc/channelwise_softmax_kernels.cuh | 2 +- src/layers/misc/distconv/CMakeLists.txt | 2 +- .../distconv/distconv_channelwise_softmax.cu | 2 +- 5 files changed, 32 insertions(+), 27 deletions(-) diff --git a/src/layers/misc/channelwise_softmax.cpp b/src/layers/misc/channelwise_softmax.cpp index f335a733bc0..29af732628f 100644 --- a/src/layers/misc/channelwise_softmax.cpp +++ b/src/layers/misc/channelwise_softmax.cpp @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. diff --git a/src/layers/misc/channelwise_softmax.cu b/src/layers/misc/channelwise_softmax.cu index cd5abc3b414..cafde5e0293 100644 --- a/src/layers/misc/channelwise_softmax.cu +++ b/src/layers/misc/channelwise_softmax.cu @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. @@ -25,32 +25,34 @@ //////////////////////////////////////////////////////////////////////////////// #define LBANN_CHANNELWISE_SOFTMAX_LAYER_INSTANTIATE +#include "channelwise_softmax_kernels.cuh" #include "lbann/layers/misc/channelwise_softmax_impl.hpp" #include "lbann/utils/gpu/helpers.hpp" -#include "channelwise_softmax_kernels.cuh" - namespace lbann { - template -void channelwise_softmax_layer::fp_compute() { - - #ifdef LBANN_HAS_DISTCONV - if (this->distconv_enabled()){ +void channelwise_softmax_layer::fp_compute() +{ + +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { this->get_distconv_adapter().fp_compute(); - return ; + return; } - #endif // LBANN_HAS_DISTCONV +#endif // LBANN_HAS_DISTCONV // Local matrices const size_t num_channels = this->get_output_dims().front(); const size_t channel_size = this->get_output_size() / num_channels; using LocalMat = El::Matrix; - const auto& local_input = dynamic_cast(this->get_prev_activations().LockedMatrix()); - auto& local_output = dynamic_cast(this->get_activations().Matrix()); + const auto& local_input = + dynamic_cast(this->get_prev_activations().LockedMatrix()); + auto& local_output = + dynamic_cast(this->get_activations().Matrix()); - // TODO: This looks wrong. Maybe wrap the implementation of this function in a namespace - SZ + // TODO: This looks wrong. Maybe wrap the implementation of this function in a + // namespace - SZ channelwise_softmax_fp_impl(num_channels, channel_size, local_input, @@ -62,24 +64,27 @@ void channelwise_softmax_layer::fp_compute() { // ========================================================= template -void channelwise_softmax_layer::bp_compute() { +void channelwise_softmax_layer::bp_compute() +{ - #ifdef LBANN_HAS_DISTCONV - if (this->distconv_enabled()){ +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { this->get_distconv_adapter().bp_compute(); - return ; + return; } - #endif // LBANN_HAS_DISTCONV +#endif // LBANN_HAS_DISTCONV const size_t num_channels = this->get_output_dims().front(); const size_t channel_size = this->get_output_size() / num_channels; - - // Local matrices + // Local matrices using LocalMat = El::Matrix; - const auto& local_output = dynamic_cast(this->get_activations().LockedMatrix()); - const auto& local_output_grad = dynamic_cast(this->get_prev_error_signals().LockedMatrix()); - auto& local_input_grad = dynamic_cast(this->get_error_signals().Matrix()); + const auto& local_output = + dynamic_cast(this->get_activations().LockedMatrix()); + const auto& local_output_grad = dynamic_cast( + this->get_prev_error_signals().LockedMatrix()); + auto& local_input_grad = + dynamic_cast(this->get_error_signals().Matrix()); channelwise_softmax_bp_impl(num_channels, channel_size, diff --git a/src/layers/misc/channelwise_softmax_kernels.cuh b/src/layers/misc/channelwise_softmax_kernels.cuh index 9c1ae86bc1b..278e43bed6c 100644 --- a/src/layers/misc/channelwise_softmax_kernels.cuh +++ b/src/layers/misc/channelwise_softmax_kernels.cuh @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. diff --git a/src/layers/misc/distconv/CMakeLists.txt b/src/layers/misc/distconv/CMakeLists.txt index 30270ed7b63..a6c3851eb38 100644 --- a/src/layers/misc/distconv/CMakeLists.txt +++ b/src/layers/misc/distconv/CMakeLists.txt @@ -1,5 +1,5 @@ ################################################################################ -## Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +## Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. ## Produced at the Lawrence Livermore National Laboratory. ## Written by the LBANN Research Team (B. Van Essen, et al.) listed in ## the CONTRIBUTORS file. diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu index 0d3db2bad8d..ba1f25c17ab 100644 --- a/src/layers/misc/distconv/distconv_channelwise_softmax.cu +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2022, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. From c88847627b027f0d5dc4a5c7ea872730611e7e4f Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Tue, 25 Jun 2024 11:56:19 -0700 Subject: [PATCH 22/23] Remove comment after applying PR suggestions --- src/layers/misc/channelwise_softmax.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/layers/misc/channelwise_softmax.cu b/src/layers/misc/channelwise_softmax.cu index cafde5e0293..d75fdf6e1e7 100644 --- a/src/layers/misc/channelwise_softmax.cu +++ b/src/layers/misc/channelwise_softmax.cu @@ -51,8 +51,6 @@ void channelwise_softmax_layer::fp_compute() auto& local_output = dynamic_cast(this->get_activations().Matrix()); - // TODO: This looks wrong. Maybe wrap the implementation of this function in a - // namespace - SZ channelwise_softmax_fp_impl(num_channels, channel_size, local_input, From bb876fabe3d0e0262d30ae9d548f82988ea8bc87 Mon Sep 17 00:00:00 2001 From: Tom Benson Date: Tue, 25 Jun 2024 14:31:26 -0400 Subject: [PATCH 23/23] Fix a couple issues with the auto-detection of the NVCC_GENCODE in NCCL (#2459) --- scripts/superbuild/nccl/CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/superbuild/nccl/CMakeLists.txt b/scripts/superbuild/nccl/CMakeLists.txt index 5b8e69e0416..354700e3f4b 100644 --- a/scripts/superbuild/nccl/CMakeLists.txt +++ b/scripts/superbuild/nccl/CMakeLists.txt @@ -48,7 +48,7 @@ endmacro () lbann_sb_init_extern_pkg( NAME NCCL - LANGUAGES C CXX # CUDA <- can't set explicitly; inferred from ${CUDA_HOME} + LANGUAGES C CXX CUDA GITHUB_URL NVIDIA/nccl GIT_TAG "master") @@ -105,8 +105,8 @@ if (LBANN_SB_FWD_NCCL_NVCC_GENCODE) elseif (DEFINED $ENV{NVCC_GENCODE}) set(_nccl_nvcc_gencode_opt "NVCC_GENCODE=$ENV{NVCC_GENCODE}") -elseif (LBANN_NCCL_CUDA_ARCHITECTURES) - set(_cuda_arch ${LBANN_NCCL_CUDA_ARCHITECTURES}) +elseif (LBANN_SB_NCCL_CUDA_ARCHITECTURES) + set(_cuda_arch ${LBANN_SB_NCCL_CUDA_ARCHITECTURES}) set(_nccl_nvcc_gencode_opt "NVCC_GENCODE=-gencode=arch=compute_${_cuda_arch},code=sm_${_cuda_arch}") else ()