Skip to content

Commit

Permalink
[Solvers] Implemented solver_anderson to support Anderson acceleration
Browse files Browse the repository at this point in the history
  • Loading branch information
arpastrana committed Oct 29, 2023
1 parent 93a037f commit fcf8eea
Showing 1 changed file with 43 additions and 5 deletions.
48 changes: 43 additions & 5 deletions src/jax_fdm/equilibrium/iterative.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import partial

import jax

import jax.numpy as jnp

from jax import vjp
Expand All @@ -9,18 +11,47 @@
from equinox.internal import while_loop

from jaxopt import FixedPointIteration
from jaxopt import AndersonAcceleration


# ==========================================================================
# Iterative solvers
# ==========================================================================

def solver_anderson(f, a, x_init, solver_config):
"""
Solve for a fixed point of a function f(a, x) using anderson acceleration in jaxopt.
"""
tmax = solver_config["tmax"]
eta = solver_config["eta"]
verbose = solver_config["verbose"]

def f_swapped(x, a):
return f(a, x)

fpi = AndersonAcceleration(fixed_point_fun=f_swapped,
maxiter=tmax,
tol=eta, # 1e-5 is the default,
has_aux=False,
history_size=5, # 5 is default
ridge=1e-5, # 1e-5 is the default
# implicit_diff=True,
# jit=True,
# unroll=False,
verbose=verbose)

result = fpi.run(x_init, a)

return result.params


def solver_fixedpoint(f, a, x_init, solver_config):
"""
Solve for a fixed point of a function f(a, x) using forward iteration in jaxopt.
"""
tmax = solver_config["tmax"]
eta = solver_config["eta"]
verbose = solver_config["verbose"]

def f_swapped(x, a):
return f(a, x)
Expand All @@ -29,9 +60,10 @@ def f_swapped(x, a):
maxiter=tmax,
tol=eta,
has_aux=False,
implicit_diff=True,
unroll=False,
)
# implicit_diff=True,
# jit=True,
# unroll=False,
verbose=verbose)

result = fpi.run(x_init, a)

Expand All @@ -44,9 +76,13 @@ def solver_forward(f, a, x_init, solver_config):
"""
tmax = solver_config["tmax"]
eta = solver_config["eta"]
verbose = solver_config["verbose"]

def distance(x_prev, x):
return jnp.mean(jnp.linalg.norm(x_prev - x, axis=1))
residual = jnp.mean(jnp.linalg.norm(x_prev - x, axis=1))
if verbose:
jax.debug.print("Residual: {}", residual)
return residual

def cond_fun(carry):
x_prev, x = carry
Expand Down Expand Up @@ -124,14 +160,16 @@ def rev_iter(packed, u):
_, vjp_x = vjp(lambda x: fn(a, x), x_star)
return x_star_bar + vjp_x(u)[0]

solver_config = {k: v for k, v in solver_config.items()}
solver_config["eta"] = 1e-3
partial_func = solver(rev_iter,
(a, x_star, x_star_bar),
x_star_bar,
solver_config)

a_bar = vjp_a(partial_func)[0]

return a_bar, None # jnp.zeros_like(x_star)
return a_bar, None


fixed_point.defvjp(fixed_point_fwd, fixed_point_bwd)

0 comments on commit fcf8eea

Please sign in to comment.