From 6ef6ed0f231a46ad3f6d1ae09a21653b5b945dea Mon Sep 17 00:00:00 2001 From: QB3 Date: Thu, 26 Aug 2021 12:31:51 +0200 Subject: [PATCH 01/27] [WIP] first draft for sparse vjp --- jaxopt/_src/implicit_diff.py | 104 +++++++++++++++++++++++++++++++++++ jaxopt/implicit_diff.py | 1 + tests/implicit_diff_test.py | 14 +++++ 3 files changed, 119 insertions(+) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index 206e86ed..adc9e40a 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -19,6 +19,7 @@ from typing import Callable from typing import Tuple +import numpy as np # to be removed, this is for the first draft import jax from jaxopt._src import linear_solve @@ -73,6 +74,109 @@ def fun_args(*args): return vjp_fun_args(u) +def root_vjp(optimality_fun: Callable, + sol: Any, + args: Tuple, + cotangent: Any, + solve: Callable = linear_solve.solve_normal_cg) -> Any: + """Vector-Jacobian product of a root. + + The invariant is ``optimality_fun(sol, *args) == 0``. + + Args: + optimality_fun: the optimality function to use. + sol: solution / root (pytree). + args: tuple containing the arguments with respect to which we wish to + differentiate ``sol`` against. + cotangent: vector to left-multiply the Jacobian with + (pytree, same structure as ``sol``). + solve: a linear solver of the form, ``x = solve(matvec, b)``, + where ``matvec(x) = Ax`` and ``Ax=b``. + Returns: + vjps: tuple of the same length as ``len(args)`` containing the vjps w.r.t. + each argument. Each ``vjps[i]` has the same pytree structure as + ``args[i]``. + """ + def fun_sol(sol): + # We close over the arguments. + return optimality_fun(sol, *args) + + _, vjp_fun_sol = jax.vjp(fun_sol, sol) + + # Compute the multiplication A^T u = (u^T A)^T. + matvec = lambda u: vjp_fun_sol(u)[0] + + # The solution of A^T u = v, where + # A = jacobian(optimality_fun, argnums=0) + # v = -cotangent. + v = tree_scalar_mul(-1, cotangent) + u = solve(matvec, v) + + def fun_args(*args): + # We close over the solution. + return optimality_fun(sol, *args) + + _, vjp_fun_args = jax.vjp(fun_args, *args) + + return vjp_fun_args(u) + + +def sparse_root_vjp(optimality_fun: Callable, + sol: Any, + args: Tuple, + cotangent: Any, + solve: Callable = linear_solve.solve_normal_cg) -> Any: + """Sparse vector-Jacobian product of a root. + + The invariant is ``optimality_fun(sol, *args) == 0``. + + Args: + optimality_fun: the optimality function to use. + F in the paper + sol: solution / root (pytree). + args: tuple containing the arguments with respect to which we wish to + differentiate ``sol`` against. + cotangent: vector to left-multiply the Jacobian with + (pytree, same structure as ``sol``). + solve: a linear solver of the form, ``x = solve(matvec, b)``, + where ``matvec(x) = Ax`` and ``Ax=b``. + Returns: + vjps: tuple of the same length as ``len(args)`` containing the vjps w.r.t. + each argument. Each ``vjps[i]` has the same pytree structure as + ``args[i]``. + """ + support = sol != 0 # nonzeros coefficients of the solution + restricted_sol = sol[support] # solution restricted to the support + + def fun_sol(restricted_sol): + # We close over the arguments. + # Maybe this could be optimized + return optimality_fun(sol, *args)[support] + + _, vjp_fun_sol = jax.vjp(fun_sol(restricted_sol), restricted_sol) + + # Compute the multiplication A^T u = (u^T A)^T resticted to the support. + def restricted_matvec(restricted_v): + return vjp_fun_sol(restricted_v)[0] + + # The solution of A^T u = v, where + # A = jacobian(optimality_fun, argnums=0) + # v = -cotangent. + restricted_v = tree_scalar_mul(-1, cotangent[support]) + restricted_u = solve(restricted_matvec, restricted_v) + + u = np.zeros_like(sol) + u[support] = restricted_u + + def fun_args(*args): + # We close over the solution. + return optimality_fun(sol, *args) + + _, vjp_fun_args = jax.vjp(fun_args, *args) + + return vjp_fun_args(u) + + def _jvp_sol(optimality_fun, sol, args, tangent): """JVP in the first argument of optimality_fun.""" # We close over the arguments. diff --git a/jaxopt/implicit_diff.py b/jaxopt/implicit_diff.py index f043d428..9b808617 100644 --- a/jaxopt/implicit_diff.py +++ b/jaxopt/implicit_diff.py @@ -16,3 +16,4 @@ from jaxopt._src.implicit_diff import custom_fixed_point from jaxopt._src.implicit_diff import root_jvp from jaxopt._src.implicit_diff import root_vjp +from jaxopt._src.implicit_diff import sparse_root_vjp diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index cc36a900..b2104063 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -55,6 +55,20 @@ def test_root_vjp(self): J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_sparse_root_vjp(self): + X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) + optimality_fun = jax.grad(ridge_objective) + lam = 5.0 + sol = ridge_solver(None, lam, X, y) + vjp = lambda g: idf.sparse_root_vjp(optimality_fun=optimality_fun, + sol=sol, + args=(lam, X, y), + cotangent=g)[0] # vjp w.r.t. lam + I = jnp.eye(len(sol)) + J = jax.vmap(vjp)(I) + J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) + self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_root_jvp(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) optimality_fun = jax.grad(ridge_objective) From 64034b848601fbd8852b5d5c058fb4ac86ba0326 Mon Sep 17 00:00:00 2001 From: QB3 Date: Thu, 26 Aug 2021 12:46:10 +0200 Subject: [PATCH 02/27] [ci skip] fixed call jax.vjp, tests still do not pass --- jaxopt/_src/implicit_diff.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index adc9e40a..3ea69044 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -101,6 +101,7 @@ def fun_sol(sol): # We close over the arguments. return optimality_fun(sol, *args) + # import ipdb; ipdb.set_trace() _, vjp_fun_sol = jax.vjp(fun_sol, sol) # Compute the multiplication A^T u = (u^T A)^T. @@ -151,9 +152,10 @@ def sparse_root_vjp(optimality_fun: Callable, def fun_sol(restricted_sol): # We close over the arguments. # Maybe this could be optimized - return optimality_fun(sol, *args)[support] + return optimality_fun(restricted_sol, *args)[support] - _, vjp_fun_sol = jax.vjp(fun_sol(restricted_sol), restricted_sol) + # import ipdb; ipdb.set_trace() + _, vjp_fun_sol = jax.vjp(fun_sol, restricted_sol) # Compute the multiplication A^T u = (u^T A)^T resticted to the support. def restricted_matvec(restricted_v): From 71b8b2c077202dd3839426c278070a9b6f76b772 Mon Sep 17 00:00:00 2001 From: QB3 Date: Thu, 26 Aug 2021 12:52:21 +0200 Subject: [PATCH 03/27] [ci skip] made test implicit diff for sparse_jvp pass --- jaxopt/_src/implicit_diff.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index 3ea69044..3a11970f 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -25,6 +25,7 @@ from jaxopt._src import linear_solve from jaxopt._src.tree_util import tree_scalar_mul from jaxopt._src.tree_util import tree_sub +from jaxopt._src.tree_util import tree_zeros_like def root_vjp(optimality_fun: Callable, @@ -167,16 +168,13 @@ def restricted_matvec(restricted_v): restricted_v = tree_scalar_mul(-1, cotangent[support]) restricted_u = solve(restricted_matvec, restricted_v) - u = np.zeros_like(sol) - u[support] = restricted_u - def fun_args(*args): # We close over the solution. - return optimality_fun(sol, *args) + return optimality_fun(sol, *args)[support] _, vjp_fun_args = jax.vjp(fun_args, *args) - return vjp_fun_args(u) + return vjp_fun_args(restricted_u) def _jvp_sol(optimality_fun, sol, args, tangent): From 4e8a7914c154118693ff4d4c99cf9982de5a2f0b Mon Sep 17 00:00:00 2001 From: QB3 Date: Thu, 26 Aug 2021 13:21:36 +0200 Subject: [PATCH 04/27] [ci skip] added test lasso, currently fails --- jaxopt/_src/implicit_diff.py | 2 +- tests/implicit_diff_test.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index 3a11970f..0517c297 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -153,7 +153,7 @@ def sparse_root_vjp(optimality_fun: Callable, def fun_sol(restricted_sol): # We close over the arguments. # Maybe this could be optimized - return optimality_fun(restricted_sol, *args)[support] + return optimality_fun(sol, *args)[support] # import ipdb; ipdb.set_trace() _, vjp_fun_sol = jax.vjp(fun_sol, restricted_sol) diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index b2104063..19e52bde 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -30,6 +30,12 @@ def ridge_objective(params, lam, X, y): return 0.5 * jnp.mean(residuals ** 2) + 0.5 * lam * jnp.sum(params ** 2) +def lasso_objective(params, lam, X, y): + residuals = jnp.dot(X, params) - y + return 0.5 * jnp.mean(residuals ** 2) / len(y) + lam * jnp.sum( + jnp.abs(params)) + + # def ridge_solver(init_params, lam, X, y): def ridge_solver(init_params, lam, X, y): del init_params # not used @@ -69,6 +75,23 @@ def test_sparse_root_vjp(self): J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_lasso_sparse_root_vjp(self): + X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) + optimality_fun = jax.grad(lasso_objective) + lam = 5.0 + lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) + lam = lam_max / 2 + sol = test_util.lasso_skl(X, y, lam) + vjp = lambda g: idf.sparse_root_vjp(optimality_fun=optimality_fun, + sol=sol, + args=(lam, X, y), + cotangent=g)[0] # vjp w.r.t. lam + I = jnp.eye(len(sol)) + J = jax.vmap(vjp)(I) + J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) + # import ipdb; ipdb.set_trace() + self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_root_jvp(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) optimality_fun = jax.grad(ridge_objective) From a9883d60d7ce3e180bb49ee2a2617a217def693d Mon Sep 17 00:00:00 2001 From: QB3 Date: Thu, 26 Aug 2021 13:31:12 +0200 Subject: [PATCH 05/27] [ci skip] added test lasso without sparsity --- tests/implicit_diff_test.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index 19e52bde..03126207 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -61,24 +61,38 @@ def test_root_vjp(self): J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) self.assertArraysAllClose(J, J_num, atol=5e-2) - def test_sparse_root_vjp(self): + # def test_sparse_root_vjp(self): + # X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) + # optimality_fun = jax.grad(ridge_objective) + # lam = 5.0 + # sol = ridge_solver(None, lam, X, y) + # vjp = lambda g: idf.sparse_root_vjp(optimality_fun=optimality_fun, + # sol=sol, + # args=(lam, X, y), + # cotangent=g)[0] # vjp w.r.t. lam + # I = jnp.eye(len(sol)) + # J = jax.vmap(vjp)(I) + # J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) + # self.assertArraysAllClose(J, J_num, atol=5e-2) + + def test_lasso_root_vjp(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) - optimality_fun = jax.grad(ridge_objective) - lam = 5.0 - sol = ridge_solver(None, lam, X, y) - vjp = lambda g: idf.sparse_root_vjp(optimality_fun=optimality_fun, - sol=sol, - args=(lam, X, y), - cotangent=g)[0] # vjp w.r.t. lam + optimality_fun = jax.grad(lasso_objective) + lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) + lam = lam_max / 2 + sol = test_util.lasso_skl(X, y, lam) + vjp = lambda g: idf.root_vjp(optimality_fun=optimality_fun, + sol=sol, + args=(lam, X, y), + cotangent=g)[0] # vjp w.r.t. lam I = jnp.eye(len(sol)) J = jax.vmap(vjp)(I) - J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) + J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) self.assertArraysAllClose(J, J_num, atol=5e-2) def test_lasso_sparse_root_vjp(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) optimality_fun = jax.grad(lasso_objective) - lam = 5.0 lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) lam = lam_max / 2 sol = test_util.lasso_skl(X, y, lam) From 7733f9a10252865ccddc665469c9a4dd6508307c Mon Sep 17 00:00:00 2001 From: QB3 Date: Thu, 26 Aug 2021 13:59:55 +0200 Subject: [PATCH 06/27] Trigger google-cla From cbe7bac011eb1f438c78d56c42ea17840c2f7141 Mon Sep 17 00:00:00 2001 From: QB3 Date: Thu, 26 Aug 2021 15:14:47 +0200 Subject: [PATCH 07/27] [ci skip] made test pass for lasso without sparse computation --- tests/implicit_diff_test.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index 03126207..5a117922 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -15,10 +15,12 @@ from absl.testing import absltest from absl.testing import parameterized +import numpy as np import jax from jax import test_util as jtu import jax.numpy as jnp +from jaxopt import prox from jaxopt import implicit_diff as idf from jaxopt._src import test_util @@ -77,7 +79,12 @@ def test_root_vjp(self): def test_lasso_root_vjp(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) - optimality_fun = jax.grad(lasso_objective) + L = jax.numpy.linalg.norm(X, ord=2) ** 2 + + def optimality_fun(params, lam, X, y): + return prox.prox_lasso( + params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) lam = lam_max / 2 sol = test_util.lasso_skl(X, y, lam) From 478d871a756c649c0e406921a21c5c6ed66e90ac Mon Sep 17 00:00:00 2001 From: QB3 Date: Thu, 26 Aug 2021 15:32:57 +0200 Subject: [PATCH 08/27] [ci skip]@ new try for sparse computation, still fails --- jaxopt/_src/implicit_diff.py | 6 ++++-- tests/implicit_diff_test.py | 8 +++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index 0517c297..cc96ee84 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -153,9 +153,10 @@ def sparse_root_vjp(optimality_fun: Callable, def fun_sol(restricted_sol): # We close over the arguments. # Maybe this could be optimized - return optimality_fun(sol, *args)[support] + sol_ = tree_zeros_like(sol) + jax.ops.index_update(sol_, support, restricted_sol) + return optimality_fun(sol_, *args)[support] - # import ipdb; ipdb.set_trace() _, vjp_fun_sol = jax.vjp(fun_sol, restricted_sol) # Compute the multiplication A^T u = (u^T A)^T resticted to the support. @@ -168,6 +169,7 @@ def restricted_matvec(restricted_v): restricted_v = tree_scalar_mul(-1, cotangent[support]) restricted_u = solve(restricted_matvec, restricted_v) + import ipdb; ipdb.set_trace() def fun_args(*args): # We close over the solution. return optimality_fun(sol, *args)[support] diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index 5a117922..923a169f 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -99,7 +99,13 @@ def optimality_fun(params, lam, X, y): def test_lasso_sparse_root_vjp(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) - optimality_fun = jax.grad(lasso_objective) + + L = jax.numpy.linalg.norm(X, ord=2) ** 2 + + def optimality_fun(params, lam, X, y): + return prox.prox_lasso( + params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) lam = lam_max / 2 sol = test_util.lasso_skl(X, y, lam) From 0346b9703b63d3dbb0290e0a3f084826fad69792 Mon Sep 17 00:00:00 2001 From: QB3 Date: Thu, 26 Aug 2021 16:53:27 +0200 Subject: [PATCH 09/27] [ci skip] made sparse jvp work, remains to see how much we win --- jaxopt/_src/implicit_diff.py | 8 +++----- tests/implicit_diff_test.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index cc96ee84..1d69c829 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -102,7 +102,6 @@ def fun_sol(sol): # We close over the arguments. return optimality_fun(sol, *args) - # import ipdb; ipdb.set_trace() _, vjp_fun_sol = jax.vjp(fun_sol, sol) # Compute the multiplication A^T u = (u^T A)^T. @@ -154,8 +153,8 @@ def fun_sol(restricted_sol): # We close over the arguments. # Maybe this could be optimized sol_ = tree_zeros_like(sol) - jax.ops.index_update(sol_, support, restricted_sol) - return optimality_fun(sol_, *args)[support] + sol_ = jax.ops.index_update(sol_, support, restricted_sol) + return optimality_fun(sol_, *args) _, vjp_fun_sol = jax.vjp(fun_sol, restricted_sol) @@ -169,10 +168,9 @@ def restricted_matvec(restricted_v): restricted_v = tree_scalar_mul(-1, cotangent[support]) restricted_u = solve(restricted_matvec, restricted_v) - import ipdb; ipdb.set_trace() def fun_args(*args): # We close over the solution. - return optimality_fun(sol, *args)[support] + return optimality_fun(sol, *args) _, vjp_fun_args = jax.vjp(fun_args, *args) diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index 923a169f..df223780 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from numpy.core.numeric import ones from absl.testing import absltest from absl.testing import parameterized @@ -23,6 +24,7 @@ from jaxopt import prox from jaxopt import implicit_diff as idf from jaxopt._src import test_util +from jaxopt import objective from sklearn import datasets @@ -103,12 +105,20 @@ def test_lasso_sparse_root_vjp(self): L = jax.numpy.linalg.norm(X, ord=2) ** 2 def optimality_fun(params, lam, X, y): - return prox.prox_lasso( - params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + support = params != 0 + res = X[:, support].T @ (X[:, support] @ params[support] - y) / L + res = params[support] - res + res = prox.prox_lasso(res, lam * len(y) / L) + res -= params[support] + return res lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) lam = lam_max / 2 sol = test_util.lasso_skl(X, y, lam) + + # jax.jacobian(optimality_fun)(jnp.ones(X.shape[1]), lam, X, y) + # test the mask in optimality_fun + vjp = lambda g: idf.sparse_root_vjp(optimality_fun=optimality_fun, sol=sol, args=(lam, X, y), From 8a3c12821e8932bd88466c799587d46029bb2dc0 Mon Sep 17 00:00:00 2001 From: QB3 Date: Thu, 26 Aug 2021 19:15:23 +0200 Subject: [PATCH 10/27] [ci skip] added little bench for sparse vjp, sol is better but running time is the same --- benchmarks/sparse_vjp.py | 66 +++++++++++++++++++++++++++++++++++++ tests/implicit_diff_test.py | 15 --------- 2 files changed, 66 insertions(+), 15 deletions(-) create mode 100644 benchmarks/sparse_vjp.py diff --git a/benchmarks/sparse_vjp.py b/benchmarks/sparse_vjp.py new file mode 100644 index 00000000..818347ac --- /dev/null +++ b/benchmarks/sparse_vjp.py @@ -0,0 +1,66 @@ +import time +import jax + +import jax.numpy as jnp + +from jaxopt import prox +from jaxopt import implicit_diff as idf +from jaxopt._src import test_util + + +from sklearn import datasets + +X, y = datasets.make_regression(n_samples=10, n_features=1000, random_state=0) + +L = jax.numpy.linalg.norm(X, ord=2) ** 2 + + +def optimality_fun(params, lam, X, y): + return prox.prox_lasso( + params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + + +def optimality_fun_sparse(params, lam, X, y): + support = params != 0 + res = X[:, support].T @ (X[:, support] @ params[support] - y) / L + res = params[support] - res + res = prox.prox_lasso(res, lam * len(y) / L) + res -= params[support] + return res + + +lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) +lam = lam_max / 2 +t_start = time.time() +sol = test_util.lasso_skl(X, y, lam) +t_optim = time.time() - t_start + +vjp = lambda g: idf.root_vjp(optimality_fun=optimality_fun, + sol=sol, + args=(lam, X, y), + cotangent=g)[0] # vjp w.r.t. lam + +vjp_sparse = lambda g: idf.sparse_root_vjp( + optimality_fun=optimality_fun_sparse, + sol=sol, + args=(lam, X, y), + cotangent=g)[0] # vjp w.r.t. lam + +t_start = time.time() +I = jnp.eye(len(sol)) +J = jax.vmap(vjp)(I) +t_jac = time.time() - t_start + +t_start = time.time() +I = jnp.eye(len(sol)) +J_sparse = jax.vmap(vjp_sparse)(I) +t_jac_sparse = time.time() - t_start + +print("Time taken to solve the Lasso optimization problem %.3f" % t_optim) +print("Time taken to compute the Jacobian %.3f" % t_jac) +print("Time taken to compute the Jacobian with the sparse implementation %.3f" % t_jac) + + +# Computation time are the same, which is very weird to me +# However, the Jacobian computed the sparse way is much closer to the real +# Jacobian diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index df223780..78e77d25 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -65,20 +65,6 @@ def test_root_vjp(self): J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) self.assertArraysAllClose(J, J_num, atol=5e-2) - # def test_sparse_root_vjp(self): - # X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) - # optimality_fun = jax.grad(ridge_objective) - # lam = 5.0 - # sol = ridge_solver(None, lam, X, y) - # vjp = lambda g: idf.sparse_root_vjp(optimality_fun=optimality_fun, - # sol=sol, - # args=(lam, X, y), - # cotangent=g)[0] # vjp w.r.t. lam - # I = jnp.eye(len(sol)) - # J = jax.vmap(vjp)(I) - # J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) - # self.assertArraysAllClose(J, J_num, atol=5e-2) - def test_lasso_root_vjp(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) L = jax.numpy.linalg.norm(X, ord=2) ** 2 @@ -126,7 +112,6 @@ def optimality_fun(params, lam, X, y): I = jnp.eye(len(sol)) J = jax.vmap(vjp)(I) J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) - # import ipdb; ipdb.set_trace() self.assertArraysAllClose(J, J_num, atol=5e-2) def test_root_jvp(self): From c9b0daea3f90dec392fbcb73097eb88d39010aa2 Mon Sep 17 00:00:00 2001 From: QB3 Date: Fri, 27 Aug 2021 10:51:36 +0200 Subject: [PATCH 11/27] [ci skip] try implemetation with hardcoded support --- benchmarks/sparse_vjp.py | 16 ++++++++-------- jaxopt/_src/implicit_diff.py | 12 ++++++------ tests/implicit_diff_test.py | 16 ++++++++++------ 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/benchmarks/sparse_vjp.py b/benchmarks/sparse_vjp.py index 818347ac..5be001e7 100644 --- a/benchmarks/sparse_vjp.py +++ b/benchmarks/sparse_vjp.py @@ -20,13 +20,13 @@ def optimality_fun(params, lam, X, y): params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params -def optimality_fun_sparse(params, lam, X, y): - support = params != 0 - res = X[:, support].T @ (X[:, support] @ params[support] - y) / L - res = params[support] - res - res = prox.prox_lasso(res, lam * len(y) / L) - res -= params[support] - return res +# def optimality_fun_sparse(params, lam, X, y): +# support = params != 0 +# res = X[:, support].T @ (X[:, support] @ params[support] - y) / L +# res = params[support] - res +# res = prox.prox_lasso(res, lam * len(y) / L) +# res -= params[support] +# return res lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) @@ -41,7 +41,7 @@ def optimality_fun_sparse(params, lam, X, y): cotangent=g)[0] # vjp w.r.t. lam vjp_sparse = lambda g: idf.sparse_root_vjp( - optimality_fun=optimality_fun_sparse, + optimality_fun=optimality_fun, sol=sol, args=(lam, X, y), cotangent=g)[0] # vjp w.r.t. lam diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index 1d69c829..2f08950d 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -149,12 +149,12 @@ def sparse_root_vjp(optimality_fun: Callable, support = sol != 0 # nonzeros coefficients of the solution restricted_sol = sol[support] # solution restricted to the support + lam, X, y = args + new_args = lam, X[:, support], y + def fun_sol(restricted_sol): # We close over the arguments. - # Maybe this could be optimized - sol_ = tree_zeros_like(sol) - sol_ = jax.ops.index_update(sol_, support, restricted_sol) - return optimality_fun(sol_, *args) + return optimality_fun(restricted_sol, *new_args) _, vjp_fun_sol = jax.vjp(fun_sol, restricted_sol) @@ -170,9 +170,9 @@ def restricted_matvec(restricted_v): def fun_args(*args): # We close over the solution. - return optimality_fun(sol, *args) + return optimality_fun(restricted_sol, *args) - _, vjp_fun_args = jax.vjp(fun_args, *args) + _, vjp_fun_args = jax.vjp(fun_args, *new_args) return vjp_fun_args(restricted_u) diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index 78e77d25..06dbb1e4 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -91,12 +91,16 @@ def test_lasso_sparse_root_vjp(self): L = jax.numpy.linalg.norm(X, ord=2) ** 2 def optimality_fun(params, lam, X, y): - support = params != 0 - res = X[:, support].T @ (X[:, support] @ params[support] - y) / L - res = params[support] - res - res = prox.prox_lasso(res, lam * len(y) / L) - res -= params[support] - return res + return prox.prox_lasso( + params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + + # def optimality_fun(params, lam, X, y): + # support = params != 0 + # res = X[:, support].T @ (X[:, support] @ params[support] - y) / L + # res = params[support] - res + # res = prox.prox_lasso(res, lam * len(y) / L) + # res -= params[support] + # return res lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) lam = lam_max / 2 From b698e02bb8343d28e99fb5485ba6d1b680e69543 Mon Sep 17 00:00:00 2001 From: QB3 Date: Mon, 30 Aug 2021 12:35:27 +0200 Subject: [PATCH 12/27] [ci skip] take larger number of features, see speed ups --- benchmarks/sparse_vjp.py | 4 +- jaxopt/_src/implicit_diff.py | 90 ++++++++++++++++++------------------ 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/benchmarks/sparse_vjp.py b/benchmarks/sparse_vjp.py index 5be001e7..0321f43a 100644 --- a/benchmarks/sparse_vjp.py +++ b/benchmarks/sparse_vjp.py @@ -10,7 +10,7 @@ from sklearn import datasets -X, y = datasets.make_regression(n_samples=10, n_features=1000, random_state=0) +X, y = datasets.make_regression(n_samples=10, n_features=10_000, random_state=0) L = jax.numpy.linalg.norm(X, ord=2) ** 2 @@ -58,7 +58,7 @@ def optimality_fun(params, lam, X, y): print("Time taken to solve the Lasso optimization problem %.3f" % t_optim) print("Time taken to compute the Jacobian %.3f" % t_jac) -print("Time taken to compute the Jacobian with the sparse implementation %.3f" % t_jac) +print("Time taken to compute the Jacobian with the sparse implementation %.3f" % t_jac_sparse) # Computation time are the same, which is very weird to me diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index 2f08950d..da72461c 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -28,51 +28,51 @@ from jaxopt._src.tree_util import tree_zeros_like -def root_vjp(optimality_fun: Callable, - sol: Any, - args: Tuple, - cotangent: Any, - solve: Callable = linear_solve.solve_normal_cg) -> Any: - """Vector-Jacobian product of a root. - - The invariant is ``optimality_fun(sol, *args) == 0``. - - Args: - optimality_fun: the optimality function to use. - sol: solution / root (pytree). - args: tuple containing the arguments with respect to which we wish to - differentiate ``sol`` against. - cotangent: vector to left-multiply the Jacobian with - (pytree, same structure as ``sol``). - solve: a linear solver of the form, ``x = solve(matvec, b)``, - where ``matvec(x) = Ax`` and ``Ax=b``. - Returns: - vjps: tuple of the same length as ``len(args)`` containing the vjps w.r.t. - each argument. Each ``vjps[i]` has the same pytree structure as - ``args[i]``. - """ - def fun_sol(sol): - # We close over the arguments. - return optimality_fun(sol, *args) - - _, vjp_fun_sol = jax.vjp(fun_sol, sol) - - # Compute the multiplication A^T u = (u^T A)^T. - matvec = lambda u: vjp_fun_sol(u)[0] - - # The solution of A^T u = v, where - # A = jacobian(optimality_fun, argnums=0) - # v = -cotangent. - v = tree_scalar_mul(-1, cotangent) - u = solve(matvec, v) - - def fun_args(*args): - # We close over the solution. - return optimality_fun(sol, *args) - - _, vjp_fun_args = jax.vjp(fun_args, *args) - - return vjp_fun_args(u) +# def root_vjp(optimality_fun: Callable, +# sol: Any, +# args: Tuple, +# cotangent: Any, +# solve: Callable = linear_solve.solve_normal_cg) -> Any: +# """Vector-Jacobian product of a root. + +# The invariant is ``optimality_fun(sol, *args) == 0``. + +# Args: +# optimality_fun: the optimality function to use. +# sol: solution / root (pytree). +# args: tuple containing the arguments with respect to which we wish to +# differentiate ``sol`` against. +# cotangent: vector to left-multiply the Jacobian with +# (pytree, same structure as ``sol``). +# solve: a linear solver of the form, ``x = solve(matvec, b)``, +# where ``matvec(x) = Ax`` and ``Ax=b``. +# Returns: +# vjps: tuple of the same length as ``len(args)`` containing the vjps w.r.t. +# each argument. Each ``vjps[i]` has the same pytree structure as +# ``args[i]``. +# """ +# def fun_sol(sol): +# # We close over the arguments. +# return optimality_fun(sol, *args) + +# _, vjp_fun_sol = jax.vjp(fun_sol, sol) + +# # Compute the multiplication A^T u = (u^T A)^T. +# matvec = lambda u: vjp_fun_sol(u)[0] + +# # The solution of A^T u = v, where +# # A = jacobian(optimality_fun, argnums=0) +# # v = -cotangent. +# v = tree_scalar_mul(-1, cotangent) +# u = solve(matvec, v) + +# def fun_args(*args): +# # We close over the solution. +# return optimality_fun(sol, *args) + +# _, vjp_fun_args = jax.vjp(fun_args, *args) + +# return vjp_fun_args(u) def root_vjp(optimality_fun: Callable, From 0c230ceda3de5c0508c3408f16641f7f05ff8804 Mon Sep 17 00:00:00 2001 From: QB3 Date: Mon, 30 Aug 2021 15:50:38 +0200 Subject: [PATCH 13/27] add make_restricted_optimality_fun to sparse_vjp --- benchmarks/sparse_vjp.py | 18 ++++++----- jaxopt/_src/implicit_diff.py | 61 +++++------------------------------- tests/implicit_diff_test.py | 26 +++++++-------- 3 files changed, 30 insertions(+), 75 deletions(-) diff --git a/benchmarks/sparse_vjp.py b/benchmarks/sparse_vjp.py index 0321f43a..4c2df42d 100644 --- a/benchmarks/sparse_vjp.py +++ b/benchmarks/sparse_vjp.py @@ -10,7 +10,8 @@ from sklearn import datasets -X, y = datasets.make_regression(n_samples=10, n_features=10_000, random_state=0) +X, y = datasets.make_regression( + n_samples=10, n_features=10_000, random_state=0) L = jax.numpy.linalg.norm(X, ord=2) ** 2 @@ -20,19 +21,19 @@ def optimality_fun(params, lam, X, y): params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params -# def optimality_fun_sparse(params, lam, X, y): -# support = params != 0 -# res = X[:, support].T @ (X[:, support] @ params[support] - y) / L -# res = params[support] - res -# res = prox.prox_lasso(res, lam * len(y) / L) -# res -= params[support] -# return res +def make_restricted_optimality_fun(support): + def restricted_optimality_fun(restricted_params, lam, X, y): + # this is suboptimal, I would try to compute restricted_X once for all + restricted_X = X[:, support] + return optimality_fun(restricted_params, lam, restricted_X, y) + return restricted_optimality_fun lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) lam = lam_max / 2 t_start = time.time() sol = test_util.lasso_skl(X, y, lam) +print(sol) t_optim = time.time() - t_start vjp = lambda g: idf.root_vjp(optimality_fun=optimality_fun, @@ -42,6 +43,7 @@ def optimality_fun(params, lam, X, y): vjp_sparse = lambda g: idf.sparse_root_vjp( optimality_fun=optimality_fun, + make_restricted_optimality_fun=make_restricted_optimality_fun, sol=sol, args=(lam, X, y), cotangent=g)[0] # vjp w.r.t. lam diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index da72461c..6092aba2 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -25,54 +25,6 @@ from jaxopt._src import linear_solve from jaxopt._src.tree_util import tree_scalar_mul from jaxopt._src.tree_util import tree_sub -from jaxopt._src.tree_util import tree_zeros_like - - -# def root_vjp(optimality_fun: Callable, -# sol: Any, -# args: Tuple, -# cotangent: Any, -# solve: Callable = linear_solve.solve_normal_cg) -> Any: -# """Vector-Jacobian product of a root. - -# The invariant is ``optimality_fun(sol, *args) == 0``. - -# Args: -# optimality_fun: the optimality function to use. -# sol: solution / root (pytree). -# args: tuple containing the arguments with respect to which we wish to -# differentiate ``sol`` against. -# cotangent: vector to left-multiply the Jacobian with -# (pytree, same structure as ``sol``). -# solve: a linear solver of the form, ``x = solve(matvec, b)``, -# where ``matvec(x) = Ax`` and ``Ax=b``. -# Returns: -# vjps: tuple of the same length as ``len(args)`` containing the vjps w.r.t. -# each argument. Each ``vjps[i]` has the same pytree structure as -# ``args[i]``. -# """ -# def fun_sol(sol): -# # We close over the arguments. -# return optimality_fun(sol, *args) - -# _, vjp_fun_sol = jax.vjp(fun_sol, sol) - -# # Compute the multiplication A^T u = (u^T A)^T. -# matvec = lambda u: vjp_fun_sol(u)[0] - -# # The solution of A^T u = v, where -# # A = jacobian(optimality_fun, argnums=0) -# # v = -cotangent. -# v = tree_scalar_mul(-1, cotangent) -# u = solve(matvec, v) - -# def fun_args(*args): -# # We close over the solution. -# return optimality_fun(sol, *args) - -# _, vjp_fun_args = jax.vjp(fun_args, *args) - -# return vjp_fun_args(u) def root_vjp(optimality_fun: Callable, @@ -123,6 +75,7 @@ def fun_args(*args): def sparse_root_vjp(optimality_fun: Callable, + make_restricted_optimality_fun: Callable, sol: Any, args: Tuple, cotangent: Any, @@ -134,6 +87,7 @@ def sparse_root_vjp(optimality_fun: Callable, Args: optimality_fun: the optimality function to use. F in the paper + make_restricted_optimality_fun: TODO XXX. sol: solution / root (pytree). args: tuple containing the arguments with respect to which we wish to differentiate ``sol`` against. @@ -149,12 +103,11 @@ def sparse_root_vjp(optimality_fun: Callable, support = sol != 0 # nonzeros coefficients of the solution restricted_sol = sol[support] # solution restricted to the support - lam, X, y = args - new_args = lam, X[:, support], y + restricted_optimality_fun = make_restricted_optimality_fun(support) def fun_sol(restricted_sol): # We close over the arguments. - return optimality_fun(restricted_sol, *new_args) + return restricted_optimality_fun(restricted_sol, *args) _, vjp_fun_sol = jax.vjp(fun_sol, restricted_sol) @@ -170,9 +123,11 @@ def restricted_matvec(restricted_v): def fun_args(*args): # We close over the solution. - return optimality_fun(restricted_sol, *args) + return restricted_optimality_fun(restricted_sol, *args) + # return optimality_fun(restricted_sol, *args) - _, vjp_fun_args = jax.vjp(fun_args, *new_args) + _, vjp_fun_args = jax.vjp(fun_args, *args) + # _, vjp_fun_args = jax.vjp(fun_args, *new_args) return vjp_fun_args(restricted_u) diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index 06dbb1e4..997bfceb 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -94,25 +94,23 @@ def optimality_fun(params, lam, X, y): return prox.prox_lasso( params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params - # def optimality_fun(params, lam, X, y): - # support = params != 0 - # res = X[:, support].T @ (X[:, support] @ params[support] - y) / L - # res = params[support] - res - # res = prox.prox_lasso(res, lam * len(y) / L) - # res -= params[support] - # return res + def make_restricted_optimality_fun(support): + def restricted_optimality_fun(restricted_params, lam, X, y): + # this is suboptimal, I would try to compute restricted_X once for all + restricted_X = X[:, support] + return optimality_fun(restricted_params, lam, restricted_X, y) + return restricted_optimality_fun lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) lam = lam_max / 2 sol = test_util.lasso_skl(X, y, lam) - # jax.jacobian(optimality_fun)(jnp.ones(X.shape[1]), lam, X, y) - # test the mask in optimality_fun - - vjp = lambda g: idf.sparse_root_vjp(optimality_fun=optimality_fun, - sol=sol, - args=(lam, X, y), - cotangent=g)[0] # vjp w.r.t. lam + vjp = lambda g: idf.sparse_root_vjp( + optimality_fun=optimality_fun, + make_restricted_optimality_fun=make_restricted_optimality_fun, + sol=sol, + args=(lam, X, y), + cotangent=g)[0] # vjp w.r.t. lam I = jnp.eye(len(sol)) J = jax.vmap(vjp)(I) J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) From 933f287f6dd88d5db5f1f80d172f72a2fef4083d Mon Sep 17 00:00:00 2001 From: QB3 Date: Mon, 30 Aug 2021 16:24:48 +0200 Subject: [PATCH 14/27] [ci skip] simplified + rearanged args in tests --- tests/implicit_diff_test.py | 53 ++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index 997bfceb..e881cbf8 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -49,6 +49,17 @@ def ridge_solver(init_params, lam, X, y): return jnp.linalg.solve(XX + lam * len(y) * I, Xy) +X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) +lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) +lam = lam_max / 2 +L = jax.numpy.linalg.norm(X, ord=2) ** 2 + + +def lasso_optimality_fun(params, X, y, lam, tol=1e-4): + return prox.prox_lasso( + params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + + class ImplicitDiffTest(jtu.JaxTestCase): def test_root_vjp(self): @@ -66,39 +77,23 @@ def test_root_vjp(self): self.assertArraysAllClose(J, J_num, atol=5e-2) def test_lasso_root_vjp(self): - X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) - L = jax.numpy.linalg.norm(X, ord=2) ** 2 - - def optimality_fun(params, lam, X, y): - return prox.prox_lasso( - params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params - - lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) - lam = lam_max / 2 sol = test_util.lasso_skl(X, y, lam) - vjp = lambda g: idf.root_vjp(optimality_fun=optimality_fun, + vjp = lambda g: idf.root_vjp(optimality_fun=lasso_optimality_fun, sol=sol, - args=(lam, X, y), - cotangent=g)[0] # vjp w.r.t. lam + args=(X, y, lam), + cotangent=g)[2] # vjp w.r.t. lam I = jnp.eye(len(sol)) J = jax.vmap(vjp)(I) J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) self.assertArraysAllClose(J, J_num, atol=5e-2) def test_lasso_sparse_root_vjp(self): - X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) - - L = jax.numpy.linalg.norm(X, ord=2) ** 2 - - def optimality_fun(params, lam, X, y): - return prox.prox_lasso( - params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params def make_restricted_optimality_fun(support): - def restricted_optimality_fun(restricted_params, lam, X, y): + def restricted_optimality_fun(restricted_params, X, y, lam): # this is suboptimal, I would try to compute restricted_X once for all restricted_X = X[:, support] - return optimality_fun(restricted_params, lam, restricted_X, y) + return lasso_optimality_fun(restricted_params, restricted_X, y, lam) return restricted_optimality_fun lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) @@ -106,11 +101,11 @@ def restricted_optimality_fun(restricted_params, lam, X, y): sol = test_util.lasso_skl(X, y, lam) vjp = lambda g: idf.sparse_root_vjp( - optimality_fun=optimality_fun, + optimality_fun=lasso_optimality_fun, make_restricted_optimality_fun=make_restricted_optimality_fun, sol=sol, - args=(lam, X, y), - cotangent=g)[0] # vjp w.r.t. lam + args=(X, y, lam), + cotangent=g)[2] # vjp w.r.t. lam I = jnp.eye(len(sol)) J = jax.vmap(vjp)(I) J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) @@ -140,6 +135,16 @@ def test_custom_root(self): J = jax.jacrev(ridge_solver_decorated, argnums=1)(None, lam, X=X, y=y) self.assertArraysAllClose(J, J_num, atol=5e-2) + # def test_custom_root_lasso(self): + # lasso_solver_decorated = idf.custom_root( + # lasso_optimality_fun)(test_util.lasso_skl) + # sol = test_util.lasso_skl(X=X, y=y, lam=lam) + # sol_decorated = lasso_solver_decorated(X=X, y=y, lam=lam) + # self.assertArraysAllClose(sol, sol_decorated, atol=1e-4) + # J_num = test_util.lasso_skl_jac(X=X, y=y, lam=lam, tol=1e-4) + # J = jax.jacrev(lasso_solver_decorated, argnums=2)(X, y, lam) + # self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_custom_root_with_has_aux(self): def ridge_solver_with_aux(init_params, lam, X, y): return ridge_solver(init_params, lam, X, y), None From 1ddeeb9517f11ae543fc38607ee01b882cf32b06 Mon Sep 17 00:00:00 2001 From: QB3 Date: Mon, 30 Aug 2021 16:25:02 +0200 Subject: [PATCH 15/27] [ci skip] CLN --- jaxopt/_src/implicit_diff.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index 6092aba2..cbc6eb78 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -124,10 +124,8 @@ def restricted_matvec(restricted_v): def fun_args(*args): # We close over the solution. return restricted_optimality_fun(restricted_sol, *args) - # return optimality_fun(restricted_sol, *args) _, vjp_fun_args = jax.vjp(fun_args, *args) - # _, vjp_fun_args = jax.vjp(fun_args, *new_args) return vjp_fun_args(restricted_u) From 3aae541c67f3bd1932230fef5ea74391ab84c6ef Mon Sep 17 00:00:00 2001 From: QB3 Date: Mon, 30 Aug 2021 16:41:12 +0200 Subject: [PATCH 16/27] [ci skip] made test_custom_root_lasso --- tests/implicit_diff_test.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index e881cbf8..c67edb7a 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -40,7 +40,11 @@ def lasso_objective(params, lam, X, y): jnp.abs(params)) -# def ridge_solver(init_params, lam, X, y): +def lasso_solver(params, X, y, lam): + sol = test_util.lasso_skl(X, y, lam) + return sol + + def ridge_solver(init_params, lam, X, y): del init_params # not used XX = jnp.dot(X.T, X) @@ -135,15 +139,15 @@ def test_custom_root(self): J = jax.jacrev(ridge_solver_decorated, argnums=1)(None, lam, X=X, y=y) self.assertArraysAllClose(J, J_num, atol=5e-2) - # def test_custom_root_lasso(self): - # lasso_solver_decorated = idf.custom_root( - # lasso_optimality_fun)(test_util.lasso_skl) - # sol = test_util.lasso_skl(X=X, y=y, lam=lam) - # sol_decorated = lasso_solver_decorated(X=X, y=y, lam=lam) - # self.assertArraysAllClose(sol, sol_decorated, atol=1e-4) - # J_num = test_util.lasso_skl_jac(X=X, y=y, lam=lam, tol=1e-4) - # J = jax.jacrev(lasso_solver_decorated, argnums=2)(X, y, lam) - # self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_custom_root_lasso(self): + lasso_solver_decorated = idf.custom_root( + lasso_optimality_fun)(lasso_solver) + sol = test_util.lasso_skl(X=X, y=y, lam=lam) + sol_decorated = lasso_solver_decorated(None, X=X, y=y, lam=lam) + self.assertArraysAllClose(sol, sol_decorated, atol=1e-4) + J_num = test_util.lasso_skl_jac(X=X, y=y, lam=lam, tol=1e-4) + J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam) + self.assertArraysAllClose(J, J_num, atol=5e-2) def test_custom_root_with_has_aux(self): def ridge_solver_with_aux(init_params, lam, X, y): From 3a3ef0b19ba44a7faed271f71cd47f34d8b2ceb9 Mon Sep 17 00:00:00 2001 From: QB3 Date: Mon, 30 Aug 2021 18:37:05 +0200 Subject: [PATCH 17/27] [ci skip] added sparse custom root + tests --- jaxopt/_src/implicit_diff.py | 54 ++++++++++++++++++++++++++++++++++++ jaxopt/implicit_diff.py | 1 + tests/implicit_diff_test.py | 28 ++++++++++++------- 3 files changed, 73 insertions(+), 10 deletions(-) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index cbc6eb78..b1bc5577 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -201,6 +201,36 @@ def solver_fun_bwd(tup, cotangent): return wrapped_solver_fun +def _sparse_custom_root( + solver_fun, optimality_fun, make_restricted_optimality_fun, solve, has_aux): + def solver_fun_fwd(init_params, *args): + res = solver_fun(init_params, *args) + return res, (res, args) + + def solver_fun_bwd(tup, cotangent): + res, args = tup + + # solver_fun can return auxiliary data if has_aux = True. + if has_aux: + cotangent = cotangent[0] + sol = res[0] + else: + sol = res + + # Compute VJPs w.r.t. args. + vjps = sparse_root_vjp( + optimality_fun=optimality_fun, + make_restricted_optimality_fun=make_restricted_optimality_fun, + sol=sol, args=args, cotangent=cotangent, solve=solve) + # For init_params, we return None. + return (None,) + vjps + + wrapped_solver_fun = jax.custom_vjp(solver_fun) + wrapped_solver_fun.defvjp(solver_fun_fwd, solver_fun_bwd) + + return wrapped_solver_fun + + def custom_root(optimality_fun: Callable, has_aux: bool = False, solve: Callable = linear_solve.solve_normal_cg): @@ -222,6 +252,30 @@ def wrapper(solver_fun): return wrapper +def sparse_custom_root(optimality_fun: Callable, + make_restricted_optimality_fun: Callable, + has_aux: bool = False, + solve: Callable = linear_solve.solve_normal_cg): + """Decorator for adding implicit differentiation to a root solver. + + Args: + optimality_fun: an equation function, ``optimality_fun(params, *args)`. + The invariant is ``optimality_fun(sol, *args) == 0`` at the + solution / root ``sol``. + has_aux: whether the decorated solver function returns auxiliary data. + solve: a linear solver of the form, ``solve(matvec, b)``. + + Returns: + A solver function decorator, i.e., + ``custom_root(optimality_fun)(solver_fun)``. + """ + def wrapper(solver_fun): + return _sparse_custom_root( + solver_fun, optimality_fun, make_restricted_optimality_fun, solve, + has_aux) + return wrapper + + def custom_fixed_point(fixed_point_fun: Callable, has_aux: bool = False, solve: Callable = linear_solve.solve_normal_cg): diff --git a/jaxopt/implicit_diff.py b/jaxopt/implicit_diff.py index 9b808617..68a4fab8 100644 --- a/jaxopt/implicit_diff.py +++ b/jaxopt/implicit_diff.py @@ -17,3 +17,4 @@ from jaxopt._src.implicit_diff import root_jvp from jaxopt._src.implicit_diff import root_vjp from jaxopt._src.implicit_diff import sparse_root_vjp +from jaxopt._src.implicit_diff import sparse_custom_root diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index c67edb7a..4f023a49 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -59,6 +59,14 @@ def ridge_solver(init_params, lam, X, y): L = jax.numpy.linalg.norm(X, ord=2) ** 2 +def make_restricted_optimality_fun(support): + def restricted_optimality_fun(restricted_params, X, y, lam): + # this is suboptimal, I would try to compute restricted_X once for all + restricted_X = X[:, support] + return lasso_optimality_fun(restricted_params, restricted_X, y, lam) + return restricted_optimality_fun + + def lasso_optimality_fun(params, X, y, lam, tol=1e-4): return prox.prox_lasso( params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params @@ -92,16 +100,6 @@ def test_lasso_root_vjp(self): self.assertArraysAllClose(J, J_num, atol=5e-2) def test_lasso_sparse_root_vjp(self): - - def make_restricted_optimality_fun(support): - def restricted_optimality_fun(restricted_params, X, y, lam): - # this is suboptimal, I would try to compute restricted_X once for all - restricted_X = X[:, support] - return lasso_optimality_fun(restricted_params, restricted_X, y, lam) - return restricted_optimality_fun - - lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) - lam = lam_max / 2 sol = test_util.lasso_skl(X, y, lam) vjp = lambda g: idf.sparse_root_vjp( @@ -149,6 +147,16 @@ def test_custom_root_lasso(self): J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam) self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_sparse_custom_root_lasso(self): + lasso_solver_decorated = idf.sparse_custom_root( + lasso_optimality_fun, make_restricted_optimality_fun)(lasso_solver) + sol = test_util.lasso_skl(X=X, y=y, lam=lam) + sol_decorated = lasso_solver_decorated(None, X=X, y=y, lam=lam) + self.assertArraysAllClose(sol, sol_decorated, atol=1e-4) + J_num = test_util.lasso_skl_jac(X=X, y=y, lam=lam, tol=1e-4) + J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam) + self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_custom_root_with_has_aux(self): def ridge_solver_with_aux(init_params, lam, X, y): return ridge_solver(init_params, lam, X, y), None From 4669b96edece4c7da16befacd092dd916ecf73bc Mon Sep 17 00:00:00 2001 From: QB3 Date: Mon, 30 Aug 2021 20:14:27 +0200 Subject: [PATCH 18/27] [ci skip] adapted example lasso, toward a sparse implementation --- examples/lasso_implicit_diff_sparse.py | 94 ++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 examples/lasso_implicit_diff_sparse.py diff --git a/examples/lasso_implicit_diff_sparse.py b/examples/lasso_implicit_diff_sparse.py new file mode 100644 index 00000000..3e7abc03 --- /dev/null +++ b/examples/lasso_implicit_diff_sparse.py @@ -0,0 +1,94 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implicit differentiation of the lasso based on a sparse implementation.""" + +from absl import app +import jax +import jax.numpy as jnp +from jaxopt import implicit_diff +from jaxopt import linear_solve +from jaxopt import OptaxSolver +from jaxopt import prox +from jaxopt._src import test_util +import optax +from sklearn import datasets +from sklearn import model_selection +from sklearn import preprocessing + +# def main(argv): +# del argv + +# Prepare data. +# X, y = datasets.load_boston(return_X_y=True) + +X, y = datasets.make_regression( + n_samples=10, n_features=10_000, random_state=0) + +X = preprocessing.normalize(X) +# data = (X_tr, X_val, y_tr, y_val) +data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0) + +L = jax.numpy.linalg.norm(X, ord=2) ** 2 + + +def optimality_fun(params, lam, data): + X, y = data + return prox.prox_lasso( + params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + + +@implicit_diff.custom_root(optimality_fun=optimality_fun) +def lasso_solver(init_params, lam, data): + """Solve Lasso.""" + X_tr, y_tr = data + # TODO add warm start? + sol = test_util.lasso_skl(X, y, lam) + return sol + + +# Perhaps confusingly, theta is a parameter of the outer objective, +# but l2reg = jnp.exp(theta) is an hyper-parameter of the inner objective. +def outer_objective(theta, init_inner, data): + """Validation loss.""" + X_tr, X_val, y_tr, y_val = data + # We use the bijective mapping l2reg = jnp.exp(theta) + # both to optimize in log-space and to ensure positivity. + lam = jnp.exp(theta) + w_fit = lasso_solver(init_inner, lam, (X_tr, y_tr)) + y_pred = jnp.dot(X_val, w_fit) + loss_value = jnp.mean((y_pred - y_val) ** 2) + # We return w_fit as auxiliary data. + # Auxiliary data is stored in the optimizer state (see below). + return loss_value, w_fit + + +# Initialize solver. +solver = OptaxSolver(opt=optax.adam(1e-2), fun=outer_objective, has_aux=True) +lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) +lam = lam_max / 10 +theta_init = jnp.log(lam) +theta, state = solver.init(theta_init) +init_w = jnp.zeros(X.shape[1]) + +# Run outer loop. +for _ in range(10): + theta, state = solver.update( + params=theta, state=state, init_inner=init_w, data=data) + # The auxiliary data returned by the outer loss is stored in the state. + init_w = state.aux + print(f"[Step {state.iter_num}] Validation loss: {state.value:.3f}.") + +# if __name__ == "__main__": +# app.run(main) From d110ff16e16be848854118882dad47d6b7b8ba08 Mon Sep 17 00:00:00 2001 From: QB3 Date: Tue, 31 Aug 2021 09:49:49 +0200 Subject: [PATCH 19/27] [ci skip] added benchmark sparse custom root --- benchmarks/sparse_custom_root.py | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 benchmarks/sparse_custom_root.py diff --git a/benchmarks/sparse_custom_root.py b/benchmarks/sparse_custom_root.py new file mode 100644 index 00000000..976cb5e2 --- /dev/null +++ b/benchmarks/sparse_custom_root.py @@ -0,0 +1,73 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import jax +import jax.numpy as jnp + +from jaxopt import prox +from jaxopt import implicit_diff as idf +from jaxopt._src import test_util + +from sklearn import datasets + + +def lasso_objective(params, lam, X, y): + residuals = jnp.dot(X, params) - y + return 0.5 * jnp.mean(residuals ** 2) / len(y) + lam * jnp.sum( + jnp.abs(params)) + + +def lasso_solver(params, X, y, lam): + sol = test_util.lasso_skl(X, y, lam) + return sol + + +X, y = datasets.make_regression( + n_samples=10, n_features=10_000, random_state=0) +lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) +lam = lam_max / 2 +L = jax.numpy.linalg.norm(X, ord=2) ** 2 + + +def make_restricted_optimality_fun(support): + def restricted_optimality_fun(restricted_params, X, y, lam): + # this is suboptimal, I would try to compute restricted_X once for all + restricted_X = X[:, support] + return lasso_optimality_fun(restricted_params, restricted_X, y, lam) + return restricted_optimality_fun + + +def lasso_optimality_fun(params, X, y, lam, tol=1e-4): + return prox.prox_lasso( + params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + + +t_start = time.time() +lasso_solver_decorated = idf.custom_root(lasso_optimality_fun)(lasso_solver) +sol = test_util.lasso_skl(X=X, y=y, lam=lam) +J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam) +t_custom = time.time() - t_start + + +t_start = time.time() +lasso_solver_decorated = idf.sparse_custom_root( + lasso_optimality_fun, make_restricted_optimality_fun)(lasso_solver) +sol = test_util.lasso_skl(X=X, y=y, lam=lam) +J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam) +t_custom_sparse = time.time() - t_start + + +print("Time taken to compute the Jacobian %.3f" % t_custom) +print("Time taken to compute the Jacobian with the sparse implementation %.3f" % t_custom_sparse) From d15ae5582a5392e4f278e98f88de3e688ab68485 Mon Sep 17 00:00:00 2001 From: QB3 Date: Tue, 31 Aug 2021 09:50:25 +0200 Subject: [PATCH 20/27] [ci skip] added sparse_custom_root to implicit diff example --- examples/lasso_implicit_diff_sparse.py | 30 ++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/examples/lasso_implicit_diff_sparse.py b/examples/lasso_implicit_diff_sparse.py index 3e7abc03..35833b17 100644 --- a/examples/lasso_implicit_diff_sparse.py +++ b/examples/lasso_implicit_diff_sparse.py @@ -14,6 +14,7 @@ """Implicit differentiation of the lasso based on a sparse implementation.""" +import time from absl import app import jax import jax.numpy as jnp @@ -34,7 +35,7 @@ # X, y = datasets.load_boston(return_X_y=True) X, y = datasets.make_regression( - n_samples=10, n_features=10_000, random_state=0) + n_samples=30, n_features=10_000, random_state=0) X = preprocessing.normalize(X) # data = (X_tr, X_val, y_tr, y_val) @@ -49,7 +50,18 @@ def optimality_fun(params, lam, data): params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params -@implicit_diff.custom_root(optimality_fun=optimality_fun) +def make_restricted_optimality_fun(support): + def restricted_optimality_fun(restricted_params, lam, data): + # this is suboptimal, I would try to compute restricted_X once for all + X, y = data + restricted_X = X[:, support] + return optimality_fun(restricted_params, lam, (restricted_X, y)) + return restricted_optimality_fun + + +@implicit_diff.sparse_custom_root( + optimality_fun=optimality_fun, + make_restricted_optimality_fun=make_restricted_optimality_fun) def lasso_solver(init_params, lam, data): """Solve Lasso.""" X_tr, y_tr = data @@ -57,6 +69,17 @@ def lasso_solver(init_params, lam, data): sol = test_util.lasso_skl(X, y, lam) return sol +# @implicit_diff.custom_root( +# optimality_fun=optimality_fun) +# def lasso_solver(init_params, lam, data): +# """Solve Lasso.""" +# X_tr, y_tr = data +# # TODO add warm start? +# sol = test_util.lasso_skl(X, y, lam) +# return sol + + + # Perhaps confusingly, theta is a parameter of the outer objective, # but l2reg = jnp.exp(theta) is an hyper-parameter of the inner objective. @@ -82,6 +105,7 @@ def outer_objective(theta, init_inner, data): theta, state = solver.init(theta_init) init_w = jnp.zeros(X.shape[1]) +t_start = time.time() # Run outer loop. for _ in range(10): theta, state = solver.update( @@ -89,6 +113,8 @@ def outer_objective(theta, init_inner, data): # The auxiliary data returned by the outer loss is stored in the state. init_w = state.aux print(f"[Step {state.iter_num}] Validation loss: {state.value:.3f}.") +t_ellapsed = time.time() - t_start # if __name__ == "__main__": # app.run(main) +print("Time taken for 10 iterations: %.2f" % t_ellapsed) From 77661b1207d2f2f6faa46dbdfb4912f9480ac991 Mon Sep 17 00:00:00 2001 From: QB3 Date: Sun, 12 Sep 2021 21:58:20 +0200 Subject: [PATCH 21/27] improved benchmark file --- benchmarks/sparse_vjp.py | 55 +++++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/benchmarks/sparse_vjp.py b/benchmarks/sparse_vjp.py index 4c2df42d..1c2d5bd0 100644 --- a/benchmarks/sparse_vjp.py +++ b/benchmarks/sparse_vjp.py @@ -2,18 +2,19 @@ import jax import jax.numpy as jnp +import numpy as onp from jaxopt import prox from jaxopt import implicit_diff as idf from jaxopt._src import test_util - +from jaxopt import linear_solve from sklearn import datasets X, y = datasets.make_regression( - n_samples=10, n_features=10_000, random_state=0) + n_samples=1_000, n_features=10_000, random_state=0) -L = jax.numpy.linalg.norm(X, ord=2) ** 2 +L = onp.linalg.norm(X, ord=2) ** 2 def optimality_fun(params, lam, X, y): @@ -36,31 +37,51 @@ def restricted_optimality_fun(restricted_params, lam, X, y): print(sol) t_optim = time.time() - t_start -vjp = lambda g: idf.root_vjp(optimality_fun=optimality_fun, - sol=sol, - args=(lam, X, y), - cotangent=g)[0] # vjp w.r.t. lam +rand = onp.random.normal(0, 1, len(sol)) +dict_times = {} +dict_grad = {} + +for maxiter in [10, 100, 1000, 2_000]: + def solve(matvec, b): + return linear_solve.solve_normal_cg( + matvec, b, None, tol=1e-32, maxiter=maxiter) + + vjp = lambda g: idf.root_vjp( + optimality_fun=optimality_fun, + sol=sol, + args=(lam, X, y), + cotangent=g, + solve=solve)[0] # vjp w.r.t. lam + + t_start = time.time() + grad = vjp(rand) + t_jac = time.time() - t_start + dict_times[maxiter] = t_jac + dict_grad[maxiter] = grad.copy() + + +def solve_sparse(matvec, b): + return linear_solve.solve_cg( + matvec, b, None, tol=1e-32, maxiter=(sol != 0).sum()) + vjp_sparse = lambda g: idf.sparse_root_vjp( optimality_fun=optimality_fun, make_restricted_optimality_fun=make_restricted_optimality_fun, sol=sol, args=(lam, X, y), - cotangent=g)[0] # vjp w.r.t. lam - -t_start = time.time() -I = jnp.eye(len(sol)) -J = jax.vmap(vjp)(I) -t_jac = time.time() - t_start + cotangent=g, + solve=solve_sparse)[0] # vjp w.r.t. lam t_start = time.time() -I = jnp.eye(len(sol)) -J_sparse = jax.vmap(vjp_sparse)(I) +grad_sparse = vjp_sparse(rand) t_jac_sparse = time.time() - t_start print("Time taken to solve the Lasso optimization problem %.3f" % t_optim) -print("Time taken to compute the Jacobian %.3f" % t_jac) -print("Time taken to compute the Jacobian with the sparse implementation %.3f" % t_jac_sparse) +for maxiter in dict_times.keys(): + print("Time taken to compute the gradient with n= %i iterations %.3f | distance to the sparse gradient %.e" % ( + maxiter, dict_times[maxiter], jnp.linalg.norm(dict_grad[maxiter] - grad_sparse) / grad_sparse)) +print("Time taken to compute the gradient with the sparse implementation %.3f" % t_jac_sparse) # Computation time are the same, which is very weird to me From a7e54360cb65b3f0aa45a0c2d9878eb894cbd898 Mon Sep 17 00:00:00 2001 From: QB3 Date: Sun, 12 Sep 2021 21:59:32 +0200 Subject: [PATCH 22/27] jax.numpy.linalg >> onp.linalg.norm --- examples/lasso_implicit_diff_sparse.py | 7 +++---- jaxopt/_src/implicit_diff.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/lasso_implicit_diff_sparse.py b/examples/lasso_implicit_diff_sparse.py index 35833b17..a4f7387e 100644 --- a/examples/lasso_implicit_diff_sparse.py +++ b/examples/lasso_implicit_diff_sparse.py @@ -18,6 +18,7 @@ from absl import app import jax import jax.numpy as jnp +import numpy as onp from jaxopt import implicit_diff from jaxopt import linear_solve from jaxopt import OptaxSolver @@ -37,11 +38,11 @@ X, y = datasets.make_regression( n_samples=30, n_features=10_000, random_state=0) -X = preprocessing.normalize(X) +# X = preprocessing.normalize(X) # data = (X_tr, X_val, y_tr, y_val) data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0) -L = jax.numpy.linalg.norm(X, ord=2) ** 2 +L = onp.linalg.norm(X, ord=2) ** 2 def optimality_fun(params, lam, data): @@ -79,8 +80,6 @@ def lasso_solver(init_params, lam, data): # return sol - - # Perhaps confusingly, theta is a parameter of the outer objective, # but l2reg = jnp.exp(theta) is an hyper-parameter of the inner objective. def outer_objective(theta, init_inner, data): diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index b1bc5577..f0b97b1f 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -79,7 +79,7 @@ def sparse_root_vjp(optimality_fun: Callable, sol: Any, args: Tuple, cotangent: Any, - solve: Callable = linear_solve.solve_normal_cg) -> Any: + solve: Callable = linear_solve.solve_cg) -> Any: """Sparse vector-Jacobian product of a root. The invariant is ``optimality_fun(sol, *args) == 0``. From a8ee7ccdf344d089637ae29d9ae04c2f99449966 Mon Sep 17 00:00:00 2001 From: QB3 Date: Sun, 12 Sep 2021 22:18:10 +0200 Subject: [PATCH 23/27] [ci skip] added back version with hardcoded support --- jaxopt/_src/implicit_diff.py | 56 ++++++++++++++++++++++++++++++++++++ jaxopt/implicit_diff.py | 1 + tests/implicit_diff_test.py | 7 +++++ 3 files changed, 64 insertions(+) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index f0b97b1f..aced6c99 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -129,6 +129,62 @@ def fun_args(*args): return vjp_fun_args(restricted_u) +def sparse_root_vjp2(optimality_fun: Callable, + # filter_args: Callable, + sol: Any, + args: Tuple, + cotangent: Any, + solve: Callable = linear_solve.solve_cg) -> Any: + """Sparse vector-Jacobian product of a root. + + The invariant is ``optimality_fun(sol, *args) == 0``. + + Args: + optimality_fun: the optimality function to use. + F in the paper + make_restricted_optimality_fun: TODO XXX. + sol: solution / root (pytree). + args: tuple containing the arguments with respect to which we wish to + differentiate ``sol`` against. + cotangent: vector to left-multiply the Jacobian with + (pytree, same structure as ``sol``). + solve: a linear solver of the form, ``x = solve(matvec, b)``, + where ``matvec(x) = Ax`` and ``Ax=b``. + Returns: + vjps: tuple of the same length as ``len(args)`` containing the vjps w.r.t. + each argument. Each ``vjps[i]` has the same pytree structure as + ``args[i]``. + """ + support = sol != 0 # nonzeros coefficients of the solution + restricted_sol = sol[support] # solution restricted to the support + + X, y, lam = args + new_args = X[:, support], y, lam + + def fun_sol(restricted_sol): + # We close over the arguments. + return optimality_fun(restricted_sol, *new_args) + + _, vjp_fun_sol = jax.vjp(fun_sol, restricted_sol) + + # Compute the multiplication A^T u = (u^T A)^T resticted to the support. + def restricted_matvec(restricted_v): + return vjp_fun_sol(restricted_v)[0] + + # The solution of A^T u = v, where + # A = jacobian(optimality_fun, argnums=0) + # v = -cotangent. + restricted_v = tree_scalar_mul(-1, cotangent[support]) + restricted_u = solve(restricted_matvec, restricted_v) + + def fun_args(*args): + # We close over the solution. + return optimality_fun(restricted_sol, *args) + + _, vjp_fun_args = jax.vjp(fun_args, * new_args) + + return vjp_fun_args(restricted_u) + def _jvp_sol(optimality_fun, sol, args, tangent): """JVP in the first argument of optimality_fun.""" diff --git a/jaxopt/implicit_diff.py b/jaxopt/implicit_diff.py index 68a4fab8..d17fe912 100644 --- a/jaxopt/implicit_diff.py +++ b/jaxopt/implicit_diff.py @@ -17,4 +17,5 @@ from jaxopt._src.implicit_diff import root_jvp from jaxopt._src.implicit_diff import root_vjp from jaxopt._src.implicit_diff import sparse_root_vjp +from jaxopt._src.implicit_diff import sparse_root_vjp2 from jaxopt._src.implicit_diff import sparse_custom_root diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index 4f023a49..5ed9dd9c 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -108,10 +108,17 @@ def test_lasso_sparse_root_vjp(self): sol=sol, args=(X, y, lam), cotangent=g)[2] # vjp w.r.t. lam + vjp2 = lambda g: idf.sparse_root_vjp2( + optimality_fun=lasso_optimality_fun, + sol=sol, + args=(X, y, lam), + cotangent=g)[2] # vjp w.r.t. lam I = jnp.eye(len(sol)) J = jax.vmap(vjp)(I) + J2 = jax.vmap(vjp2)(I) J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) self.assertArraysAllClose(J, J_num, atol=5e-2) + self.assertArraysAllClose(J2, J_num, atol=5e-2) def test_root_jvp(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) From d7e5cc1b12ddeb35ca83da8eb628ff4bd7825dd6 Mon Sep 17 00:00:00 2001 From: QB3 Date: Mon, 13 Sep 2021 11:37:42 +0200 Subject: [PATCH 24/27] [ciskip] updated benchmark --- benchmarks/sparse_vjp.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/benchmarks/sparse_vjp.py b/benchmarks/sparse_vjp.py index 1c2d5bd0..842fa3a1 100644 --- a/benchmarks/sparse_vjp.py +++ b/benchmarks/sparse_vjp.py @@ -12,21 +12,21 @@ from sklearn import datasets X, y = datasets.make_regression( - n_samples=1_000, n_features=10_000, random_state=0) + n_samples=100, n_features=100_000, random_state=0) L = onp.linalg.norm(X, ord=2) ** 2 -def optimality_fun(params, lam, X, y): +def optimality_fun(params, X, y, lam): return prox.prox_lasso( params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params def make_restricted_optimality_fun(support): - def restricted_optimality_fun(restricted_params, lam, X, y): + def restricted_optimality_fun(restricted_params, X, y, lam): # this is suboptimal, I would try to compute restricted_X once for all restricted_X = X[:, support] - return optimality_fun(restricted_params, lam, restricted_X, y) + return optimality_fun(restricted_params, restricted_X, y, lam) return restricted_optimality_fun @@ -34,14 +34,14 @@ def restricted_optimality_fun(restricted_params, lam, X, y): lam = lam_max / 2 t_start = time.time() sol = test_util.lasso_skl(X, y, lam) -print(sol) t_optim = time.time() - t_start +onp.random.seed(0) rand = onp.random.normal(0, 1, len(sol)) dict_times = {} dict_grad = {} -for maxiter in [10, 100, 1000, 2_000]: +for maxiter in [10, 100, 1000, 2000]: def solve(matvec, b): return linear_solve.solve_normal_cg( matvec, b, None, tol=1e-32, maxiter=maxiter) @@ -49,9 +49,9 @@ def solve(matvec, b): vjp = lambda g: idf.root_vjp( optimality_fun=optimality_fun, sol=sol, - args=(lam, X, y), + args=(X, y, lam), cotangent=g, - solve=solve)[0] # vjp w.r.t. lam + solve=solve)[2] # vjp w.r.t. lam t_start = time.time() grad = vjp(rand) @@ -69,19 +69,31 @@ def solve_sparse(matvec, b): optimality_fun=optimality_fun, make_restricted_optimality_fun=make_restricted_optimality_fun, sol=sol, - args=(lam, X, y), + args=(X, y, lam), cotangent=g, - solve=solve_sparse)[0] # vjp w.r.t. lam + solve=solve_sparse)[2] # vjp w.r.t. lam + +vjp_sparse2 = lambda g: idf.sparse_root_vjp2( + optimality_fun=optimality_fun, + sol=sol, + args=(X, y, lam), + cotangent=g, + solve=solve_sparse)[2] # vjp w.r.t. lam t_start = time.time() grad_sparse = vjp_sparse(rand) t_jac_sparse = time.time() - t_start +t_start = time.time() +grad_sparse2 = vjp_sparse(rand) +t_jac_sparse2 = time.time() - t_start + print("Time taken to solve the Lasso optimization problem %.3f" % t_optim) for maxiter in dict_times.keys(): print("Time taken to compute the gradient with n= %i iterations %.3f | distance to the sparse gradient %.e" % ( maxiter, dict_times[maxiter], jnp.linalg.norm(dict_grad[maxiter] - grad_sparse) / grad_sparse)) print("Time taken to compute the gradient with the sparse implementation %.3f" % t_jac_sparse) +print("Time taken to compute the gradient with the sparse2 implementation %.3f" % t_jac_sparse2) # Computation time are the same, which is very weird to me From 1c59af76c4f82f7d47295ac3f708bb0d6a13d94e Mon Sep 17 00:00:00 2001 From: QB3 Date: Mon, 13 Sep 2021 11:38:45 +0200 Subject: [PATCH 25/27] [ciskip] added sparse_custom root with other implem, currently fails --- jaxopt/_src/implicit_diff.py | 58 ++++++++++++++++++++++++++++++++++-- jaxopt/implicit_diff.py | 1 + tests/implicit_diff_test.py | 3 ++ 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index aced6c99..1294825b 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -177,11 +177,11 @@ def restricted_matvec(restricted_v): restricted_v = tree_scalar_mul(-1, cotangent[support]) restricted_u = solve(restricted_matvec, restricted_v) - def fun_args(*args): + def fun_args(*args_): # We close over the solution. - return optimality_fun(restricted_sol, *args) + return optimality_fun(restricted_sol, *args_) - _, vjp_fun_args = jax.vjp(fun_args, * new_args) + _, vjp_fun_args = jax.vjp(fun_args, *new_args) return vjp_fun_args(restricted_u) @@ -287,6 +287,35 @@ def solver_fun_bwd(tup, cotangent): return wrapped_solver_fun +def _sparse_custom_root2( + solver_fun, optimality_fun, solve, has_aux): + def solver_fun_fwd(init_params, *args): + res = solver_fun(init_params, *args) + return res, (res, args) + + def solver_fun_bwd(tup, cotangent): + res, args = tup + + # solver_fun can return auxiliary data if has_aux = True. + if has_aux: + cotangent = cotangent[0] + sol = res[0] + else: + sol = res + + # Compute VJPs w.r.t. args. + vjps = sparse_root_vjp2( + optimality_fun=optimality_fun, + sol=sol, args=args, cotangent=cotangent, solve=solve) + # For init_params, we return None. + return (None,) + vjps + + wrapped_solver_fun = jax.custom_vjp(solver_fun) + wrapped_solver_fun.defvjp(solver_fun_fwd, solver_fun_bwd) + + return wrapped_solver_fun + + def custom_root(optimality_fun: Callable, has_aux: bool = False, solve: Callable = linear_solve.solve_normal_cg): @@ -332,6 +361,29 @@ def wrapper(solver_fun): return wrapper +def sparse_custom_root2( + optimality_fun: Callable, has_aux: bool = False, + solve: Callable = linear_solve.solve_normal_cg): + """Decorator for adding implicit differentiation to a root solver. + + Args: + optimality_fun: an equation function, ``optimality_fun(params, *args)`. + The invariant is ``optimality_fun(sol, *args) == 0`` at the + solution / root ``sol``. + has_aux: whether the decorated solver function returns auxiliary data. + solve: a linear solver of the form, ``solve(matvec, b)``. + + Returns: + A solver function decorator, i.e., + ``custom_root(optimality_fun)(solver_fun)``. + """ + def wrapper(solver_fun): + return _sparse_custom_root2( + solver_fun, optimality_fun, solve, has_aux) + + return wrapper + + def custom_fixed_point(fixed_point_fun: Callable, has_aux: bool = False, solve: Callable = linear_solve.solve_normal_cg): diff --git a/jaxopt/implicit_diff.py b/jaxopt/implicit_diff.py index d17fe912..3981ba7b 100644 --- a/jaxopt/implicit_diff.py +++ b/jaxopt/implicit_diff.py @@ -19,3 +19,4 @@ from jaxopt._src.implicit_diff import sparse_root_vjp from jaxopt._src.implicit_diff import sparse_root_vjp2 from jaxopt._src.implicit_diff import sparse_custom_root +from jaxopt._src.implicit_diff import sparse_custom_root2 diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index 5ed9dd9c..f511b6cc 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -157,11 +157,14 @@ def test_custom_root_lasso(self): def test_sparse_custom_root_lasso(self): lasso_solver_decorated = idf.sparse_custom_root( lasso_optimality_fun, make_restricted_optimality_fun)(lasso_solver) + lasso_solver_decorated2 = idf.sparse_custom_root2( + lasso_optimality_fun)(lasso_solver) sol = test_util.lasso_skl(X=X, y=y, lam=lam) sol_decorated = lasso_solver_decorated(None, X=X, y=y, lam=lam) self.assertArraysAllClose(sol, sol_decorated, atol=1e-4) J_num = test_util.lasso_skl_jac(X=X, y=y, lam=lam, tol=1e-4) J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam) + J2 = jax.jacrev(lasso_solver_decorated2, argnums=3)(None, X, y, lam) self.assertArraysAllClose(J, J_num, atol=5e-2) def test_custom_root_with_has_aux(self): From 91348aa4095bf7ee8a3c7f0f9b58616c7b26bd45 Mon Sep 17 00:00:00 2001 From: QB3 Date: Mon, 13 Sep 2021 12:22:15 +0200 Subject: [PATCH 26/27] [ciskip] X.T@(X @ params - y) >> grad(obj.square) --- benchmarks/sparse_custom_root.py | 4 +++- benchmarks/sparse_vjp.py | 5 ++++- examples/lasso_implicit_diff_sparse.py | 5 ++++- tests/implicit_diff_test.py | 4 +++- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/benchmarks/sparse_custom_root.py b/benchmarks/sparse_custom_root.py index 976cb5e2..b10acf48 100644 --- a/benchmarks/sparse_custom_root.py +++ b/benchmarks/sparse_custom_root.py @@ -19,6 +19,7 @@ from jaxopt import prox from jaxopt import implicit_diff as idf from jaxopt._src import test_util +from jaxopt import objective from sklearn import datasets @@ -50,8 +51,9 @@ def restricted_optimality_fun(restricted_params, X, y, lam): def lasso_optimality_fun(params, X, y, lam, tol=1e-4): + n_samples = X.shape[0] return prox.prox_lasso( - params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + params - jax.grad(objective.least_squares)(params, (X, y)) * n_samples / L, lam * len(y) / L) - params t_start = time.time() diff --git a/benchmarks/sparse_vjp.py b/benchmarks/sparse_vjp.py index 842fa3a1..3dfcc872 100644 --- a/benchmarks/sparse_vjp.py +++ b/benchmarks/sparse_vjp.py @@ -8,6 +8,7 @@ from jaxopt import implicit_diff as idf from jaxopt._src import test_util from jaxopt import linear_solve +from jaxopt import objective from sklearn import datasets @@ -18,8 +19,10 @@ def optimality_fun(params, X, y, lam): + n_samples = X.shape[0] return prox.prox_lasso( - params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + params - jax.grad(objective.least_squares)(params, (X, y)) * n_samples / L, + lam * len(y) / L) - params def make_restricted_optimality_fun(support): diff --git a/examples/lasso_implicit_diff_sparse.py b/examples/lasso_implicit_diff_sparse.py index a4f7387e..dca673d0 100644 --- a/examples/lasso_implicit_diff_sparse.py +++ b/examples/lasso_implicit_diff_sparse.py @@ -23,6 +23,7 @@ from jaxopt import linear_solve from jaxopt import OptaxSolver from jaxopt import prox +from jaxopt import objective from jaxopt._src import test_util import optax from sklearn import datasets @@ -47,8 +48,10 @@ def optimality_fun(params, lam, data): X, y = data + n_samples = X.shape[0] return prox.prox_lasso( - params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + params - jax.grad(objective.least_squares)(params, (X, y)) * n_samples / L, + lam * len(y) / L) - params def make_restricted_optimality_fun(support): diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index f511b6cc..7e690a14 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -68,8 +68,10 @@ def restricted_optimality_fun(restricted_params, X, y, lam): def lasso_optimality_fun(params, X, y, lam, tol=1e-4): + n_samples = X.shape[0] return prox.prox_lasso( - params - X.T @ (X @ params - y) / L, lam * len(y) / L) - params + params - jax.grad(objective.least_squares)(params, (X, y)) * n_samples / L, + lam * len(y) / L) - params class ImplicitDiffTest(jtu.JaxTestCase): From c64c5967bf0dfcd9328f9a69ccec9db09dd9db22 Mon Sep 17 00:00:00 2001 From: QB3 Date: Tue, 14 Sep 2021 16:50:37 +0200 Subject: [PATCH 27/27] [ci skip] made test custom root work --- jaxopt/_src/implicit_diff.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index 1294825b..dcf8923c 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -177,11 +177,14 @@ def restricted_matvec(restricted_v): restricted_v = tree_scalar_mul(-1, cotangent[support]) restricted_u = solve(restricted_matvec, restricted_v) - def fun_args(*args_): + def fun_args(*args): # We close over the solution. - return optimality_fun(restricted_sol, *args_) + X, y, lam = args + new_args = X[:, support], y, lam + return optimality_fun(restricted_sol, *new_args) - _, vjp_fun_args = jax.vjp(fun_args, *new_args) + _, vjp_fun_args = jax.vjp(fun_args, *args) + # _, vjp_fun_args = jax.vjp(fun_args, *new_args) return vjp_fun_args(restricted_u)