Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/beanmachine/graph/global/tests/util_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ TEST(testglobal, global_default_transform) {
g.add_operator(OperatorType::SAMPLE, {probability_dist});
g.query(probability_sample);

// TODO: add support for simplex distributions
EXPECT_THROW(set_default_transforms(g), std::runtime_error);
// test support for simplex distributions
set_default_transforms(g); // should run with no issues

Graph g1;
uint natural_dist =
Expand Down
7 changes: 6 additions & 1 deletion src/beanmachine/graph/global/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ void set_default_transforms(Graph& g) {
// add default transforms based on constraints
// to transform all variables to the unconstrained space
// POS_REAL variables -> LOG transform
// TODO: add simplex transform
// PROBABILITY variables -> SIGMOID transform
for (uint node_id : g.compute_ordered_support_node_ids()) {
// @lint-ignore CLANGTIDY
auto node = g.nodes[node_id].get();
Expand All @@ -28,6 +28,11 @@ void set_default_transforms(Graph& g) {
// initialize the type of the unconstrained value
// TODO: rename method to be more clear
sto_node->get_unconstrained_value(true);
} else if (node->value.type.atomic_type == AtomicType::PROBABILITY) {
g.customize_transformation(TransformType::SIGMOID, {node_id});
// initialize the type of the unconstrained value
// TODO: rename method to be more clear
sto_node->get_unconstrained_value(true);
} else if (node->value.type.atomic_type != AtomicType::REAL) {
throw std::runtime_error(
"Node " + std::to_string(node_id) +
Expand Down
8 changes: 8 additions & 0 deletions src/beanmachine/graph/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,8 @@ void Graph::customize_transformation(
if (common_transformations.empty()) {
common_transformations[TransformType::LOG] =
std::make_unique<transform::Log>();
common_transformations[TransformType::SIGMOID] =
std::make_unique<transform::Sigmoid>();
}
auto iter = common_transformations.find(customized_type);
if (iter == common_transformations.end()) {
Expand All @@ -894,6 +896,12 @@ void Graph::customize_transformation(
"Log transformation requires POS_REAL value.");
}
break;
case TransformType::SIGMOID:
if (sto_node->value.type.atomic_type != AtomicType::PROBABILITY) {
throw std::invalid_argument(
"Sigmoid transformation requires PROBABILITY value.");
}
break;
default:
throw std::invalid_argument("Unsupported transformation type.");
}
Expand Down
148 changes: 148 additions & 0 deletions src/beanmachine/graph/transform/sigmoidtransform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <beanmachine/graph/util.h>
#include "beanmachine/graph/transform/transform.h"

namespace beanmachine {
namespace transform {

// See also https://en.wikipedia.org/wiki/Logit

// y = f(x) = logit(x) = log(x / (1 - x))
// dy/dx = 1 / (x - x^2)
void Sigmoid::operator()(
const graph::NodeValue& constrained,
graph::NodeValue& unconstrained) {
assert(constrained.type.atomic_type == graph::AtomicType::PROBABILITY);
if (constrained.type.variable_type == graph::VariableType::SCALAR) {
auto x = constrained._double;
unconstrained._double = std::log(x / (1 - x));
} else if (
constrained.type.variable_type == graph::VariableType::BROADCAST_MATRIX) {
auto x = constrained._matrix.array();
unconstrained._matrix = (x / (1 - x)).log();
} else {
throw std::invalid_argument(
"Sigmoid transformation requires PROBABILITY values.");
}
}

// x = f^{-1}(y) = expit(y) = 1 / (1 + exp(-y))
// dx/dy = exp(-y) / (1 + exp(-y))^2
void Sigmoid::inverse(
graph::NodeValue& constrained,
const graph::NodeValue& unconstrained) {
assert(constrained.type.atomic_type == graph::AtomicType::PROBABILITY);
if (constrained.type.variable_type == graph::VariableType::SCALAR) {
auto y = unconstrained._double;
constrained._double = 1 / (1 + std::exp(-y));
} else if (
constrained.type.variable_type == graph::VariableType::BROADCAST_MATRIX) {
auto y = unconstrained._matrix.array();
constrained._matrix = 1 / (1 + (-y).exp());
} else {
throw std::invalid_argument(
"Sigmoid transformation requires PROBABILITY values.");
}
}

/*
Return the log of the absolute jacobian determinant:
log |det(d x / d y)|
:param constrained: the node value x in constrained space
:param unconstrained: the node value y in unconstrained space

for scalar, log |det(d x / d y)|
= log |exp(-y) / (1 + exp(-y))^2|
= y - 2 log(1 + exp(y))
for matrix, log |det(d x / d y)|
= log |prod {y_i - 2 * Log[1 + exp(y_i)]}|
= sum{y_i - 2 * Log[1 + exp(y_i)]}
*/
double Sigmoid::log_abs_jacobian_determinant(
const graph::NodeValue& constrained,
const graph::NodeValue& unconstrained) {
assert(constrained.type.atomic_type == graph::AtomicType::PROBABILITY);
if (constrained.type.variable_type == graph::VariableType::SCALAR) {
auto y = unconstrained._double;
return y - 2 * util::log1pexp(y);
} else if (
constrained.type.variable_type == graph::VariableType::BROADCAST_MATRIX) {
/* Because this transformation is applied to each element of the input
matrix independently, the jabobian is a matrix with only values on the
diagnonal corresponding to the derivative of that value with respect to the
corresponding input. Also, the function is monotonically increasing, so the
derivative values are positive and so the absolute values of them are
positive. Therefore, the determinant of the jacobian is the product of the
diagonal entries, which is the product of the elementwise derivatives. The
log of that determinant is the sum of the log of the derivatives. */
auto y = unconstrained._matrix.array();
return (y - 2 * util::log1pexp(y).array()).sum();
} else {
throw std::invalid_argument(
"Sigmoid transformation requires PROBABILITY values.");
}
return 0;
}

/*
Given the gradient of the joint log prob w.r.t x (constrained), update the value
so that it is taken w.r.t y (unconstrained).
back_grad = back_grad * dx / dy + d(log |det(d x / d y)|) / dy

:param back_grad: the gradient w.r.t x, a.k.a
:param constrained: the node value x in constrained space
:param unconstrained: the node value y in unconstrained space

Given
x = constrained
y = unconstrained
back_grad = d/dx g(x) = g'(x) // for the joint log_prob function function g

we want (for the first term) d/dy g(x)
where x = expit(y) = 1 / (1 + exp(-y))
where expit'(y) = exp(-y) / (1 + exp(-y))^2

d/dy g(x) =
d/dy g(expit(y)) =
g'(expit(y)) expit'(y) =
g'(x) expit'(y) =
back_grad * exp(-y) / (1 + exp(-y))^2

and for the second term:

d(log |det(d x / d y)|) / dy =
d(sum{y_i - 2 * Log[1 + exp(y_i)]}) / dy =
{-Tanh[y_i/2]}
*/
void Sigmoid::unconstrained_gradient(
graph::DoubleMatrix& back_grad,
const graph::NodeValue& constrained,
const graph::NodeValue& unconstrained) {
assert(constrained.type.atomic_type == graph::AtomicType::PROBABILITY);
if (constrained.type.variable_type == graph::VariableType::SCALAR) {
auto y = unconstrained._double;
auto expmy = std::exp(-y); // exp(-y)
auto dxdy = expmy / std::pow(1 + expmy, 2);
auto dlddy = -std::tanh(y / 2);
back_grad = back_grad * dxdy + dlddy;
} else if (
constrained.type.variable_type == graph::VariableType::BROADCAST_MATRIX) {
auto y = unconstrained._matrix.array();
auto expmy = (-y).exp(); // exp(-y)
auto dxdy = expmy / (1 + expmy).pow(2);
auto dlddy = -(y / 2).tanh();
back_grad = back_grad.array() * dxdy + dlddy;
} else {
throw std::invalid_argument(
"Sigmoid transformation requires scalar or broadcast matrix values.");
}
}

} // namespace transform
} // namespace beanmachine
27 changes: 27 additions & 0 deletions src/beanmachine/graph/transform/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
namespace beanmachine {
namespace transform {

// The Log transform maps values from the range (0..INF) to (-INF..INF), for
// example for the HalfCauchy and Half_Normal distributions. Implements the
// natural logarithm function (see
// https://en.wikipedia.org/wiki/Natural_logarithm).
class Log : public graph::Transformation {
public:
Log() : Transformation(graph::TransformType::LOG) {}
Expand All @@ -32,5 +36,28 @@ class Log : public graph::Transformation {
const graph::NodeValue& unconstrained) override;
};

// The Sigmoid transform maps values from the range (0..1) to (-INF..INF), for
// example for the Beta and Dirichlet distributions. Implements the sigmoid
// function "logit" (see https://en.wikipedia.org/wiki/Logit).
class Sigmoid : public graph::Transformation {
public:
Sigmoid() : Transformation(graph::TransformType::SIGMOID) {}
~Sigmoid() override {}

void operator()(
const graph::NodeValue& constrained,
graph::NodeValue& unconstrained) override;
void inverse(
graph::NodeValue& constrained,
const graph::NodeValue& unconstrained) override;
double log_abs_jacobian_determinant(
const graph::NodeValue& constrained,
const graph::NodeValue& unconstrained) override;
void unconstrained_gradient(
graph::DoubleMatrix& back_grad,
const graph::NodeValue& constrained,
const graph::NodeValue& unconstrained) override;
};

} // namespace transform
} // namespace beanmachine
Loading