Skip to content

Commit

Permalink
Added functions to main.py and formatted other py files
Browse files Browse the repository at this point in the history
  • Loading branch information
harshitshangari committed Oct 23, 2023
1 parent 578e483 commit d92e4d8
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 15 deletions.
33 changes: 26 additions & 7 deletions src/modelling/main.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,42 @@
# This module is the training flow: it reads the data, preprocesses it, trains a model and saves it.

import argparse
from pathlib import Path

from preprocessing import extract_x_y_split, read_data, transform_data
from training import train_model
from utils import save_pickle

def main(trainset_path: Path) -> None:
"""Train a model using the data at the given path and save the model (pickle)."""
# Read data

# Preprocess data
def main(trainset_path: Path, output_path: Path) -> None:
"""Train a model using the data at the given path and save the model (pickle).
# (Optional) Pickle encoder if need be
Parameters
-------
trainset_path : path
Path of the train data.
output_path : path
Path to which the model is saved as a pickle.
Returns
-------
None : None
"""
# Read data
data = read_data(trainset_path)
# Preprocess data
trans_data = transform_data(data)
X_train, X_test, y_train, y_test = extract_x_y_split(trans_data)
# Train model

# Pickle model --> The model should be saved in pkl format the `src/web_service/local_objects` folder
model = train_model(X_train, y_train)
# Pickle model --> The model should be saved in pkl format the
# `src/web_service/local_objects` folder
save_pickle(output_path, model)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a model using the data at the given path.")
parser.add_argument("trainset_path", type=str, help="Path to the training set")
parser.add_argument("output_path", type=str, help="Path where the pickle model is saved")
args = parser.parse_args()
main(args.trainset_path)
2 changes: 1 addition & 1 deletion src/modelling/predicting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def predict_pipeline(input_data: pd.DataFrame, model: xgb.XGBRegressor) -> np.nd
model : xgb.XGBRegressor
Model used to predict on the input data.
Returns:
Returns
-------
y : np.ndarray
Array of predicted target values.
Expand Down
6 changes: 3 additions & 3 deletions src/modelling/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def read_data(path: str) -> pd.DataFrame:
path : string
String represents the path to the csv file.
Returns:
Returns
-------
df : pd.Dataframe
Pandas dataframe.
Expand All @@ -30,7 +30,7 @@ def transform_data(df: pd.DataFrame) -> pd.DataFrame:
df : pd.Dataframe
Pandas input dataframe.
Returns:
Returns
-------
df : pd.Dataframe
Transformed dataframe.
Expand Down Expand Up @@ -58,7 +58,7 @@ def extract_x_y_split(
target : str (optional)
The name of the target column in the DataFrame. Default is "age".
Returns:
Returns
--------
X_train : pd.DataFrame
The training feature set (X).
Expand Down
8 changes: 4 additions & 4 deletions src/modelling/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle

from sklearn.linear_model import LinearRegression
import xgboost as xgb


def load_pickle(path: str) -> pickle:
Expand All @@ -11,7 +11,7 @@ def load_pickle(path: str) -> pickle:
path : string
String represents the path to the object.
Returns:
Returns
-------
loaded_obj: pickle object
Pickle which was contained in the path given as parameter.
Expand All @@ -21,7 +21,7 @@ def load_pickle(path: str) -> pickle:
return loaded_obj


def save_pickle(path: str, obj: LinearRegression) -> None:
def save_pickle(path: str, obj: xgb.XGBRegressor) -> None:
"""Given a path and an object, stores the object as a pickle file in the specified path.
Parameters
Expand All @@ -32,7 +32,7 @@ def save_pickle(path: str, obj: LinearRegression) -> None:
obj : LinearRegression
Represents the linear regression model that will be stored.
Returns:
Returns
-------
None : None
"""
Expand Down

0 comments on commit d92e4d8

Please sign in to comment.