Skip to content

Commit

Permalink
Add serialization & deserialization support of container registrations.
Browse files Browse the repository at this point in the history
  • Loading branch information
runemalm committed Aug 8, 2024
1 parent 748d626 commit 7bf15f8
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/dependency_injection/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, name: str = None):
self._registrations = {}
self._singleton_instances = {}
self._scoped_instances = {}
self._has_resolved = False

@classmethod
def get_instance(cls, name: str = None) -> Self:
Expand All @@ -30,6 +31,16 @@ def get_instance(cls, name: str = None) -> Self:

return cls._instances[(cls, name)]

def get_registrations(self) -> Dict[Type, Registration]:
return self._registrations

def set_registrations(self, registrations) -> None:
if self._has_resolved:
raise Exception(
"You can't set registrations after a dependency has been resolved."
)
self._registrations = registrations

def register_transient(
self,
dependency: Type,
Expand Down Expand Up @@ -99,6 +110,8 @@ def register_instance(
self._singleton_instances[dependency] = instance

def resolve(self, dependency: Type, scope_name: str = DEFAULT_SCOPE_NAME) -> Type:
self._has_resolved = True

if scope_name not in self._scoped_instances:
self._scoped_instances[scope_name] = {}

Expand Down
15 changes: 15 additions & 0 deletions src/dependency_injection/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pickle

from typing import Dict, Type

from dependency_injection.registration import Registration


class RegistrationSerializer:
@staticmethod
def serialize(registrations) -> bytes:
return pickle.dumps(registrations)

@staticmethod
def deserialize(serialized_state) -> Dict[Type, Registration]:
return pickle.loads(serialized_state)
32 changes: 32 additions & 0 deletions tests/unit_test/container/register/test_get_registrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dependency_injection.container import DependencyContainer
from dependency_injection.scope import Scope
from unit_test.unit_test_case import UnitTestCase


class TestGetRegistrations(UnitTestCase):
def test_get_registrations_returns_empty_dict_initially(self):
# arrange
dependency_container = DependencyContainer.get_instance()

# act
registrations = dependency_container.get_registrations()

# assert
self.assertEqual(registrations, {})

def test_get_registrations_returns_correct_registrations(self):
# arrange
class Vehicle:
pass

dependency_container = DependencyContainer.get_instance()
dependency_container.register_transient(Vehicle)

# act
registrations = dependency_container.get_registrations()

# assert
self.assertIn(Vehicle, registrations)
self.assertEqual(registrations[Vehicle].dependency, Vehicle)
self.assertEqual(registrations[Vehicle].implementation, Vehicle)
self.assertEqual(registrations[Vehicle].scope, Scope.TRANSIENT)
43 changes: 43 additions & 0 deletions tests/unit_test/container/register/test_set_registrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from dependency_injection.container import DependencyContainer
from unit_test.unit_test_case import UnitTestCase


class TestSetRegistrations(UnitTestCase):
def test_set_registrations_before_first_resolution(self):
# arrange
class Vehicle:
pass

dummy_container = DependencyContainer.get_instance("dummy_container")
dummy_container.register_transient(Vehicle)
new_registrations = dummy_container.get_registrations()

container = DependencyContainer.get_instance()

# act
container.set_registrations(new_registrations) # no exception

def test_not_allowed_to_set_registrations_after_first_resolution(self):
# arrange
class Vehicle:
pass

class Fruit:
pass

dummy_container = DependencyContainer.get_instance("dummy_container")
dummy_container.register_transient(Vehicle)
new_registrations = dummy_container.get_registrations()

container = DependencyContainer.get_instance()
container.register_transient(Fruit)
container.resolve(Fruit)

# act & assert
with self.assertRaises(Exception) as context:
container.set_registrations(new_registrations)

self.assertIn(
"You can't set registrations after a dependency has been resolved.",
str(context.exception),
)
Empty file.
63 changes: 63 additions & 0 deletions tests/unit_test/serialization/test_registration_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pickle

from unit_test.unit_test_case import UnitTestCase

from dependency_injection.serialization import RegistrationSerializer
from dependency_injection.registration import Registration
from dependency_injection.scope import Scope


class Vehicle:
pass


class Car:
def __init__(self, vehicle: Vehicle):
self.vehicle = vehicle


class TestRegistrationSerializer(UnitTestCase):
def test_serialize_and_deserialize(self):
# arrange
registrations = {
Vehicle: Registration(
dependency=Vehicle,
implementation=Car,
scope=Scope.TRANSIENT,
tags={"example_tag"},
constructor_args={"vehicle": Vehicle()},
factory=None,
factory_args={},
)
}

# act
serialized = RegistrationSerializer.serialize(registrations)
deserialized = RegistrationSerializer.deserialize(serialized)

# assert
self.assertEqual(deserialized.keys(), registrations.keys())
self.assertEqual(
deserialized[Vehicle].dependency, registrations[Vehicle].dependency
)
self.assertEqual(
deserialized[Vehicle].implementation, registrations[Vehicle].implementation
)
self.assertEqual(deserialized[Vehicle].scope, registrations[Vehicle].scope)
self.assertEqual(deserialized[Vehicle].tags, registrations[Vehicle].tags)
self.assertEqual(
deserialized[Vehicle].constructor_args.keys(),
registrations[Vehicle].constructor_args.keys(),
)
self.assertEqual(deserialized[Vehicle].factory, registrations[Vehicle].factory)
self.assertEqual(
deserialized[Vehicle].factory_args, registrations[Vehicle].factory_args
)

def test_deserialize_invalid_data(self):
# arrange
invalid_data = b"not a valid pickle"

# act & assert
with self.assertRaises(pickle.UnpicklingError):
RegistrationSerializer.deserialize(invalid_data)

0 comments on commit 7bf15f8

Please sign in to comment.