Skip to content

Commit

Permalink
add missing day/time_labelers
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Aug 20, 2023
1 parent 1eab436 commit 0584b68
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 15 deletions.
34 changes: 31 additions & 3 deletions latent_calendar/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,11 @@ def plot_across_column(
config = StartEndConfig(start=start_col, end=end_col, minutes=duration)

plot_dataframe_grid_across_column(
self._obj, grid_col=grid_col, config=config, max_cols=max_cols, alpha=alpha,
self._obj,
grid_col=grid_col,
config=config,
max_cols=max_cols,
alpha=alpha,
day_labeler=day_labeler,
time_labeler=time_labeler,
grid_lines=grid_lines,
Expand Down Expand Up @@ -439,13 +443,17 @@ def plot_profile_by_row(
model: LatentCalendar,
index_func=lambda idx: idx,
include_components: bool = True,
day_labeler: DayLabeler = DayLabeler(),
time_labeler: TimeLabeler = TimeLabeler(),
) -> np.ndarray:
"""Plot each row of the DataFrame as a profile plot. Data must have been transformed to wide format first.
Args:
model: model to use for prediction and transform
index_func: function to generate title for each row
include_components: whether to include components in the plot
day_labeler: DayLabeler instance to use for day labels
time_labeler: TimeLabeler instance to use for time labels
Returns:
grid of axes
Expand All @@ -456,23 +464,37 @@ def plot_profile_by_row(
model=model,
index_func=index_func,
include_components=include_components,
day_labeler=day_labeler,
time_labeler=time_labeler,
)

def plot_raw_and_predicted_by_row(
self, *, model: LatentCalendar, index_func=lambda idx: idx
self,
*,
model: LatentCalendar,
index_func=lambda idx: idx,
day_labeler: DayLabeler = DayLabeler(),
time_labeler: TimeLabeler = TimeLabeler(),
) -> np.ndarray:
"""Plot raw and predicted values for a model. Data must have been transformed to wide format first.
Args:
model: model to use for prediction
index_func: function to generate title for each row
day_labeler: DayLabeler instance to use for day labels
time_labeler: TimeLabeler instance to use for time labels
Returns:
grid of axes
"""
return plot_profile_by_row(
self._obj, model=model, index_func=index_func, include_components=False
self._obj,
model=model,
index_func=index_func,
include_components=False,
day_labeler=day_labeler,
time_labeler=time_labeler,
)

def plot_model_predictions_by_row(
Expand All @@ -482,6 +504,8 @@ def plot_model_predictions_by_row(
model: LatentCalendar,
index_func=lambda idx: idx,
divergent: bool = True,
day_labeler: DayLabeler = DayLabeler(),
time_labeler: TimeLabeler = TimeLabeler(),
) -> np.ndarray:
"""Plot model predictions for each row of the DataFrame. Data must have been transformed to wide format first.
Expand All @@ -490,6 +514,8 @@ def plot_model_predictions_by_row(
model: model to use for prediction
index_func: function to generate title for each row
divergent: whether to use divergent colormap
day_labeler: DayLabeler instance to use for day labels
time_labeler: TimeLabeler instance to use for time labels
Returns:
grid of axes
Expand All @@ -501,4 +527,6 @@ def plot_model_predictions_by_row(
model=model,
index_func=index_func,
divergent=divergent,
day_labeler=day_labeler,
time_labeler=time_labeler,
)
34 changes: 22 additions & 12 deletions latent_calendar/plot/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def plot_profile(

# Raw Data
ax = axes[0]
plot_raw_data(array=array, ax=ax, day_labeler=day_labeler, time_labeler=time_labeler)
plot_raw_data(
array=array, ax=ax, day_labeler=day_labeler, time_labeler=time_labeler
)

# Under Model
ax = axes[1]
Expand All @@ -64,7 +66,7 @@ def plot_profile(
ax=ax,
display_y_axis=False,
divergent=divergent,
day_labeler=day_labeler,
day_labeler=day_labeler,
time_labeler=time_labeler,
)

Expand Down Expand Up @@ -102,7 +104,7 @@ def plot_profile_by_row(
index_func,
divergent: bool = True,
include_components: bool = True,
day_labeler: DayLabeler = DayLabeler(),
day_labeler: DayLabeler = DayLabeler(),
time_labeler: TimeLabeler = TimeLabeler(),
) -> np.ndarray:
nrows = len(df)
Expand All @@ -123,7 +125,7 @@ def plot_profile_by_row(
divergent=divergent,
include_components=include_components,
day_labeler=day_labeler,
time_labeler=time_labeler,
time_labeler=time_labeler,
)

ylabel = index_func(idx)
Expand Down Expand Up @@ -165,7 +167,9 @@ def plot_model_predictions(
X_to_predict_probs = model.predict(X_to_predict)[0]

ax = axes[0]
plot_raw_data(array=X_to_predict, ax=ax, day_labeler=day_labeler, time_labeler=time_labeler)
plot_raw_data(
array=X_to_predict, ax=ax, day_labeler=day_labeler, time_labeler=time_labeler
)
ax.set_title(f"Raw Data for Prediction")

ax = axes[1]
Expand All @@ -174,13 +178,19 @@ def plot_model_predictions(
ax=ax,
display_y_axis=False,
divergent=divergent,
day_labeler=day_labeler,
time_labeler=time_labeler
day_labeler=day_labeler,
time_labeler=time_labeler,
)
ax.set_title("Distribution from Prediction")

ax = axes[2]
plot_raw_data(array=X_holdout, ax=ax, display_y_axis=False, day_labeler=day_labeler, time_labeler=time_labeler)
plot_raw_data(
array=X_holdout,
ax=ax,
display_y_axis=False,
day_labeler=day_labeler,
time_labeler=time_labeler,
)
ax.set_title("Raw Data in Future")

return axes
Expand All @@ -192,8 +202,8 @@ def plot_model_predictions_by_row(
model: LatentCalendar,
index_func=lambda idx: idx,
divergent: bool = True,
day_labeler: DayLabeler = DayLabeler(),
time_labeler: TimeLabeler = TimeLabeler(),
day_labeler: DayLabeler = DayLabeler(),
time_labeler: TimeLabeler = TimeLabeler(),
) -> np.ndarray:
nrows = len(df)

Expand All @@ -213,8 +223,8 @@ def plot_model_predictions_by_row(
model=model,
axes=axes_row,
divergent=divergent,
day_labeler=day_labeler,
time_labeler=time_labeler,
day_labeler=day_labeler,
time_labeler=time_labeler,
)

ylabel = index_func(idx)
Expand Down

0 comments on commit 0584b68

Please sign in to comment.