From fcf8eeace68a47fe79b4602a8c3f7cb509a188e0 Mon Sep 17 00:00:00 2001 From: Rafael Pastrana Date: Sun, 29 Oct 2023 17:42:39 -0400 Subject: [PATCH] [Solvers] Implemented `solver_anderson` to support Anderson acceleration --- src/jax_fdm/equilibrium/iterative.py | 48 +++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/src/jax_fdm/equilibrium/iterative.py b/src/jax_fdm/equilibrium/iterative.py index 7134d38..a11a8ad 100644 --- a/src/jax_fdm/equilibrium/iterative.py +++ b/src/jax_fdm/equilibrium/iterative.py @@ -1,5 +1,7 @@ from functools import partial +import jax + import jax.numpy as jnp from jax import vjp @@ -9,18 +11,47 @@ from equinox.internal import while_loop from jaxopt import FixedPointIteration +from jaxopt import AndersonAcceleration # ========================================================================== # Iterative solvers # ========================================================================== +def solver_anderson(f, a, x_init, solver_config): + """ + Solve for a fixed point of a function f(a, x) using anderson acceleration in jaxopt. + """ + tmax = solver_config["tmax"] + eta = solver_config["eta"] + verbose = solver_config["verbose"] + + def f_swapped(x, a): + return f(a, x) + + fpi = AndersonAcceleration(fixed_point_fun=f_swapped, + maxiter=tmax, + tol=eta, # 1e-5 is the default, + has_aux=False, + history_size=5, # 5 is default + ridge=1e-5, # 1e-5 is the default + # implicit_diff=True, + # jit=True, + # unroll=False, + verbose=verbose) + + result = fpi.run(x_init, a) + + return result.params + + def solver_fixedpoint(f, a, x_init, solver_config): """ Solve for a fixed point of a function f(a, x) using forward iteration in jaxopt. """ tmax = solver_config["tmax"] eta = solver_config["eta"] + verbose = solver_config["verbose"] def f_swapped(x, a): return f(a, x) @@ -29,9 +60,10 @@ def f_swapped(x, a): maxiter=tmax, tol=eta, has_aux=False, - implicit_diff=True, - unroll=False, - ) + # implicit_diff=True, + # jit=True, + # unroll=False, + verbose=verbose) result = fpi.run(x_init, a) @@ -44,9 +76,13 @@ def solver_forward(f, a, x_init, solver_config): """ tmax = solver_config["tmax"] eta = solver_config["eta"] + verbose = solver_config["verbose"] def distance(x_prev, x): - return jnp.mean(jnp.linalg.norm(x_prev - x, axis=1)) + residual = jnp.mean(jnp.linalg.norm(x_prev - x, axis=1)) + if verbose: + jax.debug.print("Residual: {}", residual) + return residual def cond_fun(carry): x_prev, x = carry @@ -124,6 +160,8 @@ def rev_iter(packed, u): _, vjp_x = vjp(lambda x: fn(a, x), x_star) return x_star_bar + vjp_x(u)[0] + solver_config = {k: v for k, v in solver_config.items()} + solver_config["eta"] = 1e-3 partial_func = solver(rev_iter, (a, x_star, x_star_bar), x_star_bar, @@ -131,7 +169,7 @@ def rev_iter(packed, u): a_bar = vjp_a(partial_func)[0] - return a_bar, None # jnp.zeros_like(x_star) + return a_bar, None fixed_point.defvjp(fixed_point_fwd, fixed_point_bwd)