forked from artefactory/xhec-mlops-project-student
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added functions to main.py and formatted other py files
- Loading branch information
1 parent
578e483
commit d92e4d8
Showing
4 changed files
with
34 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,42 @@ | ||
# This module is the training flow: it reads the data, preprocesses it, trains a model and saves it. | ||
|
||
import argparse | ||
from pathlib import Path | ||
|
||
from preprocessing import extract_x_y_split, read_data, transform_data | ||
from training import train_model | ||
from utils import save_pickle | ||
|
||
def main(trainset_path: Path) -> None: | ||
"""Train a model using the data at the given path and save the model (pickle).""" | ||
# Read data | ||
|
||
# Preprocess data | ||
def main(trainset_path: Path, output_path: Path) -> None: | ||
"""Train a model using the data at the given path and save the model (pickle). | ||
# (Optional) Pickle encoder if need be | ||
Parameters | ||
------- | ||
trainset_path : path | ||
Path of the train data. | ||
output_path : path | ||
Path to which the model is saved as a pickle. | ||
Returns | ||
------- | ||
None : None | ||
""" | ||
# Read data | ||
data = read_data(trainset_path) | ||
# Preprocess data | ||
trans_data = transform_data(data) | ||
X_train, X_test, y_train, y_test = extract_x_y_split(trans_data) | ||
# Train model | ||
|
||
# Pickle model --> The model should be saved in pkl format the `src/web_service/local_objects` folder | ||
model = train_model(X_train, y_train) | ||
# Pickle model --> The model should be saved in pkl format the | ||
# `src/web_service/local_objects` folder | ||
save_pickle(output_path, model) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Train a model using the data at the given path.") | ||
parser.add_argument("trainset_path", type=str, help="Path to the training set") | ||
parser.add_argument("output_path", type=str, help="Path where the pickle model is saved") | ||
args = parser.parse_args() | ||
main(args.trainset_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters