Skip to content

Commit

Permalink
Add simple wrapper for read_iterator pickleability (#754)
Browse files Browse the repository at this point in the history
  • Loading branch information
sysuresh committed Aug 20, 2024
1 parent 9af332f commit 2b3deda
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 0 deletions.
71 changes: 71 additions & 0 deletions src/fairseq2/data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from collections.abc import Callable, Iterator
from typing import TypeVar

from typing_extensions import Self, TypeAlias

from fairseq2.data import DataPipelineBuilder, read_iterator

T = TypeVar("T")

IteratorFactory: TypeAlias = Callable[[], Iterator[T]]


class IteratorPickleWrapper(Iterator[T]):
def __init__(self, iterator_factory: IteratorFactory[T]) -> None:
self._iterator_factory: IteratorFactory[T] = iterator_factory
self._iterator: Iterator[T] = self._iterator_factory()
self._counter = 0

def __iter__(self) -> Self:
return self

def __next__(self) -> T:
out = next(self._iterator)
self._counter += 1
return out

def __getstate__(self) -> tuple[IteratorFactory[T], int]:
return self._iterator_factory, self._counter

def __setstate__(self, state: tuple[IteratorFactory[T], int]) -> None:
self._iterator_factory, counter = state
self._iterator = self._iterator_factory()
for i in range(counter):
next(self._iterator)
self._counter = counter


def read_pickle_wrapped_iterator(
iterator_factory: IteratorFactory[T],
) -> DataPipelineBuilder:
"""Read each element of iterator generated by ``iterator_factory``.
If ``iterator_factory`` is not pickleable, then this function wraps the
iterator in ``IteratorPickleWrapper``, a simple class that increments
an internal ``_counter`` every time ``__next__(self)`` is called.
Upon pickling, this counter is saved, and upon unpickling, a new iterator
is generated from ``iterator_factory`` and ``__next__(self)`` is called
``counter`` many times. Note that this means the time complexity of
unpickling is linear in the number of times ``__next__(self)`` was called
prior to pickling.
:param iterator_factory:
The iterator factory.
"""

iterator = iterator_factory()
try:
return read_iterator(
iterator, reset_fn=lambda x: iterator_factory(), infinite=False
)
except TypeError as e:
if (
str(e)
!= "`iterator` is not pickleable; set `skip_pickling_check` to True to bypass (see `read_iterator` documentation for details)."
):
raise
return read_iterator(
IteratorPickleWrapper(iterator_factory),
reset_fn=lambda x: IteratorPickleWrapper(iterator_factory),
infinite=False,
)
52 changes: 52 additions & 0 deletions tests/unit/data/test_read_pickle_wrapped_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterator

import pytest

from fairseq2.data import read_iterator
from fairseq2.data.utils import read_pickle_wrapped_iterator


def example_generator() -> Iterator[int]:
for i in range(10):
yield i


class TestReadAndPickleWrapIterator:
def test_read_and_pickle_wrap_iterator_works(self) -> None:
with pytest.raises(TypeError):
read_iterator(
example_generator(),
reset_fn=lambda x: example_generator(),
infinite=False,
).and_return()

pipeline = read_pickle_wrapped_iterator(example_generator).and_return()

it = iter(pipeline)

assert next(it) == 0
assert next(it) == 1

state = pipeline.state_dict()

assert next(it) == 2
assert next(it) == 3
assert next(it) == 4

pipeline.load_state_dict(state)

assert next(it) == 2
assert next(it) == 3
assert next(it) == 4

pipeline.reset()

for _ in range(2):
assert list(pipeline) == [*range(10)]
pipeline.reset()

0 comments on commit 2b3deda

Please sign in to comment.