Skip to content

Commit

Permalink
[Models] EquilibriumModel has computation attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
arpastrana committed Oct 29, 2023
1 parent cf3621b commit 53d5d0f
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 36 deletions.
29 changes: 23 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
#### Models
- Added support for efficient reverse-mode AD of the calculation of equilibrium states in the presence of shape-dependent loads, via implicit differentiation. Forward-mode AD is pending.
- Added `EquilibriumModel.equilibrium_iterative` to compute equilibrium states that have shape-dependent edge and face loads using fixed point iteration.
- Added `EquiibriumModel.edges_load` and `EquiibriumModel.faces_load` to allow computation of edge and face loads
- Added `EquiibriumModel.edges_load` and `EquiibriumModel.faces_load` to allow computation of edge and face loads.
- Implemented `EquilibriumModelSparse.stiffness_matrix`.
- Implemented `EquilibriumModel.stiffness_matrix`.
- Implemented `EquilibriumModel.force_matrix`.
- Implemented `EquilibriumModel.force_fixed_matrix`.
- Added `linearsolve_fn`, `itersolve_fn`, `implicit_diff`, and `verbose` as attributes of `EquilibriumModel`.

#### Equilibrium
- Implemented `equilibrium.states.LoadState`
- Implemented `equilibrium.states.EquilibriumParametersState`
- Restored `vectors` field in `EquilibriumState`.
- Implemented `equilibrium.states.LoadState`.
- Implemented `equilibrium.states.EquilibriumParametersState`.

#### Solvers
- Implemented `solver_anderson`, to find fixed points of a function with `jaxopt.AndersonAcceleration`.
- Defined a `jax.custom_vjp` for `fixed_point`, an interface function that solves for fixed points of a function for different root-finding solver types: `solver_fixedpoint`, `solver_forward`, and `solver_newton`.
- Implemented `solver_fixedpoint`, a function that wraps `jaxopt.FixedPointIterator` to calculate static equilibrium iteratively.
- Implemented `solver_forward`, to find fixed points of a function using an `equinox.while_loop`.
Expand All @@ -31,9 +34,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
#### Loads
- Added `equilibrium.loads` module to enable support for edge and face-loads, which correspond to line and area loads, respectively.
These two load types can be optionally become follower loads setting the `is_local` input flag to `True`. A follower load will update its direction iteratively, according to the local coordinate system of an edge or a face at an iteration. The two main functions that enable this feature are `loads.nodes_load_from_faces` and `loads.nodes_load_from_edges`. These functions are wrapped by `EquilibriumModel` under `EquiibriumModel.edges_load` and `EquiibriumModel.faces_load`.
- Implemented `equilibrium.loads.nodes_`
- Implemented `equilibrium.loads.nodes_`.

#### Datastructures
- Report standard deviation in `FDDatastructure.print_stats()`.
- Added constructor method `FDNetwork.from_mesh`.
- Added `FDMesh.face_lcs` to calculate the local coordinaty system of a mesh face.
- Added `datastructures.FDDatastructure.edges_loads`.
Expand All @@ -46,13 +50,19 @@ These two load types can be optionally become follower loads setting the `is_loc
- Implemented `structures.Graph`.
- Implemented `structures.GraphSparse`.
- Added `FDNetwork.is_edge_fully_supported`.
- Added `EquilibriumMeshStructure.from_mesh` with support for inhomogenous faces (i.e. faces with different number of vertices). The solution is to pad the rows of the `faces` 2D array with `-1` to match `max_num_vertices`.

#### Goals

- Implemented `NetworkXYZLaplacianGoal`

#### Optimization
- Added `optimization.Optimizer.loads_static` attribute to store edge and face loads during optimization.

#### Geometry
- Added `polygon_lcs` to compute the local coordinate system of a closed polygon.
- Added `line_lcs` to compute the local coordinate system of a line.
- Added `nan` gradient guardrail to `normalize_vector` calculations.

#### Parameters
- Added support for mesh vertex parameters.
Expand All @@ -66,13 +76,15 @@ These two load types can be optionally become follower loads setting the `is_loc
- Implemented helper function `sparse_blockdiag_matrix` to `spsolve_gpu_ravel`.

#### Visualization
- Added `plotters/VectorArtist` to custom plot loads and reactions arrows.
- Implemented `LossPlotter._print_error_stats` to report loss breakdown of error terms.

### Changed

#### Models

#### Equilibrium
- The functions `fdm` and `constrained_fdm` take iterative equilibrium parameters as function arguments.
- The functions `fdm` and `constrained_fdm` can take an `FDMesh` as input, in addition to `FDNetwork`.

