Skip to content

Commit

Permalink
Test branch for PR 296
Browse files Browse the repository at this point in the history
  • Loading branch information
trunk-io[bot] committed Feb 1, 2024
2 parents a83669a + 14c868c commit 9486185
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pysages/methods/funn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def update(state, data):
)
bias = (-Jxi.T @ F).reshape(state.bias.shape)
#
return FUNNState(xi, bias, hist, Fsum, F, Wp, state.Wp, nn, state.ncalls)
return FUNNState(xi, bias, hist, Fsum, F, Wp, state.Wp, nn, ncalls)

return snapshot, initialize, generalize(update, helpers)

Expand Down
22 changes: 14 additions & 8 deletions pysages/methods/spectral_abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def update(state, data):
)
bias = np.reshape(-Jxi.T @ force, state.bias.shape)
#
return SpectralABFState(xi, bias, hist, Fsum, force, Wp, state.Wp, fun, state.ncalls)
return SpectralABFState(xi, bias, hist, Fsum, force, Wp, state.Wp, fun, ncalls)

return snapshot, initialize, generalize(update, helpers)

Expand Down Expand Up @@ -228,6 +228,7 @@ def build_force_estimator(method: SpectralABF):
"""
N = method.N
grid = method.grid
dims = grid.shape.size
model = method.model
get_grad = build_grad_evaluator(model)

Expand All @@ -242,17 +243,17 @@ def _estimate_force(state):
return cond(state.pred, interpolate_force, average_force, state)

if method.restraints is None:
estimate_force = _estimate_force
ob_force = jit(lambda state: np.zeros(dims))
else:
lo, hi, kl, kh = method.restraints

def restraints_force(state):
def ob_force(state):
xi = state.xi.reshape(grid.shape.size)
return apply_restraints(lo, hi, kl, kh, xi)

def estimate_force(state):
ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition
return cond(ob, restraints_force, _estimate_force, state)
def estimate_force(state):
ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition
return cond(ob, ob_force, _estimate_force, state)

return estimate_force

Expand Down Expand Up @@ -303,7 +304,11 @@ def average_forces(hist, Fsum):
return Fsum / np.maximum(hist, 1)

def build_fes_fn(fun):
return jit(lambda x: evaluate(fun, x))
def fes_fn(x):
A = evaluate(fun, x)
return A.max() - A

return jit(fes_fn)

def first_or_all(seq):
return seq[0] if len(seq) == 1 else seq
Expand All @@ -318,7 +323,7 @@ def first_or_all(seq):
fes_fn = build_fes_fn(s.fun)
hists.append(s.hist)
mean_forces.append(average_forces(s.hist, s.Fsum))
free_energies.append(fes_fn(mesh))
free_energies.append(fes_fn(mesh).reshape(grid.shape))
funs.append(s.fun)
fes_fns.append(fes_fn)

Expand All @@ -330,4 +335,5 @@ def first_or_all(seq):
fun=first_or_all(funs),
fes_fn=first_or_all(fes_fns),
)

return numpyfy_vals(ana_result)

0 comments on commit 9486185

Please sign in to comment.