Skip to content

Commit

Permalink
Add AssetStore to dependency graph
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Sep 24, 2024
1 parent 9cc7a9b commit 8146616
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 132 deletions.
3 changes: 3 additions & 0 deletions src/fairseq2/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
PackageAssetMetadataProvider as PackageAssetMetadataProvider,
)
from fairseq2.assets.metadata_provider import load_metadata_file as load_metadata_file
from fairseq2.assets.metadata_provider import (
register_package_metadata_provider as register_package_metadata_provider,
)
from fairseq2.assets.store import AssetStore as AssetStore
from fairseq2.assets.store import EnvironmentResolver as EnvironmentResolver
from fairseq2.assets.store import StandardAssetStore as StandardAssetStore
Expand Down
126 changes: 106 additions & 20 deletions src/fairseq2/assets/metadata_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Sequence
from copy import deepcopy
from pathlib import Path
from typing import Any, NoReturn, final
from typing import Any, Literal, NoReturn, final

import yaml
from importlib_resources import files
Expand All @@ -20,6 +20,11 @@
from yaml import YAMLError

from fairseq2.assets.error import AssetError
from fairseq2.dependency import DependencyContainer, DependencyResolver
from fairseq2.logging import get_log_writer
from fairseq2.utils.env import get_path_from_env

log = get_log_writer(__name__)


class AssetMetadataProvider(ABC):
Expand All @@ -41,15 +46,27 @@ def get_names(self) -> list[str]:
def clear_cache(self) -> None:
"""Clear any cached asset metadata."""

@property
@abstractmethod
def scope(self) -> str:
"""The scope (e.g. user or global) of this provider."""


class AbstractAssetMetadataProvider(AssetMetadataProvider):
"""Provides a skeletal implementation of :class:`AssetMetadataProvider`."""

_cache: dict[str, dict[str, Any]] | None
_scope: str

def __init__(self) -> None:
def __init__(self, *, scope: Literal["global", "user"] = "global") -> None:
"""
:param scope:
The scope of the provider.
"""
self._cache = None

self._scope = scope

@final
@override
def get_metadata(self, name: str) -> dict[str, Any]:
Expand All @@ -66,7 +83,7 @@ def get_metadata(self, name: str) -> dict[str, Any]:
return deepcopy(metadata)
except Exception as ex:
raise AssetMetadataError(
f"The metadata of the asset '{name}' cannot be copied. See nested exception for details and please file a bug report to the asset owner."
f"The metadata of the asset '{name}' cannot be used. Please file a bug report to the asset owner."
) from ex

@final
Expand All @@ -93,19 +110,29 @@ def _ensure_cache_loaded(self) -> dict[str, dict[str, Any]]:
def _load_cache(self) -> dict[str, dict[str, Any]]:
...

@final
@property
@override
def scope(self) -> str:
return self._scope


@final
class FileAssetMetadataProvider(AbstractAssetMetadataProvider):
"""Provides asset metadata stored on a file system."""

_base_dir: Path

def __init__(self, base_dir: Path) -> None:
def __init__(
self, base_dir: Path, *, scope: Literal["global", "user"] = "global"
) -> None:
"""
:param base_dir:
The base directory under which the asset metadata is stored.
:param scope:
The scope of the provider.
"""
super().__init__()
super().__init__(scope=scope)

self._base_dir = base_dir.expanduser().resolve()

Expand Down Expand Up @@ -149,12 +176,16 @@ class PackageAssetMetadataProvider(AbstractAssetMetadataProvider):
_package_name: str
_package_path: MultiplexedPath

def __init__(self, package_name: str) -> None:
def __init__(
self, package_name: str, scope: Literal["global", "user"] = "global"
) -> None:
"""
:param package_name:
The name of the package in which the asset metadata is stored.
:param scope:
The scope of the provider.
"""
super().__init__()
super().__init__(scope=scope)

self._package_name = package_name

Expand All @@ -171,7 +202,7 @@ def _load_cache(self) -> dict[str, dict[str, Any]]:
for name, metadata in load_metadata_file(file):
if name in cache:
raise AssetMetadataError(
f"Two assets under the namespace package '{self._package_name}' have the same name '{name}'."
f"Two assets under the package '{self._package_name}' have the same name '{name}'."
)

