Skip to content

Commit

Permalink
update the plotting script
Browse files Browse the repository at this point in the history
  • Loading branch information
tjira committed Jun 24, 2024
1 parent 91a4b49 commit 76c9144
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
7 changes: 6 additions & 1 deletion example/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ qdyn_1d_HO_imaginary:
qdyn_2d_HO_imaginary:
"../bin/acorn_expression" -d 2 -g -8 8 -o U_DIA.mat -p 64 -e "0.5*(x^2+y^2)"
"../bin/acorn_expression" -d 2 -g -8 8 -o PSI_DIA_GUESS.mat -p 64 -e "exp(-(x^2+y^2))" "0"
"../bin/acorn_qdyn" -d 2 -i 1000 -m 1 -o 3 -p 0 -s 0.1 --savewfn
"../bin/acorn_qdyn" -d 2 -i 1000 -m 1 -o 12 -p 0 -s 0.1 --savewfn

# target to perform an imaginary-time propagation of 3D adiabatic wavepacket
qdyn_3d_HO_imaginary:
Expand Down Expand Up @@ -170,6 +170,11 @@ random_qdyn_hydrogen_imaginary:
"../bin/acorn_expression" -d 3 -g -16 16 -o PSI_DIA_GUESS.mat -p 128 -e "exp(-(x^2+y^2+z^2))" "0"
"../bin/acorn_qdyn" -d 3 -i 1000 -m 0.999455 -o 1 -p 0 -s 0.1

random_qdyn_kepler_real:
"../bin/acorn_expression" -d 2 -g -16 16 -o U_DIA.mat -p 64 -e "10/(x^2+y^2)-10/sqrt(x^2+y^2)"
"../bin/acorn_expression" -d 2 -g -16 16 -o PSI_DIA_GUESS.mat -p 64 -e "exp(-((x-1)^2+(y-1)^2))" "0"
"../bin/acorn_qdyn" -d 2 -i 1000 -m 1 -p 0 -s 0.1 --savewfn

random_dyn_1d_tully_1: random_qdyn_1d_tully_1 random_cdyn_lz_1d_tully_1
random_dyn_1d_tully_2: random_qdyn_1d_tully_2 random_cdyn_lz_1d_tully_2
random_dyn_1d_ds_1: random_qdyn_1d_ds_1 random_cdyn_lz_1d_ds_1
Expand Down
31 changes: 22 additions & 9 deletions script/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ def update(frame):
return fig, update

def two(args, mats):
# raise and exception when more than one column is requested to plot
if len(args.extract) > 1: raise Exception("YOU HAVE TO SPECIFY ONLY ONE COLUMN TO EXTRACT")

# set the column names of the data so that the first two columns are independent variables and the rest are unique dependent variables
for i, mat in enumerate(mats): mat.columns = [("x" if j == 0 else "y" if j == 1 else sum([mats[k].shape[1] - 2 for k in range(i)]) + j - 1) for j in range(mat.shape[1])]

Expand All @@ -52,26 +49,41 @@ def two(args, mats):
# create the meshgrid for the data
x, y = np.meshgrid(*[np.linspace(data[0][0][v].min(), data[0][0][v].max(), 128) for v in ["x", "y"]])

# take the norm of the data if more than one column is extracted
if len(args.extract) > 1:
for frame in data:
for mat in frame:
mat.iloc[:, 2] = np.sqrt((mat.iloc[:, 2:]**2).sum(axis=1)); mat = mat.iloc[:, [0, 1, 2]]
else: data = [[mat.iloc[:, [0, 1, 1 + args.extract[0]]] for mat in frame] for frame in data]

# calculate the number of rows and columns the resulting image will have
rows = [i for i in range(2, len(data[0]) + 1) if len(data[0]) % i == 0 and i * i <= len(data[0])]; rows = rows[-1] if len(rows) > 0 else 1; cols = len(data[0]) // rows

# calculate min and max of the data
zmin = np.min([[np.min(mat.iloc[:, 1 + args.extract[0]]) for mat in frame] for frame in data]); zmax = np.max([[np.max(mat.iloc[:, 1 + args.extract[0]]) for mat in frame] for frame in data])
zmin = np.min([[np.min(mat.iloc[:, 2]) for mat in frame] for frame in data]); zmax = np.max([[np.max(mat.iloc[:, 2]) for mat in frame] for frame in data])

# set the heatmap parameters
params = {"cbar":False, "xticklabels":False, "yticklabels":False, "vmin":zmin, "vmax":zmax, "rasterized":True, "cmap":"icefire"}
# set the heatmap and surface plot parameters
hmparams = {"cbar":False, "xticklabels":False, "yticklabels":False, "vmin":zmin, "vmax":zmax, "rasterized":True, "cmap":"icefire"}
spparams = {"rasterized":True, "cmap":"icefire", "vmin":zmin, "vmax":zmax}

# initialize the figure and axis
fig, ax = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
fig, ax = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows), subplot_kw={"projection": "3d" if args.surface else None})

# set the z limits for the surface plot
if args.surface:
for i, axis in enumerate(np.array([ax]).flatten()): axis.set_zlim(zmin - 0.05 * (zmax - zmin), zmax + 0.05 * (zmax - zmin))

# plot the heatmaps
for i, mat in enumerate(data[len(data) - 1 if args.last else 0]):
sns.heatmap(ax=np.array([ax]).flatten()[i], data=si.griddata((mat.x, mat.y), mat.iloc[:, 1 + args.extract[0]], (x, y), method="cubic"), **params)
if not args.surface: sns.heatmap(ax=np.array([ax]).flatten()[i], data=si.griddata((mat.x, mat.y), mat.iloc[:, 2], (x, y), method="cubic"), **hmparams)
else: np.array([ax]).flatten()[i].plot_trisurf(mat.x, mat.y, mat.iloc[:, 2], **spparams)

# define the animation update function
def update(frame):
[[coll.remove() for coll in axis.collections] for axis in np.array([ax]).flatten()]
for i, mat in enumerate(data[frame]):
ax.collections[i].remove(); sns.heatmap(ax=np.array([ax]).flatten()[i], data=si.griddata((mat.x, mat.y), mat.iloc[:, 1 + args.extract[0]], (x, y), method="cubic"), **params)
if not args.surface: sns.heatmap(ax=np.array([ax]).flatten()[i], data=si.griddata((mat.x, mat.y), mat.iloc[:, 2], (x, y), method="cubic"), **hmparams)
else: np.array([ax]).flatten()[i].plot_trisurf(mat.x, mat.y, mat.iloc[:, 2], **spparams)

# set the tight layout and return
plt.tight_layout(); return fig, update
Expand All @@ -90,6 +102,7 @@ def update(frame):
parser.add_argument("--image", action="store_true", help="Display only the image without frames, ticks and labels.")
parser.add_argument("--last", action="store_true", help="Display only the last frame.")
parser.add_argument("--legend", action="store_true", help="Display the legend.")
parser.add_argument("--surface", action="store_true", help="Display the data as a surface plot.")
parser.add_argument("--mp4", action="store_true", help="Save the plot as a gif.")
parser.add_argument("--gif", action="store_true", help="Save the plot as an mp4.")
parser.add_argument("--png", action="store_true", help="Save the plot as a png.")
Expand Down

0 comments on commit 76c9144

Please sign in to comment.