Skip to content

Commit

Permalink
update tests to cover trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann committed May 7, 2024
1 parent aa695af commit 6ebee0a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 10 deletions.
9 changes: 8 additions & 1 deletion test_extras/test_chemprop/test_chemprop_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pandas as pd
from chemprop.nn.loss import LossFunction
from lightning import pytorch as pl
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.profilers.base import PassThroughProfiler
from sklearn.base import clone
from torch import nn

Expand Down Expand Up @@ -217,8 +219,13 @@ def test_clone(self) -> None:
cloned_params[param_name].state_dict()["task_weights"],
)
self.assertEqual(type(param), type(cloned_params[param_name]))
elif isinstance(param, nn.Identity):
elif isinstance(param, (nn.Identity, Accelerator, PassThroughProfiler)):
self.assertEqual(type(param), type(cloned_params[param_name]))
elif param_name == "lightning_trainer__callbacks":
self.assertIsInstance(cloned_params[param_name], list)
self.assertEqual(len(param), len(cloned_params[param_name]))
for callback, cloned_callback in zip(param, cloned_params[param_name]):
self.assertEqual(type(callback), type(cloned_callback))
else:
self.assertEqual(
param, cloned_params[param_name], f"Failed for {param_name}"
Expand Down
80 changes: 71 additions & 9 deletions test_extras/test_chemprop/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

import logging
import unittest
from pathlib import Path
from typing import Iterable, Sequence

from chemprop.nn.loss import BCELoss, LossFunction, MSELoss
from lightning import pytorch as pl
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar
from lightning.pytorch.profilers.base import PassThroughProfiler
from sklearn.base import clone
from torch import Tensor, nn

Expand Down Expand Up @@ -48,7 +53,43 @@ def get_model() -> ChempropModel:

DEFAULT_PARAMS = {
"batch_size": 64,
"lightning_trainer": pl.Trainer,
"lightning_trainer": None,
"lightning_trainer__limit_predict_batches": 1.0,
"lightning_trainer__fast_dev_run": False,
"lightning_trainer__min_steps": None,
"lightning_trainer__accumulate_grad_batches": 1,
"lightning_trainer__use_distributed_sampler": True,
"lightning_trainer__devices": [0],
"lightning_trainer__check_val_every_n_epoch": 1,
"lightning_trainer__enable_progress_bar": True,
"lightning_trainer__max_epochs": 500,
"lightning_trainer__max_time": None,
"lightning_trainer__val_check_interval": 1.0,
"lightning_trainer__log_every_n_steps": 50,
"lightning_trainer__min_epochs": None,
"lightning_trainer__gradient_clip_algorithm": None,
"lightning_trainer__profiler": PassThroughProfiler,
"lightning_trainer__max_steps": -1,
"lightning_trainer__limit_val_batches": 1.0,
"lightning_trainer__gradient_clip_val": None,
"lightning_trainer__inference_mode": True,
"lightning_trainer__enable_model_summary": False,
"lightning_trainer__limit_test_batches": 1.0,
"lightning_trainer__reload_dataloaders_every_n_epochs": 0,
"lightning_trainer__callbacks": [TQDMProgressBar],
"lightning_trainer__accelerator": Accelerator,
"lightning_trainer__deterministic": False,
"lightning_trainer__logger": None,
"lightning_trainer__overfit_batches": 0.0,
"lightning_trainer__precision": "32-true",
"lightning_trainer__benchmark": False,
"lightning_trainer__num_sanity_val_steps": 2,
"lightning_trainer__enable_checkpointing": False,
"lightning_trainer__limit_train_batches": 1.0,
"lightning_trainer__barebones": False,
"lightning_trainer__default_root_dir": str(Path(".").resolve()),
"lightning_trainer__detect_anomaly": False,
"lightning_trainer__num_nodes": 1,
"model": MPNN,
"model__agg__dim": 0,
"model__agg": SumAggregation,
Expand Down Expand Up @@ -83,9 +124,11 @@ def get_model() -> ChempropModel:
}

NO_IDENTITY_CHECK = [
"lightning_trainer__accelerator",
"lightning_trainer__callbacks",
"lightning_trainer__profiler",
"model__agg",
"model__message_passing",
"lightning_trainer",
"model",
"model__predictor",
"model__predictor__criterion",
Expand All @@ -106,9 +149,14 @@ def test_get_params(self) -> None:
# Check if the parameters are as expected
for param_name, param in expected_params.items():
if param_name in NO_IDENTITY_CHECK:
if not isinstance(param, type):
if isinstance(param, Iterable):
self.assertIsInstance(orig_params[param_name], type(param))
for i, p in enumerate(param):
self.assertIsInstance(orig_params[param_name][i], p)
elif isinstance(param, type):
self.assertIsInstance(orig_params[param_name], param)
else:
raise ValueError(f"{param_name} should be a type.")
self.assertIsInstance(orig_params[param_name], param)
else:
self.assertEqual(
orig_params[param_name], param, f"Test failed for {param_name}"
Expand Down Expand Up @@ -148,8 +196,12 @@ def test_clone(self) -> None:
cloned_param.state_dict()["task_weights"],
)
self.assertEqual(type(param), type(cloned_param))
elif isinstance(param, nn.Identity):
elif isinstance(param, (nn.Identity, Accelerator, PassThroughProfiler)):
self.assertEqual(type(param), type(cloned_param))
elif param_name == "lightning_trainer__callbacks":
self.assertIsInstance(cloned_param, Sequence)
for i, callback in enumerate(param):
self.assertIsInstance(callback, type(cloned_param[i]))
else:
self.assertEqual(param, cloned_param, f"Test failed for {param_name}")

Expand Down Expand Up @@ -184,9 +236,14 @@ def test_get_params(self) -> None:
self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys()))
for param_name, param in expected_params.items():
if param_name in NO_IDENTITY_CHECK:
if not isinstance(param, type):
if isinstance(param, Iterable):
self.assertIsInstance(param_dict[param_name], type(param))
for i, p in enumerate(param):
self.assertIsInstance(param_dict[param_name][i], p)
elif isinstance(param, type):
self.assertIsInstance(param_dict[param_name], param)
else:
raise ValueError(f"{param_name} should be a type.")
self.assertIsInstance(param_dict[param_name], param)
else:
self.assertEqual(
param_dict[param_name], param, f"Test failed for {param_name}"
Expand All @@ -206,9 +263,14 @@ def test_get_params(self) -> None:
self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys()))
for param_name, param in expected_params.items():
if param_name in NO_IDENTITY_CHECK:
if not isinstance(param, type):
if isinstance(param, Iterable):
self.assertIsInstance(param_dict[param_name], type(param))
for i, p in enumerate(param):
self.assertIsInstance(param_dict[param_name][i], p)
elif isinstance(param, type):
self.assertIsInstance(param_dict[param_name], param)
else:
raise ValueError(f"{param_name} should be a type.")
self.assertIsInstance(param_dict[param_name], param)
else:
self.assertEqual(
param_dict[param_name], param, f"Test failed for {param_name}"
Expand Down

0 comments on commit 6ebee0a

Please sign in to comment.