#### Sparse solver
Expand All @@ -89,15 +101,19 @@ These two load types can be optionally become follower loads setting the `is_loc
- Changed signature of `Regularizer.__call__` to take in parameters instead of equilibirum state.

#### Datastructures
- Overhauled `EquilibriumStructure` and `EquilibirumStructureSparse`. They are subclasses `equinox.Module`, and now they are meant to be immutable. They also have little idea of what an `FDNetwork` is.
- Overhauled `EquilibriumStructure` and `EquilibriumStructureSparse`. They are subclasses `equinox.Module`, and now they are meant to be immutable. They also have little idea of what an `FDNetwork` is.
- Modified `face_matrix` adjacency matrix creation function to skip -1 vertices. This is to add support for `MeshStructures` that have faces with different number of vertices.

#### Optimization
- `Optimizer.problem` takes an `FDNetwork` as input.
- `Optimizer.problem` takes boolean `jit_fn` as arg to disable jitting if needed.
- Changed `ParameterManager` to require an `FDNetwork` as argument at initialization.
- Changed `Parameter.value` signature. Gets value from `network` directly, not from `structure.network`
- `optimization.OptimizationRecorder` has support to store, export and import named tuple parameters.

#### Visualization
- Fixed bug in `viewers/network_artist.py` that overshifted load arrows.
- Edge coloring considers force sign for `force` color scheme in `artists/network_artist.py`.
- Fixed bug with the coloring of reaction forces in `viewers/network_artist.py`.
- Fixed bug with the coloring of reaction forces in `artists/network_artist.py`.
- `LossPlotter` has support to plot named tuple parameters.
Expand All @@ -108,7 +124,8 @@ These two load types can be optionally become follower loads setting the `is_loc
- Removed `EquilibriumModel.from_network`.
- Removed `sparse.force_densities_to_A`. Superseded by `EquilibriumModelSparse.stiffness_matrix`.
- Removed `sparse.force_densities_to_b`. Superseded by `EquilibriumModel.force_matrix`.
- Removed partial jitting from `Loss.__call__`
- Removed partial jitting from `Loss.__call__`.
- Removed partial jitting from `Error.__call__`.


## [0.7.1] 2023-05-08
Expand Down
98 changes: 68 additions & 30 deletions src/jax_fdm/equilibrium/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@

from jax_fdm.equilibrium.iterative import fixed_point
from jax_fdm.equilibrium.iterative import solver_fixedpoint
from jax_fdm.equilibrium.iterative import solver_forward

from jax_fdm.equilibrium.loads import nodes_load_from_faces
from jax_fdm.equilibrium.loads import nodes_load_from_edges

from jax_fdm.equilibrium.states import LoadState


# ==========================================================================
# Equilibrium model
Expand All @@ -23,12 +26,21 @@ class EquilibriumModel:
"""
The equilibrium model.
"""
def __init__(self, tmax=100, eta=1e-6, is_load_local=False):
def __init__(self,
tmax=100,
eta=1e-6,
is_load_local=False,
itersolve_fn=None,
implicit_diff=True,
verbose=False):

self.tmax = tmax
self.eta = eta
self.is_load_local = is_load_local

self.linearsolve_func = jnp.linalg.solve
self.linearsolve_fn = jnp.linalg.solve
self.itersolve_fn = itersolve_fn or solver_forward
self.implicit_diff = implicit_diff
self.verbose = verbose

# ----------------------------------------------------------------------
# Edges
Expand Down Expand Up @@ -80,7 +92,7 @@ def nodes_free_positions(self, q, xyz_fixed, loads, structure):
A = self.stiffness_matrix(q, structure)
b = self.force_matrix(q, xyz_fixed, loads, structure)

return self.linearsolve_func(A, b)
return self.linearsolve_fn(A, b)

def nodes_equilibrium(self, q, xyz_fixed, loads_nodes, structure):
"""
Expand Down Expand Up @@ -135,25 +147,37 @@ def __call__(self, params, structure):
"""
Compute an equilibrium state using the force density method (FDM).
"""
q, xyz_fixed, loads = params
loads_nodes = loads.nodes
q, xyz_fixed, loads_state = params
loads_nodes = loads_state.nodes

tmax = self.tmax
eta = self.eta
solver = self.itersolve_fn
implicit_diff = self.implicit_diff
verbose = self.verbose

xyz = self.equilibrium(q, xyz_fixed, loads_nodes, structure)

if tmax > 1:
# NOTE: Setting node loads to zero when tmax > 1 is temporary
loads_nodes = jnp.zeros_like(loads_nodes)
loads_state = LoadState(loads_nodes,
loads_state.edges,
loads_state.faces)

xyz = self.equilibrium_iterative(q,
xyz_fixed,
loads,
loads_state,
structure,
tmax,
eta,
verbose=False)
xyz_init=xyz,
tmax=tmax,
eta=eta,
solver=solver,
implicit_diff=implicit_diff,
verbose=verbose)

# activating this function raises a Zero error in sparse mode. Why?
loads_nodes = self.nodes_load(xyz, loads, structure)
# TODO: reactivate loads nodes
loads_nodes = self.nodes_load(xyz, loads_state, structure)

return self.equilibrium_state(q, xyz, loads_nodes, structure)

Expand All @@ -167,7 +191,17 @@ def equilibrium(self, q, xyz_fixed, loads_nodes, structure):
"""
return self.nodes_equilibrium(q, xyz_fixed, loads_nodes, structure)

def equilibrium_iterative(self, q, xyz_fixed, loads, structure, tmax=100, eta=1e-6, solver=None, implicit_diff=True, verbose=False):
def equilibrium_iterative(self,
q,
xyz_fixed,
load_state,
structure,
xyz_init=None,
tmax=100,
eta=1e-6,
solver=None,
implicit_diff=True,
verbose=False):
"""
Calculate static equilibrium on a structure iteratively.
Expand All @@ -183,52 +217,59 @@ def equilibrium_iterative_fn(params, xyz_init):
TODO: Extract closure into function shared with the other nodes equilibrium function?
"""
A, f_fixed, xyz_fixed, loads = params
A, f_fixed, xyz_fixed, load_state = params

free = structure.indices_free
freefixed = structure.indices_freefixed

loads_nodes = self.nodes_load(xyz_init, loads, structure)
loads_nodes = self.nodes_load(xyz_init, load_state, structure)
b = loads_nodes[free, :] - f_fixed
xyz_free = self.linearsolve_func(A, b)
xyz_free = self.linearsolve_fn(A, b)
xyz_ = self.nodes_positions(xyz_free, xyz_fixed, freefixed)

return self.nodes_positions(xyz_free, xyz_fixed, freefixed)
return xyz_

# recompute xyz_init if not input
if xyz_init is None:
xyz_init = self.equilibrium(q, xyz_fixed, load_state.nodes, structure)

xyz_init = self.equilibrium(q, xyz_fixed, loads.nodes, structure)
A = self.stiffness_matrix(q, structure)
f_fixed = self.force_fixed_matrix(q, xyz_fixed, structure)

solver = solver or solver_fixedpoint
solver_kwargs = {"solver_config": {"tmax": tmax, "eta": eta},
solver_kwargs = {"solver_config": {"tmax": tmax, "eta": eta, "verbose": verbose},
"f": equilibrium_iterative_fn,
"a": (A, f_fixed, xyz_fixed, loads),
"a": (A, f_fixed, xyz_fixed, load_state),
"x_init": xyz_init}

if implicit_diff:
return fixed_point(solver, **solver_kwargs)
xyz_new = fixed_point(solver, **solver_kwargs)

xyz_new = solver(**solver_kwargs)

return solver(**solver_kwargs)
return xyz_new

# ----------------------------------------------------------------------
# Equilibrium state
# ----------------------------------------------------------------------

def equilibrium_state(self, q, xyz, loads, structure):
def equilibrium_state(self, q, xyz, loads_nodes, structure):
"""
Assembles an equilibrium state object.
"""
connectivity = structure.connectivity

vectors = self.edges_vectors(xyz, connectivity)
residuals = self.nodes_residuals(q, loads, vectors, connectivity)
lengths = self.edges_lengths(vectors)
residuals = self.nodes_residuals(q, loads_nodes, vectors, connectivity)
forces = self.edges_forces(q, lengths)

return EquilibriumState(xyz=xyz,
residuals=residuals,
lengths=lengths,
forces=forces,
loads=loads)
loads=loads_nodes,
vectors=vectors)

# ----------------------------------------------------------------------
# Matrices
Expand All @@ -239,7 +280,6 @@ def stiffness_matrix(q, structure):
"""
The stiffness matrix of the structure.
"""
# shorthand
c_free = structure.connectivity_free

return c_free.T @ (q[:, None] * c_free)
Expand All @@ -248,7 +288,6 @@ def force_matrix(self, q, xyz_fixed, loads, structure):
"""
The force residual matrix of the structure.
"""
# shorthand
free = structure.indices_free

return loads[free, :] - self.force_fixed_matrix(q, xyz_fixed, structure)
Expand All @@ -258,7 +297,6 @@ def force_fixed_matrix(q, xyz_fixed, structure):
"""
The force matrix block of the residual forces at the fixed nodes.
"""
# shorthands
c_free = structure.connectivity_free
c_fixed = structure.connectivity_fixed

Expand All @@ -275,14 +313,14 @@ class EquilibriumModelSparse(EquilibriumModel):
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.linearsolve_func = spsolve
self.linearsolve_fn = spsolve

@staticmethod
def stiffness_matrix(q, structure):
"""
Computes the LHS matrix in CSC format from a vector of force densities.
"""
# short hands
# shorthands
index_array = structure.index_array
diag_indices = structure.diag_indices
diags = structure.diags
Expand Down

0 comments on commit 53d5d0f

Please sign in to comment.