Skip to content

Commit

Permalink
Upgrade to pytorch-lightning >=2 (#114)
Browse files Browse the repository at this point in the history
* Upgrade to pytorch-lightning >=2

Upgrades pytorch lightning, modifying usage of pl.Trainer
to conform with the new way to set the devices and checkpoint.

* Update kaggle download commands in README

Include the commands for unzipping the files, including the zipped
csvs within the original zip from the first challenge.

* Try disabling mac memory allocation limit in CI

See if we can use the mac GPU for the trainer test

* Use cpu for if cuda unavailable in trainer test

Avoids trying to use MPS on mac in CI which has insufficient memory

* Store created val.csv alongside input csv

In preprocessing_utils.py save the created val.csv in the
same folder as the input csv instead of in the current working
directory, and add logging so the user knows where it is saved.

This makes the two functions in the file more consistent in where
they save their output.

* Support comma-separated device string
  • Loading branch information
jamt9000 committed Sep 19, 2024
1 parent be0403c commit 0eabf50
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 11 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,14 @@ cd jigsaw_data
# download data
kaggle competitions download -c jigsaw-toxic-comment-classification-challenge
unzip jigsaw-toxic-comment-classification-challenge.zip -d jigsaw-toxic-comment-classification-challenge
find jigsaw-toxic-comment-classification-challenge -name '*.csv.zip' | xargs -n1 unzip -d jigsaw-toxic-comment-classification-challenge
kaggle competitions download -c jigsaw-unintended-bias-in-toxicity-classification
unzip jigsaw-unintended-bias-in-toxicity-classification.zip -d jigsaw-unintended-bias-in-toxicity-classification
kaggle competitions download -c jigsaw-multilingual-toxic-comment-classification
unzip jigsaw-multilingual-toxic-comment-classification.zip -d jigsaw-multilingual-toxic-comment-classification
```
## Start Training
Expand Down
20 changes: 17 additions & 3 deletions preprocessing_utils.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,45 @@
import argparse
import logging
from pathlib import Path

import numpy as np
import pandas as pd

logger = logging.getLogger("preprocessing_utils")
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)


def update_test(test_csv_file):
"""Combines disjointed test and labels csv files into one file."""
test_csv_file = Path(test_csv_file)
test_set = pd.read_csv(test_csv_file)
data_labels = pd.read_csv(test_csv_file[:-4] + "_labels.csv")
data_labels = pd.read_csv(str(test_csv_file)[:-4] + "_labels.csv")
for category in data_labels.columns[1:]:
test_set[category] = data_labels[category]
if "content" in test_set.columns:
test_set.rename(columns={"content": "comment_text"}, inplace=True)
test_set.to_csv(f"{test_csv_file.split('.csv')[0]}_updated.csv")
output_file = test_csv_file.parent / f"{test_csv_file.stem}_updated.csv"
test_set.to_csv(output_file)
logger.info("Updated test set saved to %s", output_file)
return test_set


def create_val_set(csv_file, val_fraction):
"""Takes in a csv file path and creates a validation set
out of it specified by val_fraction.
"""
csv_file = Path(csv_file)
dataset = pd.read_csv(csv_file)
np.random.seed(0)
dataset_mod = dataset[dataset.toxic != -1]
indices = np.random.rand(len(dataset_mod)) > val_fraction
val_set = dataset_mod[~indices]
val_set.to_csv("val.csv")
output_file = csv_file.parent / "val.csv"
logger.info("Validation set saved to %s", output_file)
val_set.to_csv(output_file)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ classifiers = [
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
]
requires-python = ">=3.9,<3.12"
requires-python = ">=3.9,<3.13"
dependencies = [
"sentencepiece >= 0.1.94",
"torch < 2.2",
"torch >=2",
"transformers >= 3",
]

Expand All @@ -29,7 +29,7 @@ dev = [
"datasets >= 1.0.2",
"pandas >= 1.1.2",
"pytest",
"pytorch-lightning<2.0.0,>1.5.0",
"pytorch-lightning>2",
"scikit-learn >= 0.23.2",
"tqdm",
"pre-commit",
Expand Down
3 changes: 2 additions & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

import src.data_loaders as module_data

import torch
from pytorch_lightning import seed_everything, Trainer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -37,7 +38,7 @@ def get_instance(module, name, config, *args, **kwargs):
)

trainer = Trainer(
gpus=0 if torch.cuda.is_available() else None,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
Expand Down
19 changes: 15 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

import pytorch_lightning as pl

import src.data_loaders as module_data
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -159,7 +160,7 @@ def cli_main():
"--device",
default=None,
type=str,
help="indices of GPUs to enable (default: None)",
help="comma-separated indices of GPUs to enable (default: None)",
)
parser.add_argument(
"--num_workers",
Expand Down Expand Up @@ -208,16 +209,26 @@ def get_instance(module, name, config, *args, **kwargs):
monitor="val_loss",
mode="min",
)

if args.device is None:
devices = "auto"
else:
devices = [int(d.strip()) for d in args.device.split(",")]

trainer = pl.Trainer(
gpus=args.device,
devices=devices,
max_epochs=args.n_epochs,
accumulate_grad_batches=config["accumulate_grad_batches"],
callbacks=[checkpoint_callback],
resume_from_checkpoint=args.resume,
default_root_dir="saved/" + config["name"],
deterministic=True,
)
trainer.fit(model, data_loader, valid_data_loader)
trainer.fit(
model=model,
train_dataloaders=data_loader,
val_dataloaders=valid_data_loader,
ckpt_path=args.resume,
)


if __name__ == "__main__":
Expand Down

0 comments on commit 0eabf50

Please sign in to comment.