diff --git a/src/beanmachine/graph/graph.h b/src/beanmachine/graph/graph.h index 52129203f8..8677cce900 100644 --- a/src/beanmachine/graph/graph.h +++ b/src/beanmachine/graph/graph.h @@ -343,6 +343,7 @@ enum class OperatorType { MATRIX_EXP, LOG_PROB, MATRIX_SUM, + MATRIX_LOG, }; enum class DistributionType { diff --git a/src/beanmachine/graph/operator/backward.cpp b/src/beanmachine/graph/operator/backward.cpp index 2f5ccc4e3f..6bec129491 100644 --- a/src/beanmachine/graph/operator/backward.cpp +++ b/src/beanmachine/graph/operator/backward.cpp @@ -354,6 +354,14 @@ void MatrixExp::backward() { } } +void MatrixLog::backward() { + assert(in_nodes.size() == 1); + if (in_nodes[0]->needs_gradient()) { + in_nodes[0]->back_grad1 += + back_grad1.as_matrix().cwiseQuotient(in_nodes[0]->value._matrix); + } +} + void MatrixSum::backward() { assert(in_nodes.size() == 1); if (in_nodes[0]->needs_gradient()) { diff --git a/src/beanmachine/graph/operator/gradient.cpp b/src/beanmachine/graph/operator/gradient.cpp index 1c21ad4b43..138b8745be 100644 --- a/src/beanmachine/graph/operator/gradient.cpp +++ b/src/beanmachine/graph/operator/gradient.cpp @@ -528,6 +528,19 @@ void MatrixExp::compute_gradients() { value._matrix.cwiseProduct(in_nodes[0]->Grad2); } +void MatrixLog::compute_gradients() { + assert(in_nodes.size() == 1); + // f(x) = log(g(x)) + // f'(x) = g'(x) / g(x) + // f''(x) = (g''(x) * g(x) + g'(x) * g'(x)) / (g(x) * g(x)) + // = g''(x) / g(x) + f'(x) * f'(x) + auto g = in_nodes[0]->value._matrix; + auto g1 = in_nodes[0]->Grad1; + auto g2 = in_nodes[0]->Grad2; + Grad1 = g1.cwiseQuotient(g); + Grad2 = g2.cwiseQuotient(g) + Grad1.cwiseProduct(Grad1); +} + void LogProb::compute_gradients() { auto dist = (Distribution*)in_nodes[0]; auto value = in_nodes[1]; diff --git a/src/beanmachine/graph/operator/linalgop.cpp b/src/beanmachine/graph/operator/linalgop.cpp index ce7f5345a9..5db0c48fef 100644 --- a/src/beanmachine/graph/operator/linalgop.cpp +++ b/src/beanmachine/graph/operator/linalgop.cpp @@ -475,5 +475,34 @@ void MatrixSum::eval(std::mt19937& /* gen */) { value._double = in_nodes[0]->value._matrix.sum(); } +MatrixLog::MatrixLog(const std::vector& in_nodes) + : Operator(graph::OperatorType::MATRIX_LOG) { + if (in_nodes.size() != 1) { + throw std::invalid_argument("MATRIX_LOG requires one parent node"); + } + auto type = in_nodes[0]->value.type; + if (type.variable_type != graph::VariableType::BROADCAST_MATRIX) { + throw std::invalid_argument( + "the parent of MATRIX_LOG must be a BROADCAST_MATRIX"); + } + auto atomic_type = type.atomic_type; + graph::AtomicType new_type; + if (atomic_type == graph::AtomicType::POS_REAL) { + new_type = graph::AtomicType::REAL; + } else if (atomic_type == graph::AtomicType::PROBABILITY) { + new_type = graph::AtomicType::NEG_REAL; + } else { + throw std::invalid_argument( + "operator MATRIX_LOG requires a probability or pos_real parent"); + } + value = graph::NodeValue(graph::ValueType( + graph::VariableType::BROADCAST_MATRIX, new_type, type.rows, type.cols)); +} + +void MatrixLog::eval(std::mt19937& /* gen */) { + assert(in_nodes.size() == 1); + value._matrix = Eigen::log(in_nodes[0]->value._matrix.array()); +} + } // namespace oper } // namespace beanmachine diff --git a/src/beanmachine/graph/operator/linalgop.h b/src/beanmachine/graph/operator/linalgop.h index b3fb25b030..85aef4d4e9 100644 --- a/src/beanmachine/graph/operator/linalgop.h +++ b/src/beanmachine/graph/operator/linalgop.h @@ -183,5 +183,20 @@ class MatrixSum : public Operator { static bool is_registered; }; +class MatrixLog : public Operator { + public: + explicit MatrixLog(const std::vector& in_nodes); + ~MatrixLog() override {} + + void eval(std::mt19937& gen) override; + void backward() override; + void compute_gradients() override; + + static std::unique_ptr new_op( + const std::vector& in_nodes) { + return std::make_unique(in_nodes); + } +}; + } // namespace oper } // namespace beanmachine diff --git a/src/beanmachine/graph/operator/register.cpp b/src/beanmachine/graph/operator/register.cpp index 956c2c15ff..b4802e362f 100644 --- a/src/beanmachine/graph/operator/register.cpp +++ b/src/beanmachine/graph/operator/register.cpp @@ -136,6 +136,10 @@ bool ::beanmachine::oper::OperatorFactory::factories_are_registered = graph::OperatorType::MATRIX_EXP, &(MatrixExp::new_op)) && + OperatorFactory::register_op( + graph::OperatorType::MATRIX_LOG, + &(MatrixLog::new_op)) && + // matrix index OperatorFactory::register_op( graph::OperatorType::INDEX, diff --git a/src/beanmachine/graph/operator/tests/gradient_test.cpp b/src/beanmachine/graph/operator/tests/gradient_test.cpp index 3a7516cc7b..acbaff39ad 100644 --- a/src/beanmachine/graph/operator/tests/gradient_test.cpp +++ b/src/beanmachine/graph/operator/tests/gradient_test.cpp @@ -1076,6 +1076,110 @@ TEST(testgradient, matrix_exp_grad) { EXPECT_NEAR((*grad1[2]), 0.1487, 1e-3); } +TEST(testgradient, matrix_log_grad_forward) { + Graph g; + + // Test forward differentiation + + Eigen::MatrixXd m1(1, 2); + m1 << 2.0, 3.0; + auto cm1 = g.add_constant_pos_matrix(m1); + auto c = g.add_constant_pos_real(2.0); + auto cm = g.add_operator(OperatorType::MATRIX_SCALE, {c, cm1}); + auto mlog = g.add_operator(OperatorType::MATRIX_LOG, {cm}); + // f(x) = log(2 * g(x)) + // g(x) = [2, 3] + // but we artificially set + // g'(x) = [1, 1] + // g''(x) = [0, 0] + // for testing. + + Node* cm1_node = g.get_node(cm1); + cm1_node->Grad1 = Eigen::MatrixXd::Ones(1, 2); + cm1_node->Grad2 = Eigen::MatrixXd::Zero(1, 2); + Node* cm_node = g.get_node(cm); + Node* mlog_node = g.get_node(mlog); + std::mt19937 gen; + cm_node->eval(gen); + cm_node->compute_gradients(); + mlog_node->eval(gen); + mlog_node->compute_gradients(); + Eigen::MatrixXd first_grad = mlog_node->Grad1; + Eigen::MatrixXd expected_first_grad(1, 2); + // By chain rule, f'(x) should be 2 * g'(x) / 2 * g(x) = [0.5, 0.33] + expected_first_grad << 0.5, 1.0 / 3.0; + _expect_near_matrix(first_grad, expected_first_grad); + Eigen::MatrixXd second_grad = mlog_node->Grad2; + Eigen::MatrixXd expected_second_grad(1, 2); + // f''(x) = (g''(x) * g'(x) + g'(x) * g'(x)) / (g(x) * g(x)) + // = ([0, 0] * [1, 1] + [1, 1] * [1, 1]) / ([2, 3] * [2, 3]) + // = [0.25, 0.11] + expected_second_grad << 0.25, 1.0 / 9.0; + _expect_near_matrix(second_grad, expected_second_grad); +} + +TEST(testgradient, matrix_log_grad_backward) { + /* +# Test backward differentiation +# +# Build the same model in PyTorch and BMG; we should get the same +# backwards gradients as PyTorch. + +import torch +hn = torch.distributions.HalfNormal(0, 1) +s0 = torch.tensor(1.0, requires_grad=True) +s1 = torch.tensor(0.5, requires_grad=True) +mlog0 = s0.log() +mlog1 = s1.log() +n0 = torch.distributions.Normal(mlog0, 1.0) +n1 = torch.distributions.Normal(mlog1, 1.0) +sn0 = torch.tensor(2.5, requires_grad=True) +sn1 = torch.tensor(1.5, requires_grad=True) +log_prob = (n0.log_prob(sn0) + n1.log_prob(sn1) + + hn.log_prob(s0) + hn.log_prob(s1)) +torch.autograd.grad(log_prob, s0, retain_graph=True) # 1.5000 +torch.autograd.grad(log_prob, s1, retain_graph=True) # 3.8863 +torch.autograd.grad(log_prob, sn0, retain_graph=True) # -2.5000 +torch.autograd.grad(log_prob, sn1, retain_graph=True) # -2.1931 + */ + + Graph g; + auto one = g.add_constant_pos_real(1.0); + auto hn = g.add_distribution( + DistributionType::HALF_NORMAL, AtomicType::POS_REAL, {one}); + auto two = g.add_constant((natural_t)2); + auto hn_sample = + g.add_operator(OperatorType::IID_SAMPLE, std::vector{hn, two}); + Eigen::MatrixXd hn_observed(2, 1); + hn_observed << 1.0, 0.5; + g.observe(hn_sample, hn_observed); + + auto mlog_pos = g.add_operator(OperatorType::MATRIX_LOG, {hn_sample}); + auto mlog = g.add_operator(OperatorType::TO_REAL_MATRIX, {mlog_pos}); + auto index_zero = g.add_constant((natural_t)0); + auto mlog0 = g.add_operator(OperatorType::INDEX, {mlog, index_zero}); + auto index_one = g.add_constant((natural_t)1); + auto mlog1 = g.add_operator(OperatorType::INDEX, {mlog, index_one}); + + auto n0 = g.add_distribution( + DistributionType::NORMAL, AtomicType::REAL, {mlog0, one}); + auto n1 = g.add_distribution( + DistributionType::NORMAL, AtomicType::REAL, {mlog1, one}); + + auto ns0 = g.add_operator(OperatorType::SAMPLE, std::vector{n0}); + g.observe(ns0, 2.5); + auto ns1 = g.add_operator(OperatorType::SAMPLE, std::vector{n1}); + g.observe(ns1, 1.5); + + std::vector grad1; + g.eval_and_grad(grad1); + EXPECT_EQ(grad1.size(), 3); + EXPECT_NEAR((*grad1[0])(0), 1.5000, 1e-3); + EXPECT_NEAR((*grad1[0])(1), 3.8863, 1e-3); + EXPECT_NEAR((*grad1[1]), -2.5000, 1e-3); + EXPECT_NEAR((*grad1[2]), -2.1931, 1e-3); +} + TEST(testgradient, matrix_elementwise_mult_forward) { Graph g; diff --git a/src/beanmachine/graph/operator/tests/operator_test.cpp b/src/beanmachine/graph/operator/tests/operator_test.cpp index 2590501c61..30ce4dec3e 100644 --- a/src/beanmachine/graph/operator/tests/operator_test.cpp +++ b/src/beanmachine/graph/operator/tests/operator_test.cpp @@ -1760,7 +1760,7 @@ TEST(testoperator, cholesky) { } } -TEST(testgradient, matrix_exp) { +TEST(testoperator, matrix_exp) { Graph g; // negative tests @@ -1801,6 +1801,47 @@ TEST(testgradient, matrix_exp) { } } +TEST(testoperator, matrix_log) { + Graph g; + + // negative tests + // MATRIX_LOG requires matrix parent + auto real_number = g.add_constant(2.0); + EXPECT_THROW( + g.add_operator(OperatorType::MATRIX_LOG, {real_number}), + std::invalid_argument); + // must be pos real or prob + Eigen::MatrixXb bools(2, 1); + bools << false, true; + auto bools_matrix = g.add_constant_bool_matrix(bools); + EXPECT_THROW( + g.add_operator(OperatorType::MATRIX_LOG, {bools_matrix}), + std::invalid_argument); + // can only have one parent + Eigen::MatrixXd m1(3, 1); + m1 << 2.0, 1.0, 3.0; + auto m1_matrix = g.add_constant_pos_matrix(m1); + Eigen::MatrixXd m2(1, 2); + m2 << 0.5, 20.0; + auto m2_matrix = g.add_constant_pos_matrix(m2); + EXPECT_THROW( + g.add_operator(OperatorType::MATRIX_LOG, {m1_matrix, m2_matrix}), + std::invalid_argument); + + auto mlog = g.add_operator(OperatorType::MATRIX_LOG, {m1_matrix}); + g.query(mlog); + + auto mlog_infer = g.infer(2, InferenceType::REJECTION)[0][0]; + Eigen::MatrixXd mlog_expected(3, 1); + mlog_expected << 2.0, 1.0, 3.0; + mlog_expected = Eigen::log(mlog_expected.array()); + for (uint i = 0; i < mlog_infer.type.rows; i++) { + for (uint j = 0; j < mlog_infer.type.cols; j++) { + EXPECT_NEAR(mlog_expected(i, j), mlog_infer._matrix(i, j), 1e-4); + } + } +} + TEST(testoperator, log_prob) { Graph g; std::mt19937 gen; diff --git a/src/beanmachine/graph/pybindings.cpp b/src/beanmachine/graph/pybindings.cpp index b31ad43db1..bb8fb7be27 100644 --- a/src/beanmachine/graph/pybindings.cpp +++ b/src/beanmachine/graph/pybindings.cpp @@ -90,7 +90,8 @@ PYBIND11_MODULE(graph, module) { .value("CHOICE", OperatorType::CHOICE) .value("CHOLESKY", OperatorType::CHOLESKY) .value("MATRIX_EXP", OperatorType::MATRIX_EXP) - .value("MATRIX_SUM", OperatorType::MATRIX_SUM); + .value("MATRIX_SUM", OperatorType::MATRIX_SUM) + .value("MATRIX_LOG", OperatorType::MATRIX_LOG); py::enum_(module, "DistributionType") .value("TABULAR", DistributionType::TABULAR) diff --git a/src/beanmachine/graph/to_dot.cpp b/src/beanmachine/graph/to_dot.cpp index 5832889670..80eb1e82a6 100644 --- a/src/beanmachine/graph/to_dot.cpp +++ b/src/beanmachine/graph/to_dot.cpp @@ -189,6 +189,18 @@ class DOT { return "ToMatrix"; case OperatorType::COLUMN_INDEX: return "ColumnIndex"; + case OperatorType::BROADCAST_ADD: + return "BroadcastAdd"; + case OperatorType::CHOLESKY: + return "Cholesky"; + case OperatorType::MATRIX_EXP: + return "MatrixExp"; + case OperatorType::LOG_PROB: + return "LogProb"; + case OperatorType::MATRIX_SUM: + return "MatrixSum"; + case OperatorType::MATRIX_LOG: + return "MatrixLog"; default: return "Operator"; } diff --git a/src/beanmachine/ppl/compiler/tests/fix_matrix_type_test.py b/src/beanmachine/ppl/compiler/tests/fix_matrix_type_test.py index b3d3bba5ec..1dd67deaaf 100644 --- a/src/beanmachine/ppl/compiler/tests/fix_matrix_type_test.py +++ b/src/beanmachine/ppl/compiler/tests/fix_matrix_type_test.py @@ -19,6 +19,7 @@ def _rv_id() -> RVIdentifier: class FixMatrixOpTest(unittest.TestCase): def test_fix_matrix_addition(self) -> None: + self.maxDiff = None bmg = BMGraphBuilder() zeros = bmg.add_real_matrix(torch.zeros(2)) ones = bmg.add_pos_real_matrix(torch.ones(2)) @@ -149,7 +150,7 @@ def test_fix_matrix_addition(self) -> None: N11[label="~"]; N12[label="2"]; N13[label="ToMatrix"]; - N14[label="Operator"]; + N14[label="MatrixExp"]; N15[label="ToReal"]; N16[label="ElementwiseMultiply"]; N17[label="MatrixAdd"]; @@ -184,6 +185,7 @@ def test_fix_matrix_addition(self) -> None: self.assertEqual(expectation.strip(), observed.strip()) def test_fix_elementwise_multiply(self) -> None: + self.maxDiff = None bmg = BMGraphBuilder() zeros = bmg.add_real_matrix(torch.zeros(2)) ones = bmg.add_pos_real_matrix(torch.ones(2)) @@ -319,11 +321,11 @@ def test_fix_elementwise_multiply(self) -> None: N11[label="~"]; N12[label="2"]; N13[label="ToMatrix"]; - N14[label="Operator"]; + N14[label="MatrixExp"]; N15[label="ToReal"]; N16[label="MatrixAdd"]; N17[label="ElementwiseMultiply"]; - N18[label="Operator"]; + N18[label="MatrixSum"]; N0 -> N2; N0 -> N8; N1 -> N2; @@ -356,6 +358,7 @@ def test_fix_elementwise_multiply(self) -> None: self.assertEqual(expectation.strip(), observed.strip()) def test_fix_matrix_sum(self) -> None: + self.maxDiff = None bmg = BMGraphBuilder() probs = bmg.add_real_matrix(torch.tensor([[0.75, 0.25], [0.125, 0.875]])) tensor_elements = [] @@ -466,7 +469,7 @@ def test_fix_matrix_sum(self) -> None: N21[label="2"]; N22[label="ToMatrix"]; N23[label="ToReal"]; - N24[label="Operator"]; + N24[label="MatrixSum"]; N0 -> N2; N0 -> N12; N1 -> N2; @@ -506,6 +509,7 @@ def test_fix_matrix_sum(self) -> None: self.assertEqual(expectation.strip(), observed_bmg.strip()) def test_fix_matrix_exp(self) -> None: + self.maxDiff = None bmg = BMGraphBuilder() probs = bmg.add_real_matrix(torch.tensor([[0.75, 0.25], [0.125, 0.875]])) tensor_elements = [] @@ -615,7 +619,7 @@ def test_fix_matrix_exp(self) -> None: N21[label="2"]; N22[label="ToMatrix"]; N23[label="ToReal"]; - N24[label="Operator"]; + N24[label="MatrixExp"]; N0 -> N2; N0 -> N12; N1 -> N2;