Skip to content

Commit

Permalink
add arrows to pub_plot
Browse files Browse the repository at this point in the history
  • Loading branch information
robertjwilson committed Jun 17, 2024
1 parent 892fa8f commit d191eaf
Showing 1 changed file with 39 additions and 3 deletions.
42 changes: 39 additions & 3 deletions nctoolkit/static_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def pub_plot(
out=None,
breaks=None,
dpi = "figure",
font = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -514,9 +515,13 @@ def pub_plot(
if r_max is not None:
vmax = np.nanpercentile(np.ma.filled(values, np.nan), r_max)

if robust and limits is None:
if robust:
vmin = np.nanpercentile(np.ma.filled(values, np.nan), 2)
vmax = np.nanpercentile(np.ma.filled(values, np.nan), 98)
limits = [0,0]
limits[0] = vmin
limits[1] = vmax

if limits is None:
if r_min is not None:
vmin = np.nanpercentile(np.ma.filled(values, np.nan), r_min)
Expand Down Expand Up @@ -691,10 +696,34 @@ def pub_plot(
else:
fraction = 0.046 * size[1] / size[0]

min_value = np.min(values)
max_value = np.max(values)

min_arrow = False

if vmin is not None:
if vmin > min_value:
min_arrow = True
max_arrow = False
if vmax is not None:
if vmax < max_value:
max_arrow = True

if min_arrow and max_arrow:
extend = "both"
else:
if min_arrow:
extend = "min"
else:
if max_arrow:
extend = "max"
else:
extend = "neither"

if l_location == "bottom":
cb = plt.colorbar(im, fraction=fraction, pad=0.04, location=l_location)
cb = plt.colorbar(im, fraction=fraction, pad=0.04, location=l_location, extend = extend)
else:
cb = plt.colorbar(im, fraction=fraction, pad=0.04)
cb = plt.colorbar(im, fraction=fraction, pad=0.04, extend = extend)

# add breaks to colorbar cb
if breaks is not None:
Expand Down Expand Up @@ -763,6 +792,13 @@ def pub_plot(

if legend_position is None:
cb.remove()

if font is not None:
for text in [ax.title, ax.xaxis.label, ax.yaxis.label, cb.ax.yaxis.label, cb.ax.xaxis.label]:
text.set_fontsize(font)
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label, cb.ax.yaxis.label, cb.ax.xaxis.label] +
ax.get_xticklabels() + ax.get_yticklabels()):
item.set_fontsize(font)

if out is not None:
print("saving as file")
Expand Down

0 comments on commit d191eaf

Please sign in to comment.