Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/beanmachine/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ enum class OperatorType {
MATRIX_EXP,
LOG_PROB,
MATRIX_SUM,
MATRIX_LOG,
};

enum class DistributionType {
Expand Down
8 changes: 8 additions & 0 deletions src/beanmachine/graph/operator/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ void MatrixExp::backward() {
}
}

void MatrixLog::backward() {
assert(in_nodes.size() == 1);
if (in_nodes[0]->needs_gradient()) {
in_nodes[0]->back_grad1 +=
back_grad1.as_matrix().cwiseQuotient(in_nodes[0]->value._matrix);
}
}

void MatrixSum::backward() {
assert(in_nodes.size() == 1);
if (in_nodes[0]->needs_gradient()) {
Expand Down
13 changes: 13 additions & 0 deletions src/beanmachine/graph/operator/gradient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,19 @@ void MatrixExp::compute_gradients() {
value._matrix.cwiseProduct(in_nodes[0]->Grad2);
}

void MatrixLog::compute_gradients() {
assert(in_nodes.size() == 1);
// f(x) = log(g(x))
// f'(x) = g'(x) / g(x)
// f''(x) = (g''(x) * g(x) + g'(x) * g'(x)) / (g(x) * g(x))
// = g''(x) / g(x) + f'(x) * f'(x)
auto g = in_nodes[0]->value._matrix;
auto g1 = in_nodes[0]->Grad1;
auto g2 = in_nodes[0]->Grad2;
Grad1 = g1.cwiseQuotient(g);
Grad2 = g2.cwiseQuotient(g) + Grad1.cwiseProduct(Grad1);
}

void LogProb::compute_gradients() {
auto dist = (Distribution*)in_nodes[0];
auto value = in_nodes[1];
Expand Down
29 changes: 29 additions & 0 deletions src/beanmachine/graph/operator/linalgop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,5 +475,34 @@ void MatrixSum::eval(std::mt19937& /* gen */) {
value._double = in_nodes[0]->value._matrix.sum();
}

MatrixLog::MatrixLog(const std::vector<graph::Node*>& in_nodes)
: Operator(graph::OperatorType::MATRIX_LOG) {
if (in_nodes.size() != 1) {
throw std::invalid_argument("MATRIX_LOG requires one parent node");
}
auto type = in_nodes[0]->value.type;
if (type.variable_type != graph::VariableType::BROADCAST_MATRIX) {
throw std::invalid_argument(
"the parent of MATRIX_LOG must be a BROADCAST_MATRIX");
}
auto atomic_type = type.atomic_type;
graph::AtomicType new_type;
if (atomic_type == graph::AtomicType::POS_REAL) {
new_type = graph::AtomicType::REAL;
} else if (atomic_type == graph::AtomicType::PROBABILITY) {
new_type = graph::AtomicType::NEG_REAL;
} else {
throw std::invalid_argument(
"operator MATRIX_LOG requires a probability or pos_real parent");
}
value = graph::NodeValue(graph::ValueType(
graph::VariableType::BROADCAST_MATRIX, new_type, type.rows, type.cols));
}

void MatrixLog::eval(std::mt19937& /* gen */) {
assert(in_nodes.size() == 1);
value._matrix = Eigen::log(in_nodes[0]->value._matrix.array());
}

} // namespace oper
} // namespace beanmachine
15 changes: 15 additions & 0 deletions src/beanmachine/graph/operator/linalgop.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,5 +183,20 @@ class MatrixSum : public Operator {
static bool is_registered;
};

class MatrixLog : public Operator {
public:
explicit MatrixLog(const std::vector<graph::Node*>& in_nodes);
~MatrixLog() override {}

void eval(std::mt19937& gen) override;
void backward() override;
void compute_gradients() override;

static std::unique_ptr<Operator> new_op(
const std::vector<graph::Node*>& in_nodes) {
return std::make_unique<MatrixLog>(in_nodes);
}
};

} // namespace oper
} // namespace beanmachine
4 changes: 4 additions & 0 deletions src/beanmachine/graph/operator/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ bool ::beanmachine::oper::OperatorFactory::factories_are_registered =
graph::OperatorType::MATRIX_EXP,
&(MatrixExp::new_op)) &&

