diff --git a/src/dependency_injection/container.py b/src/dependency_injection/container.py index abc7c71..ee7defa 100644 --- a/src/dependency_injection/container.py +++ b/src/dependency_injection/container.py @@ -19,6 +19,7 @@ def __init__(self, name: str = None): self._registrations = {} self._singleton_instances = {} self._scoped_instances = {} + self._has_resolved = False @classmethod def get_instance(cls, name: str = None) -> Self: @@ -30,6 +31,16 @@ def get_instance(cls, name: str = None) -> Self: return cls._instances[(cls, name)] + def get_registrations(self) -> Dict[Type, Registration]: + return self._registrations + + def set_registrations(self, registrations) -> None: + if self._has_resolved: + raise Exception( + "You can't set registrations after a dependency has been resolved." + ) + self._registrations = registrations + def register_transient( self, dependency: Type, @@ -99,6 +110,8 @@ def register_instance( self._singleton_instances[dependency] = instance def resolve(self, dependency: Type, scope_name: str = DEFAULT_SCOPE_NAME) -> Type: + self._has_resolved = True + if scope_name not in self._scoped_instances: self._scoped_instances[scope_name] = {} diff --git a/src/dependency_injection/serialization.py b/src/dependency_injection/serialization.py new file mode 100644 index 0000000..2bfbad0 --- /dev/null +++ b/src/dependency_injection/serialization.py @@ -0,0 +1,15 @@ +import pickle + +from typing import Dict, Type + +from dependency_injection.registration import Registration + + +class RegistrationSerializer: + @staticmethod + def serialize(registrations) -> bytes: + return pickle.dumps(registrations) + + @staticmethod + def deserialize(serialized_state) -> Dict[Type, Registration]: + return pickle.loads(serialized_state) diff --git a/tests/unit_test/container/register/test_get_registrations.py b/tests/unit_test/container/register/test_get_registrations.py new file mode 100644 index 0000000..5d9d50c --- /dev/null +++ b/tests/unit_test/container/register/test_get_registrations.py @@ -0,0 +1,32 @@ +from dependency_injection.container import DependencyContainer +from dependency_injection.scope import Scope +from unit_test.unit_test_case import UnitTestCase + + +class TestGetRegistrations(UnitTestCase): + def test_get_registrations_returns_empty_dict_initially(self): + # arrange + dependency_container = DependencyContainer.get_instance() + + # act + registrations = dependency_container.get_registrations() + + # assert + self.assertEqual(registrations, {}) + + def test_get_registrations_returns_correct_registrations(self): + # arrange + class Vehicle: + pass + + dependency_container = DependencyContainer.get_instance() + dependency_container.register_transient(Vehicle) + + # act + registrations = dependency_container.get_registrations() + + # assert + self.assertIn(Vehicle, registrations) + self.assertEqual(registrations[Vehicle].dependency, Vehicle) + self.assertEqual(registrations[Vehicle].implementation, Vehicle) + self.assertEqual(registrations[Vehicle].scope, Scope.TRANSIENT) diff --git a/tests/unit_test/container/register/test_set_registrations.py b/tests/unit_test/container/register/test_set_registrations.py new file mode 100644 index 0000000..04d6b46 --- /dev/null +++ b/tests/unit_test/container/register/test_set_registrations.py @@ -0,0 +1,43 @@ +from dependency_injection.container import DependencyContainer +from unit_test.unit_test_case import UnitTestCase + + +class TestSetRegistrations(UnitTestCase): + def test_set_registrations_before_first_resolution(self): + # arrange + class Vehicle: + pass + + dummy_container = DependencyContainer.get_instance("dummy_container") + dummy_container.register_transient(Vehicle) + new_registrations = dummy_container.get_registrations() + + container = DependencyContainer.get_instance() + + # act + container.set_registrations(new_registrations) # no exception + + def test_not_allowed_to_set_registrations_after_first_resolution(self): + # arrange + class Vehicle: + pass + + class Fruit: + pass + + dummy_container = DependencyContainer.get_instance("dummy_container") + dummy_container.register_transient(Vehicle) + new_registrations = dummy_container.get_registrations() + + container = DependencyContainer.get_instance() + container.register_transient(Fruit) + container.resolve(Fruit) + + # act & assert + with self.assertRaises(Exception) as context: + container.set_registrations(new_registrations) + + self.assertIn( + "You can't set registrations after a dependency has been resolved.", + str(context.exception), + ) diff --git a/tests/unit_test/serialization/__init__.py b/tests/unit_test/serialization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_test/serialization/test_registration_serializer.py b/tests/unit_test/serialization/test_registration_serializer.py new file mode 100644 index 0000000..9be563d --- /dev/null +++ b/tests/unit_test/serialization/test_registration_serializer.py @@ -0,0 +1,63 @@ +import pickle + +from unit_test.unit_test_case import UnitTestCase + +from dependency_injection.serialization import RegistrationSerializer +from dependency_injection.registration import Registration +from dependency_injection.scope import Scope + + +class Vehicle: + pass + + +class Car: + def __init__(self, vehicle: Vehicle): + self.vehicle = vehicle + + +class TestRegistrationSerializer(UnitTestCase): + def test_serialize_and_deserialize(self): + # arrange + registrations = { + Vehicle: Registration( + dependency=Vehicle, + implementation=Car, + scope=Scope.TRANSIENT, + tags={"example_tag"}, + constructor_args={"vehicle": Vehicle()}, + factory=None, + factory_args={}, + ) + } + + # act + serialized = RegistrationSerializer.serialize(registrations) + deserialized = RegistrationSerializer.deserialize(serialized) + + # assert + self.assertEqual(deserialized.keys(), registrations.keys()) + self.assertEqual( + deserialized[Vehicle].dependency, registrations[Vehicle].dependency + ) + self.assertEqual( + deserialized[Vehicle].implementation, registrations[Vehicle].implementation + ) + self.assertEqual(deserialized[Vehicle].scope, registrations[Vehicle].scope) + self.assertEqual(deserialized[Vehicle].tags, registrations[Vehicle].tags) + self.assertEqual( + deserialized[Vehicle].constructor_args.keys(), + registrations[Vehicle].constructor_args.keys(), + ) + self.assertEqual(deserialized[Vehicle].factory, registrations[Vehicle].factory) + self.assertEqual( + deserialized[Vehicle].factory_args, registrations[Vehicle].factory_args + ) + + def test_deserialize_invalid_data(self): + # arrange + invalid_data = b"not a valid pickle" + + # act & assert + with self.assertRaises(pickle.UnpicklingError): + RegistrationSerializer.deserialize(invalid_data)