From 0584b68d4caba121536f800735dbdf2cce1df803 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 20 Aug 2023 14:33:08 +0200 Subject: [PATCH] add missing day/time_labelers --- latent_calendar/extensions.py | 34 +++++++++++++++++++++++++++--- latent_calendar/plot/core/model.py | 34 +++++++++++++++++++----------- 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/latent_calendar/extensions.py b/latent_calendar/extensions.py index 7169c4e..fb3f6d9 100644 --- a/latent_calendar/extensions.py +++ b/latent_calendar/extensions.py @@ -388,7 +388,11 @@ def plot_across_column( config = StartEndConfig(start=start_col, end=end_col, minutes=duration) plot_dataframe_grid_across_column( - self._obj, grid_col=grid_col, config=config, max_cols=max_cols, alpha=alpha, + self._obj, + grid_col=grid_col, + config=config, + max_cols=max_cols, + alpha=alpha, day_labeler=day_labeler, time_labeler=time_labeler, grid_lines=grid_lines, @@ -439,6 +443,8 @@ def plot_profile_by_row( model: LatentCalendar, index_func=lambda idx: idx, include_components: bool = True, + day_labeler: DayLabeler = DayLabeler(), + time_labeler: TimeLabeler = TimeLabeler(), ) -> np.ndarray: """Plot each row of the DataFrame as a profile plot. Data must have been transformed to wide format first. @@ -446,6 +452,8 @@ def plot_profile_by_row( model: model to use for prediction and transform index_func: function to generate title for each row include_components: whether to include components in the plot + day_labeler: DayLabeler instance to use for day labels + time_labeler: TimeLabeler instance to use for time labels Returns: grid of axes @@ -456,23 +464,37 @@ def plot_profile_by_row( model=model, index_func=index_func, include_components=include_components, + day_labeler=day_labeler, + time_labeler=time_labeler, ) def plot_raw_and_predicted_by_row( - self, *, model: LatentCalendar, index_func=lambda idx: idx + self, + *, + model: LatentCalendar, + index_func=lambda idx: idx, + day_labeler: DayLabeler = DayLabeler(), + time_labeler: TimeLabeler = TimeLabeler(), ) -> np.ndarray: """Plot raw and predicted values for a model. Data must have been transformed to wide format first. Args: model: model to use for prediction index_func: function to generate title for each row + day_labeler: DayLabeler instance to use for day labels + time_labeler: TimeLabeler instance to use for time labels Returns: grid of axes """ return plot_profile_by_row( - self._obj, model=model, index_func=index_func, include_components=False + self._obj, + model=model, + index_func=index_func, + include_components=False, + day_labeler=day_labeler, + time_labeler=time_labeler, ) def plot_model_predictions_by_row( @@ -482,6 +504,8 @@ def plot_model_predictions_by_row( model: LatentCalendar, index_func=lambda idx: idx, divergent: bool = True, + day_labeler: DayLabeler = DayLabeler(), + time_labeler: TimeLabeler = TimeLabeler(), ) -> np.ndarray: """Plot model predictions for each row of the DataFrame. Data must have been transformed to wide format first. @@ -490,6 +514,8 @@ def plot_model_predictions_by_row( model: model to use for prediction index_func: function to generate title for each row divergent: whether to use divergent colormap + day_labeler: DayLabeler instance to use for day labels + time_labeler: TimeLabeler instance to use for time labels Returns: grid of axes @@ -501,4 +527,6 @@ def plot_model_predictions_by_row( model=model, index_func=index_func, divergent=divergent, + day_labeler=day_labeler, + time_labeler=time_labeler, ) diff --git a/latent_calendar/plot/core/model.py b/latent_calendar/plot/core/model.py index 4151004..4ba9501 100644 --- a/latent_calendar/plot/core/model.py +++ b/latent_calendar/plot/core/model.py @@ -55,7 +55,9 @@ def plot_profile( # Raw Data ax = axes[0] - plot_raw_data(array=array, ax=ax, day_labeler=day_labeler, time_labeler=time_labeler) + plot_raw_data( + array=array, ax=ax, day_labeler=day_labeler, time_labeler=time_labeler + ) # Under Model ax = axes[1] @@ -64,7 +66,7 @@ def plot_profile( ax=ax, display_y_axis=False, divergent=divergent, - day_labeler=day_labeler, + day_labeler=day_labeler, time_labeler=time_labeler, ) @@ -102,7 +104,7 @@ def plot_profile_by_row( index_func, divergent: bool = True, include_components: bool = True, - day_labeler: DayLabeler = DayLabeler(), + day_labeler: DayLabeler = DayLabeler(), time_labeler: TimeLabeler = TimeLabeler(), ) -> np.ndarray: nrows = len(df) @@ -123,7 +125,7 @@ def plot_profile_by_row( divergent=divergent, include_components=include_components, day_labeler=day_labeler, - time_labeler=time_labeler, + time_labeler=time_labeler, ) ylabel = index_func(idx) @@ -165,7 +167,9 @@ def plot_model_predictions( X_to_predict_probs = model.predict(X_to_predict)[0] ax = axes[0] - plot_raw_data(array=X_to_predict, ax=ax, day_labeler=day_labeler, time_labeler=time_labeler) + plot_raw_data( + array=X_to_predict, ax=ax, day_labeler=day_labeler, time_labeler=time_labeler + ) ax.set_title(f"Raw Data for Prediction") ax = axes[1] @@ -174,13 +178,19 @@ def plot_model_predictions( ax=ax, display_y_axis=False, divergent=divergent, - day_labeler=day_labeler, - time_labeler=time_labeler + day_labeler=day_labeler, + time_labeler=time_labeler, ) ax.set_title("Distribution from Prediction") ax = axes[2] - plot_raw_data(array=X_holdout, ax=ax, display_y_axis=False, day_labeler=day_labeler, time_labeler=time_labeler) + plot_raw_data( + array=X_holdout, + ax=ax, + display_y_axis=False, + day_labeler=day_labeler, + time_labeler=time_labeler, + ) ax.set_title("Raw Data in Future") return axes @@ -192,8 +202,8 @@ def plot_model_predictions_by_row( model: LatentCalendar, index_func=lambda idx: idx, divergent: bool = True, - day_labeler: DayLabeler = DayLabeler(), - time_labeler: TimeLabeler = TimeLabeler(), + day_labeler: DayLabeler = DayLabeler(), + time_labeler: TimeLabeler = TimeLabeler(), ) -> np.ndarray: nrows = len(df) @@ -213,8 +223,8 @@ def plot_model_predictions_by_row( model=model, axes=axes_row, divergent=divergent, - day_labeler=day_labeler, - time_labeler=time_labeler, + day_labeler=day_labeler, + time_labeler=time_labeler, ) ylabel = index_func(idx)