diff --git a/edugrad/_tensor/tensor_broadcasted_binary_mlops.py b/edugrad/_tensor/tensor_broadcasted_binary_mlops.py index 3fb1436..40a8cde 100644 --- a/edugrad/_tensor/tensor_broadcasted_binary_mlops.py +++ b/edugrad/_tensor/tensor_broadcasted_binary_mlops.py @@ -2,7 +2,7 @@ import math -from edugrad.helpers import dtypes +from edugrad.dtypes import dtypes import edugrad.function as function diff --git a/edugrad/_tensor/tensor_combine_segment.py b/edugrad/_tensor/tensor_combine_segment.py index 3da0ba5..32eb2e2 100644 --- a/edugrad/_tensor/tensor_combine_segment.py +++ b/edugrad/_tensor/tensor_combine_segment.py @@ -7,7 +7,7 @@ def cat(tensor, *args, dim) -> Tensor: - from edugrad._tensor import Tensor + from edugrad.tensor import Tensor dim = (dim + len(tensor.shape)) if dim < 0 else dim assert all( diff --git a/edugrad/_tensor/tensor_create.py b/edugrad/_tensor/tensor_create.py index 9d84ecb..0419c1e 100644 --- a/edugrad/_tensor/tensor_create.py +++ b/edugrad/_tensor/tensor_create.py @@ -2,7 +2,8 @@ import time import math -from edugrad.helpers import argfix, DType, prod, shape_int, dtypes +from edugrad.helpers import argfix, prod, shape_int +from edugrad.dtypes import DType, dtypes from edugrad.data import TensorData from edugrad.ops import LoadOps diff --git a/edugrad/_tensor/tensor_index_slice.py b/edugrad/_tensor/tensor_index_slice.py index 5f76d27..562fd32 100644 --- a/edugrad/_tensor/tensor_index_slice.py +++ b/edugrad/_tensor/tensor_index_slice.py @@ -1,7 +1,8 @@ from typing import Sequence, Optional, Tuple from collections import defaultdict -from edugrad.helpers import shape_int, dtypes +from edugrad.helpers import shape_int +from edugrad.dtypes import dtypes from edugrad._tensor.tensor_reshape import pad, _flatten @@ -35,7 +36,7 @@ def __getitem__( tensor: "Tensor", val ) -> "Tensor": # val: Union[int, slice, Tensor, None, Ellipsis, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]] - from edugrad._tensor import Tensor + from edugrad.tensor import Tensor def normalize_int(e, i, dim_sz): if -dim_sz <= e < dim_sz: @@ -141,10 +142,12 @@ def __setitem__(tensor: "Tensor", s, v): # NOTE: using slice is discouraged and things should migrate to pad and shrink -def slice(tensor: "Tensor", arg: Sequence[Optional[Tuple[int, shape_int]]], value: float) -> "Tensor": +def tslice(tensor: "Tensor", arg: Sequence[Optional[Tuple[int, shape_int]]], value: float = 0) -> "Tensor": + from edugrad.tensor import Tensor + arg_ = tuple([a if a is not None else (0, s) for s, a in zip(tensor.shape, arg)]) padding = tuple([(max(0, -p[0]), max(0, p[1] - tensor.shape[i])) for i, p in enumerate(arg_)]) - return pad(tensor, padding, value=value).shrink( + return tensor.pad(padding, value=value).shrink( tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i, p in enumerate(arg_)]) ) # FIXME: tensor.pad(padding, value=value)... returns None... diff --git a/edugrad/_tensor/tensor_nn.py b/edugrad/_tensor/tensor_nn.py index 5aeb04a..75fa7da 100644 --- a/edugrad/_tensor/tensor_nn.py +++ b/edugrad/_tensor/tensor_nn.py @@ -3,7 +3,8 @@ from __future__ import annotations import math -from edugrad.helpers import make_pair, flatten, dtypes, all_int, shape_int +from edugrad.helpers import make_pair, flatten, all_int, shape_int +from edugrad.dtypes import dtypes # processing ops diff --git a/edugrad/_tensor/tensor_reduce.py b/edugrad/_tensor/tensor_reduce.py index 9657416..02d834b 100644 --- a/edugrad/_tensor/tensor_reduce.py +++ b/edugrad/_tensor/tensor_reduce.py @@ -2,7 +2,8 @@ from __future__ import annotations -from edugrad.helpers import dtypes, prod, all_int +from edugrad.helpers import prod, all_int +from edugrad.dtypes import dtypes from edugrad.function import Function import edugrad.function as function @@ -44,6 +45,7 @@ def _reduce(self, fxn: type[Function], axis: int | tuple[int, ...] | None, keepd return ret if keepdim else ret.reshape(shape=shape) +# ---------------------------------------------------------------------------------------------------------------------- # Functions that use the generic _reduce method for specific reduction operations. @@ -59,7 +61,7 @@ def tmax(tensor: Tensor, axis, keepdim): def tmin(tensor: Tensor, axis, keepdim): """Computes the minimum value of elements over the specified axis.""" - return -((-tensor).tmax((-tensor), axis=axis, keepdim=keepdim)) + return -tmax((-tensor), axis=axis, keepdim=keepdim) def mean(tensor: Tensor, axis, keepdim): @@ -76,6 +78,7 @@ def std(tensor: Tensor, axis, keepdim, correction): return square_sum.div(prod(tensor.shape) / prod(square_sum.shape) - correction).sqrt() +# ---------------------------------------------------------------------------------------------------------------------- # Functions for softmax and its logarithmic variant, as well as argmax and argmin operations. diff --git a/edugrad/_tensor/tensor_reshape.py b/edugrad/_tensor/tensor_reshape.py index b286a86..c1873c1 100644 --- a/edugrad/_tensor/tensor_reshape.py +++ b/edugrad/_tensor/tensor_reshape.py @@ -44,10 +44,12 @@ def shrink(tensor: Tensor, arg: tuple[tuple[shape_int, shape_int] | None, ...]) def pad(tensor: Tensor, arg: tuple[tuple[int, int] | None, ...], value: float) -> Tensor: + from edugrad.tensor import Tensor + if all(x is None or x == (0, 0) for x in arg): return tensor ret = function.Pad.apply(tensor, arg=(narg := tuple(x if x is not None else (0, 0) for x in arg))) - return ret if 0 == value else ret + function.Pad.apply("Tensor".ones_like(tensor), arg=narg).where(0, value) + return ret if 0 == value else ret + function.Pad.apply(Tensor.ones_like(tensor), arg=narg).where(0, value) # (padding_left, padding_right, padding_top, padding_bottom) diff --git a/edugrad/data.py b/edugrad/data.py index af514ae..2216fd2 100644 --- a/edugrad/data.py +++ b/edugrad/data.py @@ -13,7 +13,8 @@ from typing import Tuple import numpy as np from edugrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, LoadOps # consider reading the docs there -from edugrad.helpers import DType, dtypes, DEBUG +from edugrad.helpers import DEBUG +from edugrad.dtypes import DType, dtypes class TensorData: diff --git a/edugrad/dtypes.py b/edugrad/dtypes.py new file mode 100644 index 0000000..217d3a1 --- /dev/null +++ b/edugrad/dtypes.py @@ -0,0 +1,93 @@ +from typing import ClassVar, Dict, Optional, Final +import numpy as np +from dataclasses import dataclass + + +@dataclass(frozen=True, order=True) +class DType: + """Data type class for managing different data types.""" + + priority: int # Priority for upcasting + itemsize: int # Size of the data type in bytes + name: str # Name of the data type + np: Optional[type] # Corresponding numpy data type + sz: int = 1 # Size factor + + def __repr__(self): + return f"dtypes.{self.name}" + + +class dtypes: + """Container for different data types and utility methods. + + We need this because some layer operation might use different trade-offs between precision and efficiency. In such + cases, we have to translate b/w dtypes. + + """ + + @staticmethod + def is_int(x: DType) -> bool: + """Check if a data type is an integer type.""" + return x in ( + dtypes.int8, + dtypes.int16, + dtypes.int32, + dtypes.int64, + dtypes.uint8, + dtypes.uint16, + dtypes.uint32, + dtypes.uint64, + ) + + @staticmethod + def is_float(x: DType) -> bool: + """Check if a data type is a float type.""" + return x in (dtypes.float16, dtypes.float32, dtypes.float64) + + @staticmethod + def is_unsigned(x: DType) -> bool: + """Check if a data type is an unsigned type.""" + return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) + + @staticmethod + def from_np(x) -> DType: + """Convert a numpy data type to a DType.""" + return DTYPES_DICT[np.dtype(x).name] + + @staticmethod + def fields() -> Dict[str, DType]: + return DTYPES_DICT + + @staticmethod # NOTE: isinstance(True, int) is True in python + def from_py(x) -> DType: + return ( + dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int + ) + + # Definition of various data types + bool: Final[DType] = DType(0, 1, "bool", np.bool_) + float16: Final[DType] = DType(9, 2, "half", np.float16) + half = float16 + float32: Final[DType] = DType(10, 4, "float", np.float32) + float = float32 + float64: Final[DType] = DType(11, 8, "double", np.float64) + double = float64 + int8: Final[DType] = DType(1, 1, "char", np.int8) + int16: Final[DType] = DType(3, 2, "short", np.int16) + int32: Final[DType] = DType(5, 4, "int", np.int32) + int64: Final[DType] = DType(7, 8, "long", np.int64) + uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8) + uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16) + uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32) + uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64) + + default_float: ClassVar[DType] = float32 + default_int: ClassVar[DType] = int32 + + +# Dictionary mapping data type names to DType objects +DTYPES_DICT = { + k: v + for k, v in dtypes.__dict__.items() + if not k.startswith("__") and not callable(v) and not v.__class__ == staticmethod +} diff --git a/edugrad/function.py b/edugrad/function.py index 0cb4f67..709d6be 100644 --- a/edugrad/function.py +++ b/edugrad/function.py @@ -11,7 +11,8 @@ """ import math from typing import Tuple, Optional, cast -from edugrad.helpers import argsort, DType, shape_int +from edugrad.helpers import argsort, shape_int +from edugrad.dtypes import DType from edugrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps from edugrad.data import TensorData @@ -355,3 +356,12 @@ def backward(self, grad_output: TensorData) -> TensorData: ), "symbolic shrink does not support backward" # need this cast because mypy cannot narrow the type even with assert return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg)) + + +class Flip(Function): + def forward(self, x: TensorData, axis: Tuple[int, ...]) -> TensorData: + self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))]) + return x.stride(self.arg) + + def backward(self, grad_output: TensorData) -> TensorData: + return grad_output.stride(self.arg) diff --git a/edugrad/helpers.py b/edugrad/helpers.py index 30c288d..f42d8ae 100644 --- a/edugrad/helpers.py +++ b/edugrad/helpers.py @@ -1,9 +1,7 @@ -from typing import Union, Tuple, Iterator, Optional, Final, Any +from typing import Union, Tuple, Iterator, Any import os import functools -import numpy as np from math import prod # noqa: F401 # pylint:disable=unused-import -from dataclasses import dataclass shape_int = int @@ -28,6 +26,12 @@ def flatten(list_: Iterator): return [item for sublist in list_ for item in sublist] +def fully_flatten(l): + return [ + item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist]) + ] + + def argsort(x): """Return the indices that would sort an array. @@ -55,83 +59,3 @@ def getenv(key, default=0): # Global flags for debugging and continuous integration DEBUG = getenv("DEBUG") - - -@dataclass(frozen=True, order=True) -class DType: - """Data type class for managing different data types.""" - - priority: int # Priority for upcasting - itemsize: int # Size of the data type in bytes - name: str # Name of the data type - np: Optional[type] # Corresponding numpy data type - sz: int = 1 # Size factor - - def __repr__(self): - return f"dtypes.{self.name}" - - -class dtypes: - """Container for different data types and utility methods. - - We need this because some layer operation might use different trade-offs between precision and efficiency. In such - cases, we have to translate b/w dtypes. - - """ - - @staticmethod - def is_int(x: DType) -> bool: - """Check if a data type is an integer type.""" - return x in ( - dtypes.int8, - dtypes.int16, - dtypes.int32, - dtypes.int64, - dtypes.uint8, - dtypes.uint16, - dtypes.uint32, - dtypes.uint64, - ) - - @staticmethod - def is_float(x: DType) -> bool: - """Check if a data type is a float type.""" - return x in (dtypes.float16, dtypes.float32, dtypes.float64) - - @staticmethod - def is_unsigned(x: DType) -> bool: - """Check if a data type is an unsigned type.""" - return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) - - @staticmethod - def from_np(x) -> DType: - """Convert a numpy data type to a DType.""" - return DTYPES_DICT[np.dtype(x).name] - - # Definition of various data types - bool: Final[DType] = DType(0, 1, "bool", np.bool_) - float16: Final[DType] = DType(9, 2, "half", np.float16) - half = float16 - float32: Final[DType] = DType(10, 4, "float", np.float32) - float = float32 - float64: Final[DType] = DType(11, 8, "double", np.float64) - double = float64 - int8: Final[DType] = DType(1, 1, "char", np.int8) - int16: Final[DType] = DType(3, 2, "short", np.int16) - int32: Final[DType] = DType(5, 4, "int", np.int32) - int64: Final[DType] = DType(7, 8, "long", np.int64) - uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8) - uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16) - uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32) - uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64) - - # Note: bfloat16 isn't supported in numpy - bfloat16: Final[DType] = DType(9, 2, "__bf16", None) - - -# Dictionary mapping data type names to DType objects -DTYPES_DICT = { - k: v - for k, v in dtypes.__dict__.items() - if not k.startswith("__") and not callable(v) and not v.__class__ == staticmethod -} diff --git a/edugrad/tensor.py b/edugrad/tensor.py index 7738f79..32cbadb 100644 --- a/edugrad/tensor.py +++ b/edugrad/tensor.py @@ -1,4 +1,4 @@ -"""Contain the tensor class that can be used for building neural networks with forward and backward pass. +"""Contains the tensor class that can be used for building neural networks with forward and backward pass. The module contains the "high-level ops". These are syntax sugar and built on top of the "mid-level ops" containing the the functions with forward and backward passes in Function.function which is build on top of the "low-level ops" @@ -11,9 +11,13 @@ from __future__ import annotations import time from typing import ClassVar, Sequence, Any +from collections import defaultdict +import math + import numpy as np -from edugrad.helpers import getenv, DEBUG, DType, dtypes, prod, all_int, round_up, shape_int +from edugrad.helpers import getenv, DEBUG, prod, all_int, round_up, shape_int +from edugrad.dtypes import DType, dtypes from edugrad.data import TensorData from edugrad.ops import LoadOps from edugrad.function import Function @@ -27,10 +31,11 @@ from edugrad._tensor.tensor_combine_segment import cat, stack, repeat, chunk from edugrad._tensor.tensor_reshape import reshape, expand, permute, flip, shrink, pad, pad2d, transpose, _flatten, squeeze, unsqueeze from edugrad._tensor.tensor_nn import _pool, avg_pool2d, max_pool2d, conv2d, linear, binary_crossentropy, binary_crossentropy_logits, sparse_categorical_crossentropy -from edugrad._tensor.tensor_index_slice import __getitem__, __setitem__, slice, gather +from edugrad._tensor.tensor_index_slice import __getitem__, __setitem__, tslice, gather from edugrad._tensor.tensor_broadcasted_binary_mlops import _broadcasted, _to_float, add, sub, mul, div, pow, matmul, maximum, minimum, where from edugrad._tensor.tensor_reduce import _reduce, tsum, tmax, tmin, mean, std, _softmax, softmax, log_softmax, argmax, argmin # fmt: on +from edugrad.helpers import argfix, fully_flatten class Tensor: @@ -69,19 +74,31 @@ def __init__( # internal variables used for autograd graph construction self._ctx: Function | None = None + # -------------------------------------------------------------------------------------------------------------- + # Handles Tensor(x) for x with different data types + if isinstance(data, TensorData): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" - elif isinstance(data, (int, float)): - data = TensorData.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, data) + elif isinstance(data, (bool, int, float)): + data = TensorData.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), data) - elif data is None or data.__class__ is list: - assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" - data = TensorData(np.array([] if data is None else data, dtype=(dtype or Tensor.default_type).np)) + elif isinstance(data, list): + if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): + dtype = dtype or dtypes.bool + elif d and all_int(d): + dtype = dtype or dtypes.default_int + else: + dtype = dtype or dtypes.default_float + # NOTE: cast at the end for the dtypes that do not have a numpy dtype + data = TensorData(np.array(data, dtype.np)).cast(dtype) elif isinstance(data, bytes): data = TensorData(np.frombuffer(data, np.uint8)) + elif data is None: + data = TensorData.loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float) + elif isinstance(data, np.ndarray): assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" if data.shape == (): @@ -218,9 +235,10 @@ def backward(self): return backward(self) def reshape(self, shape, *args) -> Tensor: return reshape(self, shape, *args) def expand(self, shape, *args) -> Tensor: return expand(self, shape, *args) + def permute(self, order, *args) -> Tensor: return permute(self, order, *args) def flip(self, axis, *args) -> Tensor: return flip(self, axis, *args) - def pad(self, arg:tuple[tuple[int, int] | None, ...], value:float=0.0) -> Tensor: pad(self, arg, value) + def pad(self, arg:tuple[tuple[int, int] | None, ...], value:float=0.0) -> Tensor: return pad(self, arg, value) # (padding_left, padding_right, padding_top, padding_bottom) def pad2d(self, padding:list[int] | tuple[int, ...], value:float=0) -> Tensor: return pad2d(self, padding, value) def shrink(self, arg:tuple[tuple[shape_int, shape_int] | None, ...]) -> Tensor: return shrink(self, arg) @@ -243,14 +261,14 @@ def __setitem__(self,s,v): return __setitem__(self,s,v) # NOTE: using slice is discouraged and things should migrate to pad and shrink def slice(self, arg:Sequence[tuple[int, shape_int] | None], value:float=0) -> Tensor: - return slice(self, arg, value) + return tslice(self, arg, value=value) def gather(self: Tensor, idx: Tensor, dim: int) -> Tensor: return gather(self, idx, dim) # ------------------------------------------------------------------------------------------------------------------ # tensor_combine_segment.py - def cat(self, *args, dim=0) -> Tensor: return cat(self, *args, dim) + def cat(self, *args, dim=0) -> Tensor: return cat(self, *args, dim=dim) @staticmethod def stack(tensors, dim=0) -> Tensor: stack(tensors, dim) def repeat(self, repeats) -> Tensor: repeat(self, repeats) @@ -322,10 +340,13 @@ def exp(self): return function.Exp.apply(self) def relu(self): return function.Relu.apply(self) def sigmoid(self): return function.Sigmoid.apply(self) def sqrt(self): return function.Sqrt.apply(self) + def sin(self): return function.Sin.apply(self) + def cos(self): return ((math.pi/2)-self).sin() # math functions (unary) skipped - # activation functions (unary) skipped + # activation functions (unary) + def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu() # ------------------------------------------------------------------------------------------------------------------ # tensor_bradcasted_binary_mlops.py diff --git a/environment.yaml b/environment.yaml index bc3a68d..1d8a3d1 100644 --- a/environment.yaml +++ b/environment.yaml @@ -11,3 +11,4 @@ dependencies: # Tests - pytest + - pytorch diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/gradcheck.py b/tests/gradcheck.py new file mode 100644 index 0000000..cd0636a --- /dev/null +++ b/tests/gradcheck.py @@ -0,0 +1,54 @@ +import numpy as np +from edugrad.tensor import Tensor + + +def mask_like(like, mask_inx, mask_value=1.0): + mask = np.zeros_like(like).reshape(-1) + mask[mask_inx] = mask_value + return mask.reshape(like.shape) + + +def jacobian(func, input): + output = func(input) + + ji = input.numpy().reshape(-1).shape[-1] + jo = output.numpy().reshape(-1).shape[-1] + J = np.zeros((jo, ji), dtype=np.float32) + + for o in range(jo): + input.grad = None + output = func(input) + + # tinygrad doesn't support slicing, tiny-hack to select + # the needed scalar an backpropagate only through it + o_scalar = Tensor(mask_like(output.numpy(), o, 1.0)).mul(output).sum() + o_scalar.backward() + + for i, grad in enumerate(input.grad.numpy().reshape(-1)): + J[o, i] = grad + return J + + +def numerical_jacobian(func, input, eps=1e-3): + output = func(input) + + ji = input.numpy().reshape(-1).shape[-1] + jo = output.numpy().reshape(-1).shape[-1] + NJ = np.zeros((jo, ji), dtype=np.float32) + + for i in range(ji): + eps_perturb = mask_like(input.numpy(), i, mask_value=eps) + + output_perturb_add = func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1) + output_perturb_sub = func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1) + + grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2 * eps) + + NJ[:, i] = grad_approx + return NJ + + +def gradcheck(func, input, eps=1e-3, atol=1e-3, rtol=1e-3): + NJ = numerical_jacobian(func, input, eps) + J = jacobian(func, input) + return np.allclose(J, NJ, atol=atol, rtol=rtol) diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index 63ae41d..0000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_dummy(): - assert True is True diff --git a/tests/test_tensor.py b/tests/test_tensor.py new file mode 100644 index 0000000..3293fd7 --- /dev/null +++ b/tests/test_tensor.py @@ -0,0 +1,441 @@ +import numpy as np +import torch +import unittest, copy +from edugrad import Tensor +from edugrad.dtypes import dtypes + +# from tinygrad.helpers import temp + +from tests.gradcheck import numerical_jacobian, jacobian, gradcheck + +x_init = np.random.randn(1, 3).astype(np.float32) +U_init = np.random.randn(3, 3).astype(np.float32) +V_init = np.random.randn(3, 3).astype(np.float32) +W_init = np.random.randn(3, 3).astype(np.float32) +m_init = np.random.randn(1, 3).astype(np.float32) + + +class TestTinygrad(unittest.TestCase): + def test_zerodim_initialization(self): + a = Tensor(55) + b = Tensor(3.14) + + self.assertEqual(a.shape, ()) + self.assertEqual(b.shape, ()) + + def test_plus_equals(self): + a = Tensor.randn(10, 10) + b = Tensor.randn(10, 10) + c = a + b + val1 = c.numpy() + a += b + val2 = a.numpy() + np.testing.assert_allclose(val1, val2) + + def test_backward_pass(self): + def test_tinygrad(): + x = Tensor(x_init, requires_grad=True) + W = Tensor(W_init, requires_grad=True) + m = Tensor(m_init) + out = x.dot(W).relu() + out = out.log_softmax() + out = out.mul(m).add(m).sum() + out.backward() + return out.numpy(), x.grad.numpy(), W.grad.numpy() + + def test_pytorch(): + x = torch.tensor(x_init, requires_grad=True) + W = torch.tensor(W_init, requires_grad=True) + m = torch.tensor(m_init) + out = x.matmul(W).relu() + out = torch.nn.functional.log_softmax(out, dim=1) + out = out.mul(m).add(m).sum() + out.backward() + return out.detach().numpy(), x.grad, W.grad + + for x, y in zip(test_tinygrad(), test_pytorch()): + np.testing.assert_allclose(x, y, atol=1e-5) + + # @unittest.skipIf(Device.DEFAULT == "WEBGPU", "this test uses more than 8 bufs which breaks webgpu") #TODO: remove after #1461 + def test_backward_pass_diamond_model(self): + def test_tinygrad(): + u = Tensor(U_init, requires_grad=True) + v = Tensor(V_init, requires_grad=True) + w = Tensor(W_init, requires_grad=True) + x = u.mul(v).relu() + y = u.mul(w).relu() + out = x.add(y).mul(y).relu() + out = out.log_softmax() + out = out.sum() + out.backward() + return out.numpy(), u.grad.numpy(), v.grad.numpy(), w.grad.numpy() + + def test_pytorch(): + u = torch.tensor(U_init, requires_grad=True) + v = torch.tensor(V_init, requires_grad=True) + w = torch.tensor(W_init, requires_grad=True) + x = u.mul(v).relu() + y = u.mul(w).relu() + out = x.add(y).mul(y).relu() + out = torch.nn.functional.log_softmax(out, dim=1) + out = out.sum() + out.backward() + return out.detach().numpy(), u.grad, v.grad, w.grad + + for x, y in zip(test_tinygrad(), test_pytorch()): + np.testing.assert_allclose(x, y, atol=1e-5) + + def test_nograd(self): + x = Tensor(x_init, requires_grad=False) + m = Tensor(m_init, requires_grad=False) + W = Tensor(W_init, requires_grad=True) + tmp = x.mul(m) + mm = tmp.matmul(W) + out = mm.relu() + out = out.sum() + out.backward() + assert x.grad is None + assert m.grad is None + assert tmp.grad is None + assert mm.grad is not None + assert W.grad is not None + + def test_jacobian(self): + W = np.random.RandomState(42069).random((10, 5)).astype(np.float32) + x = np.random.RandomState(69420).random((1, 10)).astype(np.float32) + + torch_x = torch.tensor(x, requires_grad=True) + torch_W = torch.tensor(W, requires_grad=True) + + def torch_func(x): + return torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1) + + PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy() + + tiny_x = Tensor(x, requires_grad=True) + tiny_W = Tensor(W, requires_grad=True) + + def tiny_func(x): + return x.dot(tiny_W).relu().log_softmax() + + J = jacobian(tiny_func, tiny_x) + NJ = numerical_jacobian(tiny_func, tiny_x) + + np.testing.assert_allclose(PJ, J, atol=1e-5) + np.testing.assert_allclose(PJ, NJ, atol=1e-3) + + def test_gradcheck(self): + W = np.random.RandomState(1337).random((10, 5)).astype(np.float32) + x = np.random.RandomState(7331).random((1, 10)).astype(np.float32) + + tiny_x = Tensor(x, requires_grad=True) + tiny_W = Tensor(W, requires_grad=True) + + def tiny_func(x): + return x.dot(tiny_W).relu().log_softmax() + + self.assertTrue(gradcheck(tiny_func, tiny_x, eps=1e-3)) + + # coarse approx. since a "big" eps and the non-linearities of the model + self.assertFalse(gradcheck(tiny_func, tiny_x, eps=1e-5)) + + def test_random_fns_are_deterministic_with_seed(self): + for random_fn in [Tensor.randn, Tensor.normal, Tensor.uniform, Tensor.scaled_uniform]: + with self.subTest(msg=f"Tensor.{random_fn.__name__}"): + Tensor.manual_seed(1337) + a = random_fn(10, 10) + Tensor.manual_seed(1337) + b = random_fn(10, 10) + np.testing.assert_allclose(a.numpy(), b.numpy()) + + def test_randn_isnt_inf_on_zero(self): + # simulate failure case of rand handing a zero to randn + original_rand, Tensor.rand = Tensor.rand, Tensor.zeros + try: + self.assertNotIn(np.inf, Tensor.randn(16).numpy()) + except: + raise + finally: + Tensor.rand = original_rand + + def test_zeros_like_has_same_dtype_and_shape(self): + for datatype in [dtypes.float16, dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64, dtypes.uint8]: + a = Tensor([1, 2, 3], dtype=datatype) + b = Tensor.zeros_like(a) + assert a.dtype == b.dtype, f"dtype mismatch {a.dtype=} != {b.dtype}" + assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}" + + a = Tensor([1, 2, 3]) + b = Tensor.zeros_like(a, dtype=dtypes.int8) + assert ( + a.dtype == dtypes.default_int and b.dtype == dtypes.int8 + ), "a.dtype should be int and b.dtype should be char" + assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}" + + def test_ones_like_has_same_dtype_and_shape(self): + for datatype in [dtypes.float16, dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64, dtypes.uint8]: + a = Tensor([1, 2, 3], dtype=datatype) + b = Tensor.ones_like(a) + assert a.dtype == b.dtype, f"dtype mismatch {a.dtype=} != {b.dtype}" + assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}" + + a = Tensor([1, 2, 3]) + b = Tensor.ones_like(a, dtype=dtypes.int8) + assert ( + a.dtype == dtypes.default_int and b.dtype == dtypes.int8 + ), "a.dtype should be int and b.dtype should be char" + assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}" + + def test_ndim(self): + assert Tensor(1).ndim == 0 + assert Tensor.randn(1).ndim == 1 + assert Tensor.randn(2, 2, 2).ndim == 3 + assert Tensor.randn(1, 1, 1, 1, 1, 1).ndim == 6 + + def test_argfix(self): + self.assertEqual(Tensor.zeros().shape, ()) + self.assertEqual(Tensor.ones().shape, ()) + + self.assertEqual(Tensor.zeros([]).shape, ()) + self.assertEqual(Tensor.ones([]).shape, ()) + + self.assertEqual(Tensor.zeros(tuple()).shape, ()) + self.assertEqual(Tensor.ones(tuple()).shape, ()) + + self.assertEqual(Tensor.zeros(1).shape, (1,)) + self.assertEqual(Tensor.ones(1).shape, (1,)) + + self.assertEqual(Tensor.zeros(1, 10, 20).shape, (1, 10, 20)) + self.assertEqual(Tensor.ones(1, 10, 20).shape, (1, 10, 20)) + + self.assertEqual(Tensor.zeros([1]).shape, (1,)) + self.assertEqual(Tensor.ones([1]).shape, (1,)) + + self.assertEqual(Tensor.zeros([10, 20, 40]).shape, (10, 20, 40)) + self.assertEqual(Tensor.ones([10, 20, 40]).shape, (10, 20, 40)) + + self.assertEqual(Tensor.rand(1, 10, 20).shape, (1, 10, 20)) + self.assertEqual(Tensor.rand((10, 20, 40)).shape, (10, 20, 40)) + + self.assertEqual(Tensor.empty(1, 10, 20).shape, (1, 10, 20)) + self.assertEqual(Tensor.empty((10, 20, 40)).shape, (10, 20, 40)) + + def test_numel(self): + assert Tensor.randn(10, 10).numel() == 100 + assert Tensor.randn(1, 2, 5).numel() == 10 + assert Tensor.randn(1, 1, 1, 1, 1, 1).numel() == 1 + assert Tensor([]).numel() == 0 + assert Tensor.randn(1, 0, 2, 5).numel() == 0 + + def test_element_size(self): + for _, dtype in dtypes.fields().items(): + assert ( + dtype.itemsize == Tensor.randn(3, dtype=dtype).element_size() + ), f"Tensor.element_size() not matching Tensor.dtype.itemsize for {dtype}" + + def test_deepwalk_ctx_check(self): + layer = Tensor.uniform(1, 1, requires_grad=True) + x = Tensor.randn(1, 1, 1) + x.dot(layer).mean().backward() + x = Tensor.randn(1, 1, 1) + x.dot(layer).mean().backward() + + def test_zerosized_tensors(self): + np.testing.assert_equal(Tensor([]).numpy(), np.array([])) + np.testing.assert_equal(Tensor(None).numpy(), np.array([])) + + def test_tensor_ndarray_dtype(self): + arr = np.array([1]) # where dtype is implicitly int64 + assert Tensor(arr).dtype == dtypes.int64 + assert ( + Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 + ) # check if ndarray correctly casts to Tensor dtype + assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else + + def test_tensor_list_dtype(self): + for arr in ([1], [[[1]]], [[1, 1], [1, 1]], [[[1, 1], [1, 1]], [[1, 1], [1, 1]]]): + x = Tensor(arr) + have = x.dtype + assert Tensor(arr).dtype == dtypes.default_int + assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 + assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 + + for arr in ( + [True], + [[[False]]], + [[True, False], [True, False]], + [[[False, True], [False, False]], [[True, True], [False, True]]], + ): + assert Tensor(arr).dtype == dtypes.bool + assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 + assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 + + # empty tensor defaults + for arr in ([], [[[]]], [[], []]): + t = Tensor(arr) + assert t.dtype == dtypes.default_float + np.testing.assert_allclose(t.numpy(), np.array(arr)) + + # mixture of bool and int + for arr in ([True, 3], [[True], [3]], [[[True]], [[3]]], [[True, 3], [3, True]]): + t = Tensor(arr) + assert t.dtype == dtypes.default_int + np.testing.assert_allclose(t.numpy(), np.array(arr)) + + # mixture of bool, int and float + for arr in ( + [[True, True], [3.0, True]], + [[0, 1], [3.0, 4]], + [[[0], [1]], [[3.0], [4]]], + [[[True], [1]], [[3.0], [4]]], + ): + t = Tensor(arr) + assert t.dtype == dtypes.default_float + np.testing.assert_allclose(t.numpy(), np.array(arr)) + + def test_tensor_list_shapes(self): + self.assertEqual(Tensor([[[]]]).shape, (1, 1, 0)) + self.assertEqual(Tensor([[], []]).shape, (2, 0)) + self.assertEqual(Tensor([[[[]], [[]]], [[[]], [[]]], [[[]], [[]]]]).shape, (3, 2, 1, 0)) + + def test_tensor_list_errors(self): + # inhomogeneous shape + with self.assertRaises(ValueError): + Tensor([[], [[]]]) + with self.assertRaises(ValueError): + Tensor([[1], []]) + with self.assertRaises(ValueError): + Tensor([[1], [1], 1]) + with self.assertRaises(ValueError): + Tensor([[[1, 1, 1], [1, 1]]]) + with self.assertRaises(ValueError): + Tensor([[1, 1, 1], [[1, 1, 1]]]) + + def test_tensor_copy(self): + x = copy.deepcopy(Tensor.ones((3, 3, 3))) + np.testing.assert_allclose(x.numpy(), np.ones((3, 3, 3))) + + def test_item_to_tensor_to_item(self): + for a in [0, 1, 2, 3, -1, -100, 100, -101.1, 2.345, 100.1, True, False]: + item = Tensor(a).item() + assert type(item) == type(a), a + np.testing.assert_allclose(item, a), a + buffered_item = Tensor([a]).item() + assert type(buffered_item) == type(a), a + np.testing.assert_allclose(buffered_item, a), a + reshaped_item = Tensor([a]).reshape((1, 1, 1, 1, 1)).item() + assert type(reshaped_item) == type(a), a + np.testing.assert_allclose(reshaped_item, a), a + + +class TestZeroShapeTensor(unittest.TestCase): + def test_rand(self): + t = Tensor.rand(3, 2, 0) + assert t.shape == (3, 2, 0) + np.testing.assert_equal(t.numpy(), np.zeros((3, 2, 0))) + t = Tensor.rand(0) + assert t.shape == (0,) + np.testing.assert_equal(t.numpy(), np.zeros((0,))) + t = Tensor.rand(0, 0, 0) + assert t.shape == (0, 0, 0) + np.testing.assert_equal(t.numpy(), np.zeros((0, 0, 0))) + + def test_full(self): + t = Tensor.zeros(3, 2, 0) + assert t.shape == (3, 2, 0) + np.testing.assert_equal(t.numpy(), np.zeros((3, 2, 0))) + t = Tensor.full((3, 2, 0), 12) + assert t.shape == (3, 2, 0) + np.testing.assert_equal(t.numpy(), np.full((3, 2, 0), 12)) + + def test_reshape(self): + t = Tensor.zeros(3, 2, 0) + a = t.reshape(7, 0) + assert a.shape == (7, 0) + np.testing.assert_equal(a.numpy(), np.zeros((7, 0))) + with self.assertRaises(ValueError): + # cannot reshape array of size 0 into shape () + a = t.reshape(()) + + def test_expand(self): + t = Tensor.full((3, 2, 0), 12) + # with numpy operands could not be broadcast together with remapped shapes [original->remapped]: (3,2,0) + # and requested shape (6,2,0) + with self.assertRaises(ValueError): + t = t.expand((6, 2, 0)) + # assert t.shape == (6, 2, 0) + # np.testing.assert_equal(t.numpy(), np.full((6, 2, 0), 12)) + + def test_pad(self): + t = Tensor.rand(3, 2, 0).pad((None, None, (1, 1)), 1) + assert t.shape == (3, 2, 2) + np.testing.assert_equal(t.numpy(), np.ones((3, 2, 2))) + + # torch does not support padding non-zero dim with 0-size. torch.nn.functional.pad(torch.zeros(3,2,0), [0,0,0,4,0,0]) + t = Tensor.rand(3, 2, 0).pad((None, (1, 1), None), 1) + assert t.shape == (3, 4, 0) + np.testing.assert_equal(t.numpy(), np.ones((3, 4, 0))) + + t = Tensor.rand(3, 2, 0).pad(((1, 1), None, None), 1) + assert t.shape == (5, 2, 0) + np.testing.assert_equal(t.numpy(), np.ones((5, 2, 0))) + + def test_shrink_into_zero(self): + t = Tensor.rand(3, 4) + assert t.shrink((None, (2, 2))).shape == (3, 0) + assert t.shrink(((2, 2), None)).shape == (0, 4) + assert t.shrink(((2, 2), (2, 2))).shape == (0, 0) + + def test_cat(self): + s = Tensor.rand(3, 2, 2) + t = Tensor.rand(3, 2, 0).cat(s, dim=2) + assert t.shape == (3, 2, 2) + np.testing.assert_equal(t.numpy(), s.numpy()) + + # torch does not support padding non-zero dim with 0-size. torch.nn.functional.pad(torch.zeros(3,2,0), [0,0,0,4,0,0]) + s = Tensor.rand(3, 4, 0) + t = Tensor.rand(3, 2, 0).cat(s, dim=1) + assert t.shape == (3, 6, 0) + np.testing.assert_equal(t.numpy(), np.zeros((3, 6, 0))) + + def test_elementwise(self): + a = Tensor.rand(3, 2, 0) + a_exp = a.exp() + assert a_exp.shape == (3, 2, 0) + np.testing.assert_equal(a_exp.numpy(), np.exp(a.numpy())) + + b = Tensor.rand(3, 2, 0) + assert b.shape == (3, 2, 0) + ab = a * b + assert ab.shape == (3, 2, 0) + np.testing.assert_equal(ab.numpy(), a.numpy() * b.numpy()) + + mask = Tensor.rand(3, 2, 0) > 0.5 + assert mask.shape == (3, 2, 0) + c = mask.where(a, b) + assert c.shape == (3, 2, 0) + np.testing.assert_equal(c.numpy(), np.where(mask.numpy(), a.numpy(), b.numpy())) + + def test_reduce_over_non_zero(self): + a = Tensor.ones(3, 2, 0).sum(axis=1) + assert a.shape == (3, 0) + np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=1)) + + def test_reduce_over_zero(self): + a = Tensor.ones(3, 2, 0).sum(axis=2) + assert a.shape == (3, 2) + np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=2)) + + a = Tensor.ones(3, 2, 0).sum(axis=2, keepdim=True) + assert a.shape == (3, 2, 1) + np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=2, keepdims=True)) + + def test_reduce_default(self): + np.testing.assert_equal(Tensor([]).max().numpy(), -float("inf")) + np.testing.assert_equal(Tensor([]).min().numpy(), float("inf")) + np.testing.assert_equal(Tensor([]).sum().numpy(), 0) + np.testing.assert_equal(Tensor([]).mean().numpy(), 0) + + +if __name__ == "__main__": + unittest.main()