diff --git a/src/dxtb/_src/calculators/types/base.py b/src/dxtb/_src/calculators/types/base.py index 37aed0de..e47c6b4c 100644 --- a/src/dxtb/_src/calculators/types/base.py +++ b/src/dxtb/_src/calculators/types/base.py @@ -53,7 +53,7 @@ class and implement the :meth:`calculate` method and the corresponding methods from dxtb._src.constants import defaults from dxtb._src.param import Param from dxtb._src.timing import timer -from dxtb._src.typing import Any, Self, Tensor, get_default_dtype, override +from dxtb._src.typing import Any, Self, Tensor, override from dxtb.config import Config from dxtb.integrals import Integrals from dxtb.typing import Tensor, TensorLike diff --git a/test/test_calculator/test_dd.py b/test/test_calculator/test_dd.py index 50d37d0f..a436e16b 100644 --- a/test/test_calculator/test_dd.py +++ b/test/test_calculator/test_dd.py @@ -26,18 +26,21 @@ from dxtb import GFN1_XTB as par from dxtb import Calculator -from dxtb._src.timing import timer -from dxtb._src.typing import MockTensor +from dxtb._src.typing import MockTensor, Tensor def test_fail_dtype() -> None: - numbers = torch.tensor([6, 1]) + numbers = torch.tensor([3, 1]) positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) - charge = torch.tensor([0.0]) - spin = torch.tensor([0.0]) + charge = torch.tensor(0.0) + spin = torch.tensor(0.0) calc = Calculator(numbers, par, opts={"verbosity": 0}) + # same dtype works + e = calc.get_energy(positions, charge, spin) + assert isinstance(e, Tensor) + with pytest.raises(DtypeError): calc.get_energy(positions.type(torch.double), charge, spin) @@ -47,19 +50,20 @@ def test_fail_dtype() -> None: with pytest.raises(DtypeError): calc.get_energy(positions, charge, spin.type(torch.double)) - # because of the exception, the timer for the setup is never stopped - timer.reset() - def test_fail_device() -> None: - numbers = torch.tensor([6, 1]) + numbers = torch.tensor([3, 1]) _positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) - _charge = torch.tensor([0.0]) - _spin = torch.tensor([0.0]) + _charge = torch.tensor(0.0) + _spin = torch.tensor(0.0) calc = Calculator(numbers, par, opts={"verbosity": 0}, dtype=torch.float) + # same device works + e = calc.get_energy(_positions, _charge, _spin) + assert isinstance(e, Tensor) + with pytest.raises(DeviceError): positions = MockTensor(_positions) positions.device = torch.device("cuda") @@ -74,6 +78,3 @@ def test_fail_device() -> None: spin = MockTensor(_spin) spin.device = torch.device("cuda") calc.get_energy(_positions, _charge, spin) - - # because of the exception, the timer for the setup is never stopped - timer.reset()