Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Couple of fixes #296

Merged
merged 2 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pysages/methods/funn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def update(state, data):
)
bias = (-Jxi.T @ F).reshape(state.bias.shape)
#
return FUNNState(xi, bias, hist, Fsum, F, Wp, state.Wp, nn, state.ncalls)
return FUNNState(xi, bias, hist, Fsum, F, Wp, state.Wp, nn, ncalls)

return snapshot, initialize, generalize(update, helpers)

Expand Down
22 changes: 14 additions & 8 deletions pysages/methods/spectral_abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def update(state, data):
)
bias = np.reshape(-Jxi.T @ force, state.bias.shape)
#
return SpectralABFState(xi, bias, hist, Fsum, force, Wp, state.Wp, fun, state.ncalls)
return SpectralABFState(xi, bias, hist, Fsum, force, Wp, state.Wp, fun, ncalls)

return snapshot, initialize, generalize(update, helpers)

Expand Down Expand Up @@ -228,6 +228,7 @@ def build_force_estimator(method: SpectralABF):
"""
N = method.N
grid = method.grid
dims = grid.shape.size
model = method.model
get_grad = build_grad_evaluator(model)

Expand All @@ -242,17 +243,17 @@ def _estimate_force(state):
return cond(state.pred, interpolate_force, average_force, state)

if method.restraints is None:
estimate_force = _estimate_force
ob_force = jit(lambda state: np.zeros(dims))
else:
lo, hi, kl, kh = method.restraints

def restraints_force(state):
def ob_force(state):
xi = state.xi.reshape(grid.shape.size)
return apply_restraints(lo, hi, kl, kh, xi)

def estimate_force(state):
ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition
return cond(ob, restraints_force, _estimate_force, state)
def estimate_force(state):
ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition
return cond(ob, ob_force, _estimate_force, state)

return estimate_force

Expand Down Expand Up @@ -303,7 +304,11 @@ def average_forces(hist, Fsum):
return Fsum / np.maximum(hist, 1)

def build_fes_fn(fun):
return jit(lambda x: evaluate(fun, x))
def fes_fn(x):
A = evaluate(fun, x)
return A.max() - A

return jit(fes_fn)

def first_or_all(seq):
return seq[0] if len(seq) == 1 else seq
Expand All @@ -318,7 +323,7 @@ def first_or_all(seq):
fes_fn = build_fes_fn(s.fun)
hists.append(s.hist)
mean_forces.append(average_forces(s.hist, s.Fsum))
free_energies.append(fes_fn(mesh))
free_energies.append(fes_fn(mesh).reshape(grid.shape))
funs.append(s.fun)
fes_fns.append(fes_fn)

Expand All @@ -330,4 +335,5 @@ def first_or_all(seq):
fun=first_or_all(funs),
fes_fn=first_or_all(fes_fns),
)

return numpyfy_vals(ana_result)
Loading