metadata["__source__"] = f"package:{self._package_name}"
Expand Down Expand Up @@ -222,21 +253,21 @@ def load_metadata_file(file: Path) -> list[tuple[str, dict[str, Any]]]:
for idx, metadata in enumerate(all_metadata):
if not isinstance(metadata, dict):
raise AssetMetadataError(
f"The asset metadata at index {idx} in {file} has an invalid format."
f"The asset metadata at index {idx} in the file '{file}' has an invalid format."
)

try:
name = metadata.pop("name")
except KeyError:
raise AssetMetadataError(
f"The asset metadata at index {idx} in {file} does not have a name entry."
f"The asset metadata at index {idx} in the file {file} does not have a name entry."
) from None

try:
canonical_name = _canonicalize_name(name)
except ValueError as ex:
raise AssetMetadataError(
f"The asset metadata at index {idx} in {file} has an invalid name. See nested exception for details."
f"The asset metadata at index {idx} in the file {file} has an invalid name. See nested exception for details."
) from ex

metadata["__base_path__"] = file.parent
Expand All @@ -250,19 +281,18 @@ def load_metadata_file(file: Path) -> list[tuple[str, dict[str, Any]]]:
class InProcAssetMetadataProvider(AssetMetadataProvider):
"""Provides asset metadata stored in memory."""

_name: str | None
_metadata: dict[str, dict[str, Any]]
_scope: str

def __init__(
self, metadata: Sequence[dict[str, Any]], *, name: str | None = None
self,
metadata: Sequence[dict[str, Any]],
*,
scope: Literal["global", "user"] = "global",
) -> None:
self._name = name
self._metadata = {}

source = "inproc"
super().__init__()

if name is not None:
source = f"{source}:{name}"
self._metadata = {}

for idx, metadata_ in enumerate(metadata):
try:
Expand All @@ -284,10 +314,12 @@ def __init__(
f"Two assets in `metadata` have the same name '{canonical_name}'."
)

metadata_["__source__"] = source
metadata_["__source__"] = "inproc"

self._metadata[canonical_name] = metadata_

self._scope = scope

@override
def get_metadata(self, name: str) -> dict[str, Any]:
try:
Expand All @@ -305,6 +337,11 @@ def get_names(self) -> list[str]:
def clear_cache(self) -> None:
pass

@override
@property
def scope(self) -> str:
return self._scope


def _canonicalize_name(name: Any) -> str:
if not isinstance(name, str):
Expand Down Expand Up @@ -343,3 +380,52 @@ def name(self) -> str:

class AssetMetadataError(AssetError):
"""Raised when an asset metadata operation fails."""


def register_objects(container: DependencyContainer) -> None:
container.register_factory(AssetMetadataProvider, _create_package_metadata_provider)
container.register_factory(AssetMetadataProvider, _create_etc_dir_metadata_provider)
container.register_factory(AssetMetadataProvider, _create_cfg_dir_metadata_provider)


def _create_package_metadata_provider(
resolver: DependencyResolver,
) -> AssetMetadataProvider:
return PackageAssetMetadataProvider("fairseq2.assets.cards")


def _create_etc_dir_metadata_provider(
resolver: DependencyResolver,
) -> AssetMetadataProvider | None:
asset_dir = get_path_from_env("FAIRSEQ2_ASSET_DIR", log)
if asset_dir is None:
asset_dir = Path("/etc/fairseq2/assets").resolve()
if not asset_dir.exists():
return None

return FileAssetMetadataProvider(asset_dir)


def _create_cfg_dir_metadata_provider(
resolver: DependencyResolver,
) -> AssetMetadataProvider | None:
asset_dir = get_path_from_env("FAIRSEQ2_USER_ASSET_DIR", log)
if asset_dir is None:
asset_dir = get_path_from_env("XDG_CONFIG_HOME", log)
if asset_dir is None:
asset_dir = Path("~/.config").expanduser()

asset_dir = asset_dir.joinpath("fairseq2/assets").resolve()
if not asset_dir.exists():
return None

return FileAssetMetadataProvider(asset_dir, scope="user")


def register_package_metadata_provider(
container: DependencyContainer, package_name: str
) -> None:
def create(resolver: DependencyResolver) -> AssetMetadataProvider:
return PackageAssetMetadataProvider(package_name)

container.register_factory(AssetMetadataProvider, create)
Loading

0 comments on commit 8146616

Please sign in to comment.