OperatorFactory::register_op(
graph::OperatorType::MATRIX_LOG,
&(MatrixLog::new_op)) &&

// matrix index
OperatorFactory::register_op(
graph::OperatorType::INDEX,
Expand Down
104 changes: 104 additions & 0 deletions src/beanmachine/graph/operator/tests/gradient_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,110 @@ TEST(testgradient, matrix_exp_grad) {
EXPECT_NEAR((*grad1[2]), 0.1487, 1e-3);
}

TEST(testgradient, matrix_log_grad_forward) {
Graph g;

// Test forward differentiation

Eigen::MatrixXd m1(1, 2);
m1 << 2.0, 3.0;
auto cm1 = g.add_constant_pos_matrix(m1);
auto c = g.add_constant_pos_real(2.0);
auto cm = g.add_operator(OperatorType::MATRIX_SCALE, {c, cm1});
auto mlog = g.add_operator(OperatorType::MATRIX_LOG, {cm});
// f(x) = log(2 * g(x))
// g(x) = [2, 3]
// but we artificially set
// g'(x) = [1, 1]
// g''(x) = [0, 0]
// for testing.

Node* cm1_node = g.get_node(cm1);
cm1_node->Grad1 = Eigen::MatrixXd::Ones(1, 2);
cm1_node->Grad2 = Eigen::MatrixXd::Zero(1, 2);
Node* cm_node = g.get_node(cm);
Node* mlog_node = g.get_node(mlog);
std::mt19937 gen;
cm_node->eval(gen);
cm_node->compute_gradients();
mlog_node->eval(gen);
mlog_node->compute_gradients();
Eigen::MatrixXd first_grad = mlog_node->Grad1;
Eigen::MatrixXd expected_first_grad(1, 2);
// By chain rule, f'(x) should be 2 * g'(x) / 2 * g(x) = [0.5, 0.33]
expected_first_grad << 0.5, 1.0 / 3.0;
_expect_near_matrix(first_grad, expected_first_grad);
Eigen::MatrixXd second_grad = mlog_node->Grad2;
Eigen::MatrixXd expected_second_grad(1, 2);
// f''(x) = (g''(x) * g'(x) + g'(x) * g'(x)) / (g(x) * g(x))
// = ([0, 0] * [1, 1] + [1, 1] * [1, 1]) / ([2, 3] * [2, 3])
// = [0.25, 0.11]
expected_second_grad << 0.25, 1.0 / 9.0;
_expect_near_matrix(second_grad, expected_second_grad);
}

TEST(testgradient, matrix_log_grad_backward) {
/*
# Test backward differentiation
#
# Build the same model in PyTorch and BMG; we should get the same
# backwards gradients as PyTorch.

import torch
hn = torch.distributions.HalfNormal(0, 1)
s0 = torch.tensor(1.0, requires_grad=True)
s1 = torch.tensor(0.5, requires_grad=True)
mlog0 = s0.log()
mlog1 = s1.log()
n0 = torch.distributions.Normal(mlog0, 1.0)
n1 = torch.distributions.Normal(mlog1, 1.0)
sn0 = torch.tensor(2.5, requires_grad=True)
sn1 = torch.tensor(1.5, requires_grad=True)
log_prob = (n0.log_prob(sn0) + n1.log_prob(sn1) +
hn.log_prob(s0) + hn.log_prob(s1))
torch.autograd.grad(log_prob, s0, retain_graph=True) # 1.5000
torch.autograd.grad(log_prob, s1, retain_graph=True) # 3.8863
torch.autograd.grad(log_prob, sn0, retain_graph=True) # -2.5000
torch.autograd.grad(log_prob, sn1, retain_graph=True) # -2.1931
*/

Graph g;
auto one = g.add_constant_pos_real(1.0);
auto hn = g.add_distribution(
DistributionType::HALF_NORMAL, AtomicType::POS_REAL, {one});
auto two = g.add_constant((natural_t)2);
auto hn_sample =
g.add_operator(OperatorType::IID_SAMPLE, std::vector<uint>{hn, two});
Eigen::MatrixXd hn_observed(2, 1);
hn_observed << 1.0, 0.5;
g.observe(hn_sample, hn_observed);

auto mlog_pos = g.add_operator(OperatorType::MATRIX_LOG, {hn_sample});
auto mlog = g.add_operator(OperatorType::TO_REAL_MATRIX, {mlog_pos});
auto index_zero = g.add_constant((natural_t)0);
auto mlog0 = g.add_operator(OperatorType::INDEX, {mlog, index_zero});
auto index_one = g.add_constant((natural_t)1);
auto mlog1 = g.add_operator(OperatorType::INDEX, {mlog, index_one});

auto n0 = g.add_distribution(
DistributionType::NORMAL, AtomicType::REAL, {mlog0, one});
auto n1 = g.add_distribution(
DistributionType::NORMAL, AtomicType::REAL, {mlog1, one});

auto ns0 = g.add_operator(OperatorType::SAMPLE, std::vector<uint>{n0});
g.observe(ns0, 2.5);
auto ns1 = g.add_operator(OperatorType::SAMPLE, std::vector<uint>{n1});
g.observe(ns1, 1.5);

std::vector<DoubleMatrix*> grad1;
g.eval_and_grad(grad1);
EXPECT_EQ(grad1.size(), 3);
EXPECT_NEAR((*grad1[0])(0), 1.5000, 1e-3);
EXPECT_NEAR((*grad1[0])(1), 3.8863, 1e-3);
EXPECT_NEAR((*grad1[1]), -2.5000, 1e-3);
EXPECT_NEAR((*grad1[2]), -2.1931, 1e-3);
}

TEST(testgradient, matrix_elementwise_mult_forward) {
Graph g;

Expand Down
43 changes: 42 additions & 1 deletion src/beanmachine/graph/operator/tests/operator_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,7 @@ TEST(testoperator, cholesky) {
}
}

TEST(testgradient, matrix_exp) {
TEST(testoperator, matrix_exp) {
Graph g;

// negative tests
Expand Down Expand Up @@ -1801,6 +1801,47 @@ TEST(testgradient, matrix_exp) {
}
}

TEST(testoperator, matrix_log) {
Graph g;

// negative tests
// MATRIX_LOG requires matrix parent
auto real_number = g.add_constant(2.0);
EXPECT_THROW(
g.add_operator(OperatorType::MATRIX_LOG, {real_number}),
std::invalid_argument);
// must be pos real or prob
Eigen::MatrixXb bools(2, 1);
bools << false, true;
auto bools_matrix = g.add_constant_bool_matrix(bools);
EXPECT_THROW(
g.add_operator(OperatorType::MATRIX_LOG, {bools_matrix}),
std::invalid_argument);
// can only have one parent
Eigen::MatrixXd m1(3, 1);
m1 << 2.0, 1.0, 3.0;
auto m1_matrix = g.add_constant_pos_matrix(m1);
Eigen::MatrixXd m2(1, 2);
m2 << 0.5, 20.0;
auto m2_matrix = g.add_constant_pos_matrix(m2);
EXPECT_THROW(
g.add_operator(OperatorType::MATRIX_LOG, {m1_matrix, m2_matrix}),
std::invalid_argument);

auto mlog = g.add_operator(OperatorType::MATRIX_LOG, {m1_matrix});
g.query(mlog);

auto mlog_infer = g.infer(2, InferenceType::REJECTION)[0][0];
Eigen::MatrixXd mlog_expected(3, 1);
mlog_expected << 2.0, 1.0, 3.0;
mlog_expected = Eigen::log(mlog_expected.array());
for (uint i = 0; i < mlog_infer.type.rows; i++) {
for (uint j = 0; j < mlog_infer.type.cols; j++) {
EXPECT_NEAR(mlog_expected(i, j), mlog_infer._matrix(i, j), 1e-4);
}
}
}

TEST(testoperator, log_prob) {
Graph g;
std::mt19937 gen;
Expand Down
3 changes: 2 additions & 1 deletion src/beanmachine/graph/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ PYBIND11_MODULE(graph, module) {
.value("CHOICE", OperatorType::CHOICE)
.value("CHOLESKY", OperatorType::CHOLESKY)
.value("MATRIX_EXP", OperatorType::MATRIX_EXP)
.value("MATRIX_SUM", OperatorType::MATRIX_SUM);
.value("MATRIX_SUM", OperatorType::MATRIX_SUM)
.value("MATRIX_LOG", OperatorType::MATRIX_LOG);

py::enum_<DistributionType>(module, "DistributionType")
.value("TABULAR", DistributionType::TABULAR)
Expand Down
12 changes: 12 additions & 0 deletions src/beanmachine/graph/to_dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ class DOT {
return "ToMatrix";
case OperatorType::COLUMN_INDEX:
return "ColumnIndex";
case OperatorType::BROADCAST_ADD:
return "BroadcastAdd";
case OperatorType::CHOLESKY:
return "Cholesky";
case OperatorType::MATRIX_EXP:
return "MatrixExp";
case OperatorType::LOG_PROB:
return "LogProb";
case OperatorType::MATRIX_SUM:
return "MatrixSum";
case OperatorType::MATRIX_LOG:
return "MatrixLog";
default:
return "Operator";
}
Expand Down
14 changes: 9 additions & 5 deletions src/beanmachine/ppl/compiler/tests/fix_matrix_type_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def _rv_id() -> RVIdentifier:

class FixMatrixOpTest(unittest.TestCase):
def test_fix_matrix_addition(self) -> None:
self.maxDiff = None
bmg = BMGraphBuilder()
zeros = bmg.add_real_matrix(torch.zeros(2))
ones = bmg.add_pos_real_matrix(torch.ones(2))
Expand Down Expand Up @@ -149,7 +150,7 @@ def test_fix_matrix_addition(self) -> None:
N11[label="~"];
N12[label="2"];
N13[label="ToMatrix"];
N14[label="Operator"];
N14[label="MatrixExp"];
N15[label="ToReal"];
N16[label="ElementwiseMultiply"];
N17[label="MatrixAdd"];
Expand Down Expand Up @@ -184,6 +185,7 @@ def test_fix_matrix_addition(self) -> None:
self.assertEqual(expectation.strip(), observed.strip())

def test_fix_elementwise_multiply(self) -> None:
self.maxDiff = None
bmg = BMGraphBuilder()
zeros = bmg.add_real_matrix(torch.zeros(2))
ones = bmg.add_pos_real_matrix(torch.ones(2))
Expand Down Expand Up @@ -319,11 +321,11 @@ def test_fix_elementwise_multiply(self) -> None:
N11[label="~"];
N12[label="2"];
N13[label="ToMatrix"];
N14[label="Operator"];
N14[label="MatrixExp"];
N15[label="ToReal"];
N16[label="MatrixAdd"];
N17[label="ElementwiseMultiply"];
N18[label="Operator"];
N18[label="MatrixSum"];
N0 -> N2;
N0 -> N8;
N1 -> N2;
Expand Down Expand Up @@ -356,6 +358,7 @@ def test_fix_elementwise_multiply(self) -> None:
self.assertEqual(expectation.strip(), observed.strip())

def test_fix_matrix_sum(self) -> None:
self.maxDiff = None
bmg = BMGraphBuilder()
probs = bmg.add_real_matrix(torch.tensor([[0.75, 0.25], [0.125, 0.875]]))
tensor_elements = []
Expand Down Expand Up @@ -466,7 +469,7 @@ def test_fix_matrix_sum(self) -> None:
N21[label="2"];
N22[label="ToMatrix"];
N23[label="ToReal"];
N24[label="Operator"];
N24[label="MatrixSum"];
N0 -> N2;
N0 -> N12;
N1 -> N2;
Expand Down Expand Up @@ -506,6 +509,7 @@ def test_fix_matrix_sum(self) -> None:
self.assertEqual(expectation.strip(), observed_bmg.strip())

def test_fix_matrix_exp(self) -> None:
self.maxDiff = None
bmg = BMGraphBuilder()
probs = bmg.add_real_matrix(torch.tensor([[0.75, 0.25], [0.125, 0.875]]))
tensor_elements = []
Expand Down Expand Up @@ -615,7 +619,7 @@ def test_fix_matrix_exp(self) -> None:
N21[label="2"];
N22[label="ToMatrix"];
N23[label="ToReal"];
N24[label="Operator"];
N24[label="MatrixExp"];
N0 -> N2;
N0 -> N12;
N1 -> N2;
Expand Down