diff --git a/include/lbann/layers/operator_layer.hpp b/include/lbann/layers/operator_layer.hpp index 4546fa5fd8e..8929f4520d8 100644 --- a/include/lbann/layers/operator_layer.hpp +++ b/include/lbann/layers/operator_layer.hpp @@ -81,6 +81,10 @@ class OperatorLayer final : public data_type_layer data_layout get_data_layout() const final; El::Device get_device_allocation() const final; +#ifdef LBANN_HAS_ONNX + void fill_onnx_node(onnx::GraphProto& graph) const override; +#endif // LBANN_HAS_ONNX + void fp_compute() final; void bp_compute() final; diff --git a/include/lbann/operators/declare_stateless_op.hpp b/include/lbann/operators/declare_stateless_op.hpp index 7ff0ebd531a..10a0de0431a 100644 --- a/include/lbann/operators/declare_stateless_op.hpp +++ b/include/lbann/operators/declare_stateless_op.hpp @@ -32,6 +32,16 @@ #include "lbann/proto/operators.pb.h" +#ifdef LBANN_HAS_ONNX +#define ADD_GET_ONNX_NODES_API() \ + std::vector get_onnx_nodes() const final \ + { \ + return get_onnx_nodes_impl(*this); \ + } +#else +#define ADD_GET_ONNX_NODES_API() +#endif // LBANN_HAS_ONNX + // These are all single-type operators. #define LBANN_DECLARE_STATELESS_OPERATOR(OP_NAME, OP_STRING) \ @@ -64,6 +74,7 @@ ar(::cereal::make_nvp("Operator", \ ::cereal::base_class(this))); \ } \ + ADD_GET_ONNX_NODES_API() \ void fp_compute(std::vector const& inputs, \ std::vector const& outputs) const final; \ void bp_compute( \ @@ -113,6 +124,7 @@ ar(::cereal::make_nvp("ElementwiseOperator", \ ::cereal::base_class(this))); \ } \ + ADD_GET_ONNX_NODES_API() \ \ private: \ void \ @@ -128,6 +140,21 @@ } \ void do_fill_description(description&) const final \ {} \ - } + }; + +namespace lbann { + +#ifdef LBANN_HAS_ONNX +// Overloads of this function are used to implement the functions in +// the macro template above. +template +std::vector get_onnx_nodes_impl(OperatorT const& op) +{ + // The default assumption is that we don't know how to represent + // this operator in ONNX terms yet. + return {}; +} +#endif // LBANN_HAS_ONNX +} // namespace lbann #endif // LBANN_INCLUDE_LBANN_OPERATORS_DECLARE_STATELESS_OP_HPP_INCLUDED diff --git a/include/lbann/operators/math/binary.hpp b/include/lbann/operators/math/binary.hpp index d9db70c21e9..9a639190736 100644 --- a/include/lbann/operators/math/binary.hpp +++ b/include/lbann/operators/math/binary.hpp @@ -29,34 +29,57 @@ #include "lbann/operators/declare_stateless_op.hpp" +#ifdef LBANN_HAS_ONNX +#define LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(OP_NAME, \ + OP_STRING, \ + OP_ONNX_NAME) \ + LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(OP_NAME, OP_STRING); \ + template \ + std::vector get_onnx_nodes_impl( \ + OP_NAME##Operator const& op) \ + { \ + std::vector nodes(1UL); \ + nodes.front().set_op_type(OP_ONNX_NAME); \ + return nodes; \ + } +#else +#define LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(OP_NAME, \ + OP_STRING, \ + OP_ONNX_NAME) \ + LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(OP_NAME, OP_STRING) +#endif // LBANN_HAS_ONNX + namespace lbann { // Arithmetic operations -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Add, "add"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Subtract, "subtract"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Multiply, "multiply"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Divide, "divide"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Mod, "modulo"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Pow, "power"); +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Add, "add", "Add") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Subtract, "subtract", "Sub") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Multiply, "multiply", "Mul") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Divide, "divide", "Div") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Mod, "modulo", "Mod") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Pow, "power", "Pow") LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(SafeDivide, "safe divide"); LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(SquaredDifference, "squared difference"); // Comparison operations -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Max, "maximum"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Min, "minimum"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Equal, "equal"); +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Max, "maximum", "Max") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Min, "minimum", "Min") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Equal, "equal", "Equal") LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(NotEqual, "not equal"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Less, "less than"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LessEqual, "less than or equal"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Greater, "greater than"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(GreaterEqual, - "greater than or equal"); +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Less, "less than", "Less") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LessEqual, + "less than or equal", + "LessOrEqual") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Greater, "greater than", "Greater") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(GreaterEqual, + "greater than or equal", + "GreaterOrEqual") // Logical operations -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LogicalAnd, "logical and"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LogicalOr, "logical or"); -LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LogicalXor, "logical xor"); +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LogicalAnd, "logical and", "And") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LogicalOr, "logical or", "Or") +LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LogicalXor, "logical xor", "Xor") } // namespace lbann diff --git a/include/lbann/operators/math/binary_with_constant.hpp b/include/lbann/operators/math/binary_with_constant.hpp index c4d6dd7d6cc..e021d9e1438 100644 --- a/include/lbann/operators/math/binary_with_constant.hpp +++ b/include/lbann/operators/math/binary_with_constant.hpp @@ -32,7 +32,10 @@ #include "lbann/operators/elementwise_operator.hpp" #include "lbann/utils/cloneable.hpp" -#include "lbann/proto/operators.pb.h" +#ifdef LBANN_HAS_ONNX +#include +#endif // LBANN_HAS_ONNX + /** @file * @@ -50,6 +53,16 @@ #include "lbann/proto/operators.pb.h" +#ifdef LBANN_HAS_ONNX +#define ADD_GET_ONNX_NODES_API() \ + std::vector get_onnx_nodes() const final \ + { \ + return get_onnx_nodes_impl(*this); \ + } +#else +#define ADD_GET_ONNX_NODES_API() +#endif // LBANN_HAS_ONNX + // These are all single-type operators. #define LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(OP_NAME, OP_STRING) \ @@ -88,6 +101,7 @@ ::cereal::base_class(this)), \ CEREAL_NVP(m_constant)); \ } \ + ADD_GET_ONNX_NODES_API() \ DataT get_constant() const noexcept \ { \ return m_constant; \ @@ -123,7 +137,7 @@ namespace lbann { // x + c -- treated as commutative. LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(AddConstant, "add constant"); -// x + c -- treated as commutative. +// x * c -- treated as commutative. LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(Scale, "scale"); // x - C -- yes, could be "plus -C", but so could 7-4 be 7+-4, but @@ -149,5 +163,152 @@ LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterEqualConstant, LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterConstant, "greater than constant"); +#ifdef LBANN_HAS_ONNX +inline onnx::NodeProto get_constant_node(float val) +{ + onnx::NodeProto const_node; + const_node.add_output("const_val"); + const_node.set_domain(""); + const_node.set_doc_string("Const value for binary with constant operations"); + auto* const_val = const_node.add_attribute(); + const_val->set_name("value_float"); + const_val->set_type(onnx::AttributeProto::FLOAT); + const_val->set_f(val); + return const_node; +} + +template +std::vector +get_onnx_nodes_impl(AddConstantOperator const op) +{ + std::vector nodes(2UL); + nodes.front().set_op_type("Add"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PostConstant"); + return nodes; +} + +template +std::vector get_onnx_nodes_impl(ScaleOperator const op) +{ + std::vector nodes(2UL); + nodes.front().set_op_type("Mul"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PostConstant"); + return nodes; +} + +template +std::vector +get_onnx_nodes_impl(SubtractConstantOperator const op) +{ + std::vector nodes(2UL); + nodes.front().set_op_type("Sub"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PostConstant"); + return nodes; +} + +template +std::vector +get_onnx_nodes_impl(ConstantSubtractOperator const op) +{ + std::vector nodes(2UL); + nodes.front().set_op_type("Sub"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); + return nodes; +} + +template +std::vector +get_onnx_nodes_impl(MaxConstantOperator const op) +{ + std::vector nodes(2UL); + nodes.front().set_op_type("Max"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); + return nodes; +} + +template +std::vector +get_onnx_nodes_impl(MinConstantOperator const op) +{ + std::vector nodes(2UL); + nodes.front().set_op_type("Min"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); + return nodes; +} + +template +std::vector +get_onnx_nodes_impl(EqualConstantOperator const op) +{ + std::vector nodes(2UL); + nodes.front().set_op_type("Equal"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); + return nodes; +} + +template +std::vector +get_onnx_nodes_impl(NotEqualConstantOperator const op) +{ + std::vector nodes(3UL); + nodes.front().set_op_type("Equal"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); + nodes.at(1).set_op_type("Not"); + return nodes; +} + +template +std::vector +get_onnx_nodes_impl(LessConstantOperator const op) +{ + std::vector nodes(2UL); + nodes.front().set_op_type("Less"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PostConstant"); + return nodes; +} + +template +std::vector +get_onnx_nodes_impl(LessEqualConstantOperator const op) +{ + std::vector nodes(2UL); + nodes.front().set_op_type("LessOrEqual"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PostConstant"); + return nodes; +} + +template +std::vector +get_onnx_nodes_impl(GreaterConstantOperator const op) +{ + std::vector nodes(2UL); + nodes.front().set_op_type("Greater"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); + return nodes; +} + +template +std::vector +get_onnx_nodes_impl(GreaterEqualConstantOperator const op) +{ + std::vector nodes(2UL); + nodes.front().set_op_type("GreaterOrEqual"); + nodes.back() = get_constant_node(El::To(op.get_constant())); + nodes.back().set_op_type("PreConstant"); + return nodes; +} +#endif // LBANN_HAS_ONNX + } // namespace lbann #endif // LBANN_INCLUDE_LBANN_OPERATORS_BINARY_WITH_CONSTANT_HPP_INCLUDED diff --git a/include/lbann/operators/operator.hpp b/include/lbann/operators/operator.hpp index 0bbd838f0dd..d1e53fbf1cf 100644 --- a/include/lbann/operators/operator.hpp +++ b/include/lbann/operators/operator.hpp @@ -43,6 +43,10 @@ #include +#ifdef LBANN_HAS_ONNX +#include +#endif + #include #include @@ -130,6 +134,10 @@ class Operator : public AbstractCloneableBase>, template void serialize(ArchiveT& ar); +#ifdef LBANN_HAS_ONNX + virtual std::vector get_onnx_nodes() const; +#endif + ///@} /** @name Computational interface */ ///@{ @@ -164,7 +172,7 @@ class Operator : public AbstractCloneableBase>, virtual void set_proto_params(lbann_data::Operator&) const = 0; /** @brief Concrete operator description. */ virtual void do_fill_description(Description&) const = 0; -}; +}; // class Operator template void Operator::write_proto( @@ -208,5 +216,16 @@ template void Operator::serialize(ArchiveT& ar) {} +#ifdef LBANN_HAS_ONNX +template +std::vector +Operator::get_onnx_nodes() const +{ + // The default assumption is that we don't know how to represent + // this operator in ONNX terms yet. + return {}; +} +#endif + } // namespace lbann #endif // LBANN_OPERATORS_OPERATOR_HPP_INCLUDED diff --git a/src/layers/operator_layer.cpp b/src/layers/operator_layer.cpp index b3e91a514e2..5d007590de4 100644 --- a/src/layers/operator_layer.cpp +++ b/src/layers/operator_layer.cpp @@ -57,4 +57,65 @@ void OperatorLayer::write_specific_proto( op->set_device_allocation(proto::ProtoDevice); } +#ifdef LBANN_HAS_ONNX +template +void OperatorLayer::fill_onnx_node(onnx::GraphProto& graph) const +{ + const auto& parents = this->get_parent_layers(); + auto nodes = m_ops.front()->get_onnx_nodes(); + + auto* op_node = graph.add_node(); + *op_node = nodes.front(); + + op_node->set_name(this->get_name()); + op_node->set_domain(""); + op_node->set_doc_string(this->get_name()); + + // binary operators + if (nodes.size() == 1) { + for (auto* parent : parents) { + size_t idx = parent->find_child_layer_index(*this); + op_node->add_input(parent->get_name() + "_" + std::to_string(idx)); + } + } + // Binary w/ constant operators + else if (nodes.size() == 2 || nodes.size() == 3) { + auto* const_node = graph.add_node(); + *const_node = nodes.back(); + if (const_node->op_type() == "PostConstant") { + op_node->add_input(parents[0]->get_name() + "_0"); + op_node->add_input(const_node->output(0)); + } + else if (const_node->op_type() == "PreConstant") { + op_node->add_input(const_node->output(0)); + op_node->add_input(parents[0]->get_name() + "_0"); + } + else + LBANN_ERROR("Unknown onnx op type for constant."); + + const_node->set_op_type("Constant"); + } + else + LBANN_ERROR("Expected 1-3 ONNX nodes for binary operation, received ", + nodes.size()); + + // Not equal operator + if (nodes.size() == 3) { + op_node->add_output("EqualOperator"); + auto* not_node = graph.add_node(); + not_node->add_input(op_node->output(0)); + not_node->set_name("Not operator"); + not_node->set_op_type("Not"); + not_node->set_domain(""); + not_node->set_doc_string("Not node for not equal operation."); + op_node = not_node; + } + + for (auto const* child : this->get_child_layers()) { + auto idx = this->find_child_layer_index(*child); + op_node->add_output(this->get_name() + "_" + std::to_string(idx)); + } +} +#endif // LBANN_HAS_ONNX + } // namespace lbann