From c07b136ca94cb7489f0057bd1b8734017c0d1f9d Mon Sep 17 00:00:00 2001 From: Neal Gafter Date: Thu, 1 Sep 2022 11:56:45 -0700 Subject: [PATCH] bmg: Implement a sigmoid transform for use with distributions like Beta. Summary: Adds a new transform, Sigma, which maps a probability value in the range (0..1) to the range (-INF..INF). This is one of the blockers preventing distributions like Beta from working with NUTS. Reviewed By: yucenli Differential Revision: D38399002 fbshipit-source-id: 48cbccf4e9447799a5eb0f5654af6787e35d1cd2 --- .../graph/global/tests/util_test.cpp | 4 +- src/beanmachine/graph/global/util.cpp | 7 +- src/beanmachine/graph/graph.cpp | 8 + .../graph/transform/sigmoidtransform.cpp | 148 +++++++++++ src/beanmachine/graph/transform/transform.h | 27 ++ .../graph/transform/transform_test.cpp | 233 +++++++++++++++++- src/beanmachine/graph/transformation.h | 10 +- src/beanmachine/graph/util.cpp | 17 +- src/beanmachine/graph/util.h | 3 + 9 files changed, 434 insertions(+), 23 deletions(-) create mode 100644 src/beanmachine/graph/transform/sigmoidtransform.cpp diff --git a/src/beanmachine/graph/global/tests/util_test.cpp b/src/beanmachine/graph/global/tests/util_test.cpp index 5590054dcc..63bf49edce 100644 --- a/src/beanmachine/graph/global/tests/util_test.cpp +++ b/src/beanmachine/graph/global/tests/util_test.cpp @@ -35,8 +35,8 @@ TEST(testglobal, global_default_transform) { g.add_operator(OperatorType::SAMPLE, {probability_dist}); g.query(probability_sample); - // TODO: add support for simplex distributions - EXPECT_THROW(set_default_transforms(g), std::runtime_error); + // test support for simplex distributions + set_default_transforms(g); // should run with no issues Graph g1; uint natural_dist = diff --git a/src/beanmachine/graph/global/util.cpp b/src/beanmachine/graph/global/util.cpp index 3686a651c1..f1ca5adb5e 100644 --- a/src/beanmachine/graph/global/util.cpp +++ b/src/beanmachine/graph/global/util.cpp @@ -16,7 +16,7 @@ void set_default_transforms(Graph& g) { // add default transforms based on constraints // to transform all variables to the unconstrained space // POS_REAL variables -> LOG transform - // TODO: add simplex transform + // PROBABILITY variables -> SIGMOID transform for (uint node_id : g.compute_ordered_support_node_ids()) { // @lint-ignore CLANGTIDY auto node = g.nodes[node_id].get(); @@ -28,6 +28,11 @@ void set_default_transforms(Graph& g) { // initialize the type of the unconstrained value // TODO: rename method to be more clear sto_node->get_unconstrained_value(true); + } else if (node->value.type.atomic_type == AtomicType::PROBABILITY) { + g.customize_transformation(TransformType::SIGMOID, {node_id}); + // initialize the type of the unconstrained value + // TODO: rename method to be more clear + sto_node->get_unconstrained_value(true); } else if (node->value.type.atomic_type != AtomicType::REAL) { throw std::runtime_error( "Node " + std::to_string(node_id) + diff --git a/src/beanmachine/graph/graph.cpp b/src/beanmachine/graph/graph.cpp index 342733f27c..bfae4d11f3 100644 --- a/src/beanmachine/graph/graph.cpp +++ b/src/beanmachine/graph/graph.cpp @@ -873,6 +873,8 @@ void Graph::customize_transformation( if (common_transformations.empty()) { common_transformations[TransformType::LOG] = std::make_unique(); + common_transformations[TransformType::SIGMOID] = + std::make_unique(); } auto iter = common_transformations.find(customized_type); if (iter == common_transformations.end()) { @@ -894,6 +896,12 @@ void Graph::customize_transformation( "Log transformation requires POS_REAL value."); } break; + case TransformType::SIGMOID: + if (sto_node->value.type.atomic_type != AtomicType::PROBABILITY) { + throw std::invalid_argument( + "Sigmoid transformation requires PROBABILITY value."); + } + break; default: throw std::invalid_argument("Unsupported transformation type."); } diff --git a/src/beanmachine/graph/transform/sigmoidtransform.cpp b/src/beanmachine/graph/transform/sigmoidtransform.cpp new file mode 100644 index 0000000000..c8c9d50416 --- /dev/null +++ b/src/beanmachine/graph/transform/sigmoidtransform.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "beanmachine/graph/transform/transform.h" + +namespace beanmachine { +namespace transform { + +// See also https://en.wikipedia.org/wiki/Logit + +// y = f(x) = logit(x) = log(x / (1 - x)) +// dy/dx = 1 / (x - x^2) +void Sigmoid::operator()( + const graph::NodeValue& constrained, + graph::NodeValue& unconstrained) { + assert(constrained.type.atomic_type == graph::AtomicType::PROBABILITY); + if (constrained.type.variable_type == graph::VariableType::SCALAR) { + auto x = constrained._double; + unconstrained._double = std::log(x / (1 - x)); + } else if ( + constrained.type.variable_type == graph::VariableType::BROADCAST_MATRIX) { + auto x = constrained._matrix.array(); + unconstrained._matrix = (x / (1 - x)).log(); + } else { + throw std::invalid_argument( + "Sigmoid transformation requires PROBABILITY values."); + } +} + +// x = f^{-1}(y) = expit(y) = 1 / (1 + exp(-y)) +// dx/dy = exp(-y) / (1 + exp(-y))^2 +void Sigmoid::inverse( + graph::NodeValue& constrained, + const graph::NodeValue& unconstrained) { + assert(constrained.type.atomic_type == graph::AtomicType::PROBABILITY); + if (constrained.type.variable_type == graph::VariableType::SCALAR) { + auto y = unconstrained._double; + constrained._double = 1 / (1 + std::exp(-y)); + } else if ( + constrained.type.variable_type == graph::VariableType::BROADCAST_MATRIX) { + auto y = unconstrained._matrix.array(); + constrained._matrix = 1 / (1 + (-y).exp()); + } else { + throw std::invalid_argument( + "Sigmoid transformation requires PROBABILITY values."); + } +} + +/* +Return the log of the absolute jacobian determinant: + log |det(d x / d y)| +:param constrained: the node value x in constrained space +:param unconstrained: the node value y in unconstrained space + +for scalar, log |det(d x / d y)| + = log |exp(-y) / (1 + exp(-y))^2| + = y - 2 log(1 + exp(y)) +for matrix, log |det(d x / d y)| + = log |prod {y_i - 2 * Log[1 + exp(y_i)]}| + = sum{y_i - 2 * Log[1 + exp(y_i)]} +*/ +double Sigmoid::log_abs_jacobian_determinant( + const graph::NodeValue& constrained, + const graph::NodeValue& unconstrained) { + assert(constrained.type.atomic_type == graph::AtomicType::PROBABILITY); + if (constrained.type.variable_type == graph::VariableType::SCALAR) { + auto y = unconstrained._double; + return y - 2 * util::log1pexp(y); + } else if ( + constrained.type.variable_type == graph::VariableType::BROADCAST_MATRIX) { + /* Because this transformation is applied to each element of the input + matrix independently, the jabobian is a matrix with only values on the + diagnonal corresponding to the derivative of that value with respect to the + corresponding input. Also, the function is monotonically increasing, so the + derivative values are positive and so the absolute values of them are + positive. Therefore, the determinant of the jacobian is the product of the + diagonal entries, which is the product of the elementwise derivatives. The + log of that determinant is the sum of the log of the derivatives. */ + auto y = unconstrained._matrix.array(); + return (y - 2 * util::log1pexp(y).array()).sum(); + } else { + throw std::invalid_argument( + "Sigmoid transformation requires PROBABILITY values."); + } + return 0; +} + +/* +Given the gradient of the joint log prob w.r.t x (constrained), update the value +so that it is taken w.r.t y (unconstrained). + back_grad = back_grad * dx / dy + d(log |det(d x / d y)|) / dy + +:param back_grad: the gradient w.r.t x, a.k.a +:param constrained: the node value x in constrained space +:param unconstrained: the node value y in unconstrained space + +Given +x = constrained +y = unconstrained +back_grad = d/dx g(x) = g'(x) // for the joint log_prob function function g + +we want (for the first term) d/dy g(x) + where x = expit(y) = 1 / (1 + exp(-y)) + where expit'(y) = exp(-y) / (1 + exp(-y))^2 + +d/dy g(x) = +d/dy g(expit(y)) = +g'(expit(y)) expit'(y) = +g'(x) expit'(y) = +back_grad * exp(-y) / (1 + exp(-y))^2 + +and for the second term: + +d(log |det(d x / d y)|) / dy = + d(sum{y_i - 2 * Log[1 + exp(y_i)]}) / dy = + {-Tanh[y_i/2]} +*/ +void Sigmoid::unconstrained_gradient( + graph::DoubleMatrix& back_grad, + const graph::NodeValue& constrained, + const graph::NodeValue& unconstrained) { + assert(constrained.type.atomic_type == graph::AtomicType::PROBABILITY); + if (constrained.type.variable_type == graph::VariableType::SCALAR) { + auto y = unconstrained._double; + auto expmy = std::exp(-y); // exp(-y) + auto dxdy = expmy / std::pow(1 + expmy, 2); + auto dlddy = -std::tanh(y / 2); + back_grad = back_grad * dxdy + dlddy; + } else if ( + constrained.type.variable_type == graph::VariableType::BROADCAST_MATRIX) { + auto y = unconstrained._matrix.array(); + auto expmy = (-y).exp(); // exp(-y) + auto dxdy = expmy / (1 + expmy).pow(2); + auto dlddy = -(y / 2).tanh(); + back_grad = back_grad.array() * dxdy + dlddy; + } else { + throw std::invalid_argument( + "Sigmoid transformation requires scalar or broadcast matrix values."); + } +} + +} // namespace transform +} // namespace beanmachine diff --git a/src/beanmachine/graph/transform/transform.h b/src/beanmachine/graph/transform/transform.h index 87b2eb994d..cc92538460 100644 --- a/src/beanmachine/graph/transform/transform.h +++ b/src/beanmachine/graph/transform/transform.h @@ -12,6 +12,10 @@ namespace beanmachine { namespace transform { +// The Log transform maps values from the range (0..INF) to (-INF..INF), for +// example for the HalfCauchy and Half_Normal distributions. Implements the +// natural logarithm function (see +// https://en.wikipedia.org/wiki/Natural_logarithm). class Log : public graph::Transformation { public: Log() : Transformation(graph::TransformType::LOG) {} @@ -32,5 +36,28 @@ class Log : public graph::Transformation { const graph::NodeValue& unconstrained) override; }; +// The Sigmoid transform maps values from the range (0..1) to (-INF..INF), for +// example for the Beta and Dirichlet distributions. Implements the sigmoid +// function "logit" (see https://en.wikipedia.org/wiki/Logit). +class Sigmoid : public graph::Transformation { + public: + Sigmoid() : Transformation(graph::TransformType::SIGMOID) {} + ~Sigmoid() override {} + + void operator()( + const graph::NodeValue& constrained, + graph::NodeValue& unconstrained) override; + void inverse( + graph::NodeValue& constrained, + const graph::NodeValue& unconstrained) override; + double log_abs_jacobian_determinant( + const graph::NodeValue& constrained, + const graph::NodeValue& unconstrained) override; + void unconstrained_gradient( + graph::DoubleMatrix& back_grad, + const graph::NodeValue& constrained, + const graph::NodeValue& unconstrained) override; +}; + } // namespace transform } // namespace beanmachine diff --git a/src/beanmachine/graph/transform/transform_test.cpp b/src/beanmachine/graph/transform/transform_test.cpp index 30e1e73862..9b67d49efa 100644 --- a/src/beanmachine/graph/transform/transform_test.cpp +++ b/src/beanmachine/graph/transform/transform_test.cpp @@ -7,6 +7,7 @@ #include +#include #include "beanmachine/graph/distribution/flat.h" #include "beanmachine/graph/graph.h" #include "beanmachine/graph/operator/stochasticop.h" @@ -94,14 +95,29 @@ TEST(test_transform, log) { xobs << 0.5, 1.5; g2.observe(x2, xobs); - // To verify the results with pyTorch: - // log_a = torch.tensor(np.log(10.0), requires_grad=True) - // log_b = torch.tensor(np.log(1.2), requires_grad=True) - // log_x = torch.tensor(np.log([2.5, 0.5, 1.5]), requires_grad=True) - // log_p = torch.distributions.Gamma( - // log_a.exp(), log_b.exp()).log_prob(log_x.exp()).sum() - // log_q = log_p + log_a + log_b + log_x.sum() - // torch.autograd.grad(log_q, log_x)[0] + /* + # To verify the results with pyTorch: +from torch import tensor +from torch import log +from torch.distributions import Gamma +from torch.autograd import grad + +rate1 = tensor(10.0, requires_grad=True) +log_rate1 = log(rate1) +shape1 = tensor(1.2, requires_grad=True) +log_shape1 = log(shape1) +x = tensor([2.5, 0.5, 1.5], requires_grad=True) +log_x = tensor(log(x), requires_grad=True) +gamma = Gamma( + log_rate1.exp(), + log_shape1.exp()) +log_p = gamma.log_prob(log_x.exp()).sum() +log_q = log_p + log_rate1 + log_shape1 + log_x.sum() +print("rate1", grad(log_q, log_rate1, retain_graph=True)) +print("shape1", grad(log_q, log_shape1, retain_graph=True)) +print("x", grad(log_q, log_x, retain_graph=True)) +print("full_log_prob", log_q) + */ std::vector back_grad; g2.eval_and_grad(back_grad); EXPECT_EQ(back_grad.size(), 4); @@ -113,7 +129,180 @@ TEST(test_transform, log) { EXPECT_NEAR(g2.full_log_prob(), -29.5648, 1e-3); } -TEST(test_transform, unconstrained_type) { +namespace { +double logit(double p) { + return std::log(p / (1 - p)); +} +double expit(double x) { + return 1 / (1 + std::exp(-x)); +} +} // namespace + +TEST(test_transform, sigmoid_flat) { + std::mt19937 generator(1234); + Graph g1; + NodeValue *x, *y; + auto size = g1.add_constant((natural_t)2); + auto flat_real = + g1.add_distribution(DistributionType::FLAT, AtomicType::REAL, {}); + auto real1 = g1.add_operator(OperatorType::SAMPLE, {flat_real}); + // negative test: Sigmoid only applies to PROBABILITY + EXPECT_THROW( + g1.customize_transformation(TransformType::SIGMOID, {real1}), + std::invalid_argument); + // test transform and inverse transform + auto flat_pos = + g1.add_distribution(DistributionType::FLAT, AtomicType::PROBABILITY, {}); + auto pos1 = g1.add_operator(OperatorType::SAMPLE, {flat_pos}); + auto pos2 = g1.add_operator(OperatorType::IID_SAMPLE, {flat_pos, size}); + g1.customize_transformation(TransformType::SIGMOID, {pos1, pos2}); + g1.observe(pos1, 0.2); + Eigen::MatrixXd pos2_obs(2, 1); + pos2_obs << 0.4, 0.5; + g1.observe(pos2, pos2_obs); + + // scalar transform + auto n1 = static_cast( + g1.check_node(pos1, NodeType::OPERATOR)); + y = n1->get_unconstrained_value(false); + EXPECT_NEAR(y->_double, 0, 0.001); + x = n1->get_original_value(false); + EXPECT_NEAR(x->_double, 0.2, 0.001); + y = n1->get_unconstrained_value(true); + EXPECT_NEAR(y->_double, logit(0.2), 0.001); + y->_double = 0.0; + x = n1->get_original_value(true); + EXPECT_NEAR(x->_double, expit(0.0), 0.001); + + // vector transform + auto n2 = static_cast( + g1.check_node(pos2, NodeType::OPERATOR)); + y = n2->get_unconstrained_value(false); + EXPECT_NEAR(y->_matrix.squaredNorm(), 0, 0.001); + x = n2->get_original_value(false); + EXPECT_NEAR(x->_matrix.coeff(0), 0.4, 0.001); + EXPECT_NEAR(x->_matrix.coeff(1), 0.5, 0.001); + y = n2->get_unconstrained_value(true); + EXPECT_NEAR(y->_matrix.coeff(0), logit(0.4), 0.001); + EXPECT_NEAR(y->_matrix.coeff(1), logit(0.5), 0.001); + y->_matrix.setZero(); + x = n2->get_original_value(true); + EXPECT_NEAR(x->_matrix.coeff(0), expit(0.0), 0.001); + EXPECT_NEAR(x->_matrix.coeff(1), expit(0.0), 0.001); + + Graph g2; + size = g2.add_constant((natural_t)2); + auto dist = + g2.add_distribution(DistributionType::FLAT, AtomicType::PROBABILITY, {}); + auto x1 = g2.add_operator(OperatorType::SAMPLE, {dist}); + auto x2 = g2.add_operator(OperatorType::IID_SAMPLE, {dist, size}); + g2.customize_transformation(TransformType::SIGMOID, {x1, x2}); + g2.observe(x1, 0.2); + Eigen::MatrixXd xobs(2, 1); + xobs << 0.4, 0.5; + g2.observe(x2, xobs); +} + +TEST(test_transform, sigmoid_beta) { + std::mt19937 generator(1234); + Graph g1; + NodeValue *x, *y; + auto size = g1.add_constant((natural_t)2); + auto two = g1.add_constant_pos_real(2.0); + auto beta = g1.add_distribution( + DistributionType::BETA, AtomicType::PROBABILITY, {two, two}); + auto pos1 = g1.add_operator(OperatorType::SAMPLE, {beta}); + auto pos2 = g1.add_operator(OperatorType::IID_SAMPLE, {beta, size}); + g1.customize_transformation(TransformType::SIGMOID, {pos1, pos2}); + g1.observe(pos1, 0.2); + Eigen::MatrixXd pos2_obs(2, 1); + pos2_obs << 0.4, 0.5; + g1.observe(pos2, pos2_obs); + // scalar transform + auto n1 = static_cast( + g1.check_node(pos1, NodeType::OPERATOR)); + y = n1->get_unconstrained_value(false); + EXPECT_NEAR(y->_double, 0, 0.001); + x = n1->get_original_value(false); + EXPECT_NEAR(x->_double, 0.2, 0.001); + y = n1->get_unconstrained_value(true); + EXPECT_NEAR(y->_double, std::log(0.25), 0.001); + y->_double = 0.0; + x = n1->get_original_value(true); + EXPECT_NEAR(x->_double, expit(0.0), 0.001); + // vector transform + auto n2 = static_cast( + g1.check_node(pos2, NodeType::OPERATOR)); + y = n2->get_unconstrained_value(false); + EXPECT_NEAR(y->_matrix.squaredNorm(), 0, 0.001); + x = n2->get_original_value(false); + EXPECT_NEAR(x->_matrix.coeff(0), 0.4, 0.001); + EXPECT_NEAR(x->_matrix.coeff(1), 0.5, 0.001); + y = n2->get_unconstrained_value(true); + EXPECT_NEAR(y->_matrix.coeff(0), std::log(2.0 / 3), 0.001); + EXPECT_NEAR(y->_matrix.coeff(1), 0, 0.001); + y->_matrix.setZero(); + x = n2->get_original_value(true); + EXPECT_NEAR(x->_matrix.coeff(0), expit(0), 0.001); + EXPECT_NEAR(x->_matrix.coeff(1), expit(0), 0.001); +} + +TEST(test_transform, sigmoid_beta_2) { + // test log_abs_jacobian_determinant and unconstrained_gradient + Graph g2; + auto dist = g2.add_distribution( + DistributionType::BETA, + AtomicType::PROBABILITY, + {g2.add_constant_pos_real(0.25), g2.add_constant_pos_real(0.75)}); + + auto x1 = g2.add_operator(OperatorType::SAMPLE, {dist}); + g2.customize_transformation(TransformType::SIGMOID, {x1}); + g2.observe(x1, 0.5); + + auto x2 = g2.add_operator( + OperatorType::IID_SAMPLE, {dist, g2.add_constant((natural_t)2)}); + g2.customize_transformation(TransformType::SIGMOID, {x2}); + Eigen::MatrixXd xobs(2, 1); + xobs << 0.2, 0.3; + g2.observe(x2, xobs); + + /* + # To verify the results with pyTorch: +from torch import tensor +from torch import logit +from torch.special import expit +from torch.autograd import grad +from torch.distributions import Beta +from torch.distributions import TransformedDistribution +from torch.distributions import SigmoidTransform + +a = tensor(0.25, requires_grad=True) +b = tensor(0.75, requires_grad=True) +x = tensor([0.5, 0.2, 0.3], requires_grad=True) +logit_x = logit(x) + +beta_dist = Beta(a, b) +log_p = beta_dist.log_prob(expit(logit_x)).sum() + +transformed_dist = TransformedDistribution( + beta_dist, [SigmoidTransform().inv]) +full_log_prob = transformed_dist.log_prob(logit_x).sum() + +print("xgrad", grad(full_log_prob, logit_x, retain_graph=True)) +print("full_log_prob", full_log_prob) + */ + std::vector back_grad; + g2.eval_and_grad(back_grad); + EXPECT_EQ(back_grad.size(), 2); + // The indices below are of the random variables in the model, not node + // indices. + EXPECT_NEAR(g2.full_log_prob(), -6.3053, 1e-3); + EXPECT_NEAR((*back_grad[0]), -0.2500, 1e-3); // x1 + EXPECT_NEAR(back_grad[1]->coeff(0), 0.0500, 1e-3); // x2[0] + EXPECT_NEAR(back_grad[1]->coeff(1), -0.0500, 1e-3); // x2[1] +} + +TEST(test_transform, log_unconstrained_type) { Graph g1; auto size = g1.add_constant((natural_t)2); @@ -141,3 +330,29 @@ TEST(test_transform, unconstrained_type) { EXPECT_EQ( n2->get_unconstrained_value(true)->type.atomic_type, AtomicType::REAL); } + +TEST(test_transform, sigmoid_unconstrained_type) { + Graph g1; + auto size = g1.add_constant((natural_t)2); + + // test transform types + auto flat_pos = + g1.add_distribution(DistributionType::FLAT, AtomicType::PROBABILITY, {}); + auto sample = g1.add_operator(OperatorType::SAMPLE, {flat_pos}); + auto iid_sample = g1.add_operator(OperatorType::IID_SAMPLE, {flat_pos, size}); + g1.customize_transformation(TransformType::SIGMOID, {sample, iid_sample}); + + auto n1 = static_cast( + g1.check_node(sample, NodeType::OPERATOR)); + EXPECT_EQ(n1->value.type.atomic_type, AtomicType::PROBABILITY); + EXPECT_EQ(n1->unconstrained_value.type.atomic_type, AtomicType::REAL); + + auto n2 = static_cast( + g1.check_node(iid_sample, NodeType::OPERATOR)); + EXPECT_EQ(n2->value.type.atomic_type, AtomicType::PROBABILITY); + // check type is unknown before calling "get_unconstrained_value" + EXPECT_EQ(n2->unconstrained_value.type.atomic_type, AtomicType::UNKNOWN); + // check that the type is initialized properly + EXPECT_EQ( + n2->get_unconstrained_value(true)->type.atomic_type, AtomicType::REAL); +} diff --git a/src/beanmachine/graph/transformation.h b/src/beanmachine/graph/transformation.h index ebebe6f231..c144d304c8 100644 --- a/src/beanmachine/graph/transformation.h +++ b/src/beanmachine/graph/transformation.h @@ -12,7 +12,7 @@ namespace beanmachine::graph { class NodeValue; struct DoubleMatrix; -enum class TransformType { NONE = 0, LOG = 1 }; +enum class TransformType { NONE = 0, LOG = 1, SIGMOID = 2 }; class Transformation { public: @@ -62,9 +62,13 @@ class Transformation { return 0; } /* - Given the gradient of the joint log prob w.r.t x, update the value so - that it is taken w.r.t y: + Given the gradient of the joint log prob of the untransformed distribution + w.r.t x (the constrained value), update the value so that it is the gradient + of the joint log prob of the transformed distribution taken w.r.t y (the + unconstrained value): + back_grad = back_grad * dx / dy + d(log |det(d x / d y)|) / dy + :param back_grad: the gradient w.r.t x :param constrained: the node value x in constrained space :param unconstrained: the node value y in unconstrained space diff --git a/src/beanmachine/graph/util.cpp b/src/beanmachine/graph/util.cpp index 485dd4c6c4..0403c992e2 100644 --- a/src/beanmachine/graph/util.cpp +++ b/src/beanmachine/graph/util.cpp @@ -6,15 +6,13 @@ */ #define _USE_MATH_DEFINES -#include - +#include "beanmachine/graph/util.h" #include - +#include +#include #include "beanmachine/graph/graph.h" -#include "beanmachine/graph/util.h" -namespace beanmachine { -namespace util { +namespace beanmachine::util { // see https://core.ac.uk/download/pdf/41787448.pdf const double PHI_APPROX_GAMMA = 1.702; @@ -114,6 +112,10 @@ double log1pexp(double x) { } } +Eigen::MatrixXd log1pexp(const Eigen::MatrixXd& x) { + return x.unaryExpr([](double x) { return log1pexp(x); }); +} + double log1mexp(double x) { assert(x <= 0); if (x < -0.693) { @@ -123,5 +125,4 @@ double log1mexp(double x) { } } -} // namespace util -} // namespace beanmachine +} // namespace beanmachine::util diff --git a/src/beanmachine/graph/util.h b/src/beanmachine/graph/util.h index 9bf9008eec..bbf1683ecb 100644 --- a/src/beanmachine/graph/util.h +++ b/src/beanmachine/graph/util.h @@ -6,6 +6,7 @@ */ #pragma once +#include #include #include @@ -116,6 +117,8 @@ See: https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf */ double log1pexp(double x); +Eigen::MatrixXd log1pexp(const Eigen::MatrixXd& x); + /* Compute `log(1 - exp(x))` with numerical stability. See: https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf