Skip to content

Commit

Permalink
Extend CLI to accept a dependency graph (#806)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Sep 24, 2024
1 parent 9fbbbc7 commit 3e2d1ae
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 74 deletions.
51 changes: 37 additions & 14 deletions src/fairseq2/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os

from fairseq2.dependency import DependencyResolver, get_container
from fairseq2.logging import get_log_writer
from fairseq2.recipes.assets import _setup_asset_cli
from fairseq2.recipes.hg import _setup_hg_cli
Expand All @@ -30,12 +31,12 @@

def main() -> None:
"""Run the command line fairseq2 program."""
from fairseq2 import __version__, setup_extensions
from fairseq2 import __version__, setup_fairseq2

with exception_logger(log):
setup_basic_logging()

setup_extensions()
setup_fairseq2()

cli = Cli(
name="fairseq2",
Expand All @@ -44,12 +45,17 @@ def main() -> None:
description="command line interface of fairseq2",
)

_setup_cli(cli)
container = get_container()

cli()
_setup_cli(cli, container)
_setup_cli_extensions(cli, container)

_setup_legacy_cli_extensions(cli)

def _setup_cli(cli: Cli) -> None:
cli(container)


def _setup_cli(cli: Cli, resolver: DependencyResolver) -> None:
_setup_asset_cli(cli)
_setup_lm_cli(cli)
_setup_llama_cli(cli)
Expand All @@ -58,23 +64,40 @@ def _setup_cli(cli: Cli) -> None:
_setup_wav2vec2_asr_cli(cli)
_setup_hg_cli(cli)

# Set up 3rd party CLI extensions.

def _setup_cli_extensions(cli: Cli, resolver: DependencyResolver) -> None:
for entry_point in entry_points(group="fairseq2.extension.cli"):
try:
setup_cli = entry_point.load()

setup_cli(cli, resolver)
except TypeError:
raise RuntimeError(
f"The entry point '{entry_point.value}' is not a valid fairseq2 CLI extension function."
) from None
except Exception as ex:
if "FAIRSEQ2_EXTENSION_TRACE" in os.environ:
raise RuntimeError(
f"The CLI extension function at '{entry_point.value}' has failed. See nested exception for details."
) from ex

log.warning("The CLI extension function at '{}' has failed. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value) # fmt: skip


def _setup_legacy_cli_extensions(cli: Cli) -> None:
for entry_point in entry_points(group="fairseq2.cli"):
try:
setup_cli_extension = entry_point.load()
setup_cli = entry_point.load()

setup_cli_extension(cli)
setup_cli(cli)
except TypeError:
raise RuntimeError(
f"The entry point '{entry_point.value}' is not a valid fairseq2 CLI setup function."
f"The entry point '{entry_point.value}' is not a valid fairseq2 CLI extension function."
) from None
except Exception as ex:
if "FAIRSEQ2_EXTENSION_TRACE" in os.environ:
raise RuntimeError(
f"The CLI setup function at '{entry_point.value}' has failed. See nested exception for details."
f"The CLI extension function at '{entry_point.value}' has failed. See nested exception for details."
) from ex

log.warning(
"The CLI setup function at '{}' has failed. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.",
entry_point.value,
)
log.warning("The CLI extension function at '{}' has failed. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value) # fmt: skip
52 changes: 16 additions & 36 deletions src/fairseq2/recipes/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@
from rich.pretty import pretty_repr
from typing_extensions import override

from fairseq2.assets import (
AssetCard,
AssetNotFoundError,
AssetStore,
default_asset_store,
)
from fairseq2.assets import AssetCard, AssetNotFoundError, AssetStore
from fairseq2.console import get_console
from fairseq2.data.text import is_tokenizer_card
from fairseq2.datasets import is_dataset_card
from fairseq2.dependency import DependencyContainer, DependencyResolver
from fairseq2.logging import get_log_writer
from fairseq2.models import is_model_card
from fairseq2.recipes.cli import Cli, CliCommandHandler
Expand Down Expand Up @@ -53,18 +49,8 @@ def _setup_asset_cli(cli: Cli) -> None:
class ListAssetsCommand(CliCommandHandler):
"""Lists assets available in the current Python environment."""

_asset_store: AssetStore

def __init__(self, asset_store: AssetStore | None = None) -> None:
"""
:param asset_store:
The asset store from which to retrieve the asset cards. If ``None``,
the default asset store will be used.
"""
self._asset_store = asset_store or default_asset_store

@override
def init_parser(self, parser: ArgumentParser) -> None:
def init_parser(self, parser: ArgumentParser, resolver: DependencyResolver) -> None:
parser.add_argument(
"--type",
choices=["all", "model", "dataset", "tokenizer"],
Expand All @@ -73,9 +59,11 @@ def init_parser(self, parser: ArgumentParser) -> None:
)

@override
def __call__(self, args: Namespace) -> None:
usr_assets = self._retrieve_assets(args, user=True)
glb_assets = self._retrieve_assets(args, user=False)
def __call__(self, args: Namespace, container: DependencyContainer) -> None:
asset_store = container.resolve(AssetStore)

usr_assets = self._retrieve_assets(asset_store, args, user=True)
glb_assets = self._retrieve_assets(asset_store, args, user=False)

console = get_console()

Expand All @@ -88,15 +76,15 @@ def __call__(self, args: Namespace) -> None:
self._dump_assets(console, glb_assets)

def _retrieve_assets(
self, args: Namespace, user: bool
self, asset_store: AssetStore, args: Namespace, user: bool
) -> list[tuple[str, list[str]]]:
assets: dict[str, list[str]] = defaultdict(list)

names = self._asset_store.retrieve_names(scope="user" if user else "global")
names = asset_store.retrieve_names(scope="user" if user else "global")

for name in names:
try:
card = self._asset_store.retrieve_card(
card = asset_store.retrieve_card(
name, scope="all" if user else "global"
)
except AssetNotFoundError:
Expand Down Expand Up @@ -162,18 +150,8 @@ def _dump_assets(
class ShowAssetCommand(CliCommandHandler):
"""Shows the metadata of an asset."""

_asset_store: AssetStore

def __init__(self, asset_store: AssetStore | None = None) -> None:
"""
:param asset_store:
The asset store from which to retrieve the asset cards. If ``None``,
the default asset store will be used.
"""
self._asset_store = asset_store or default_asset_store

@override
def init_parser(self, parser: ArgumentParser) -> None:
def init_parser(self, parser: ArgumentParser, resolver: DependencyResolver) -> None:
parser.add_argument(
"--env",
dest="envs",
Expand All @@ -192,9 +170,11 @@ def init_parser(self, parser: ArgumentParser) -> None:
parser.add_argument("name", help="name of the asset")

@override
def __call__(self, args: Namespace) -> None:
def __call__(self, args: Namespace, container: DependencyContainer) -> None:
asset_store = container.resolve(AssetStore)

try:
card: AssetCard | None = self._asset_store.retrieve_card(
card: AssetCard | None = asset_store.retrieve_card(
args.name, envs=args.envs, scope=args.scope
)
except AssetNotFoundError:
Expand Down
38 changes: 18 additions & 20 deletions src/fairseq2/recipes/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from fairseq2.config_registry import ConfigRegistry
from fairseq2.console import get_console, set_console
from fairseq2.dependency import DependencyContainer, DependencyResolver
from fairseq2.error import AlreadyExistsError
from fairseq2.logging import get_log_writer
from fairseq2.recipes.logging import setup_basic_logging, setup_logging
Expand Down Expand Up @@ -92,7 +93,7 @@ def add_group(
The help text of the command group.
"""
if name in self._groups:
raise ValueError(
raise AlreadyExistsError(
f"`name` must be a unique group name, but '{name}' is already registered."
)

Expand All @@ -111,7 +112,7 @@ def get_group(self, name: str) -> CliGroup:
f"`name` must be a registered group name, but '{name}' is not registered."
) from None

def init_parser(self, parser: ArgumentParser) -> None:
def init_parser(self, parser: ArgumentParser, resolver: DependencyResolver) -> None:
"""Initialize ``parser`` with program-specific arguments."""
parser.add_argument(
"--version", action="version", version=f"%(prog)s {self._version}"
Expand All @@ -132,18 +133,15 @@ def init_parser(self, parser: ArgumentParser) -> None:

sub_parser = sub_parsers.add_parser(group.name, help=help)

group.init_parser(sub_parser)
group.init_parser(sub_parser, resolver)

def __call__(self) -> None:
def __call__(self, container: DependencyContainer) -> None:
"""Run the program."""
set_console(Console(highlight=False))

self._run_command()

def _run_command(self) -> None:
parser = ArgumentParser(self._name, description=self._description)

self.init_parser(parser)
self.init_parser(parser, container)

args = parser.parse_args()

Expand All @@ -152,7 +150,7 @@ def _run_command(self) -> None:

sys.exit(2)

args.command(args)
args.command(args, container)

@property
def name(self) -> str:
Expand Down Expand Up @@ -279,7 +277,7 @@ def get_command(self, name: str) -> CliCommand:
f"`name` must be a registered command name, but '{name}' is not registered."
) from None

def init_parser(self, parser: ArgumentParser) -> None:
def init_parser(self, parser: ArgumentParser, resolver: DependencyResolver) -> None:
"""Initialize ``parser`` with command group-specific arguments."""
sub_parsers = parser.add_subparsers()

Expand All @@ -296,7 +294,7 @@ def init_parser(self, parser: ArgumentParser) -> None:

sub_parser = sub_parsers.add_parser(group.name, help=help)

group.init_parser(sub_parser)
group.init_parser(sub_parser, resolver)

for command in self._commands.values():
help = command.help
Expand All @@ -313,7 +311,7 @@ def init_parser(self, parser: ArgumentParser) -> None:

sub_parser.set_defaults(command=command)

command.init_parser(sub_parser)
command.init_parser(sub_parser, resolver)

@property
def name(self) -> str:
Expand Down Expand Up @@ -352,13 +350,13 @@ def __init__(
self._origin_module = origin_module
self._help = help

def init_parser(self, parser: ArgumentParser) -> None:
def init_parser(self, parser: ArgumentParser, resolver: DependencyResolver) -> None:
"""Initialize ``parser`` with command group-specific arguments."""
self._handler.init_parser(parser)
self._handler.init_parser(parser, resolver)

def __call__(self, args: Namespace) -> None:
def __call__(self, args: Namespace, container: DependencyContainer) -> None:
"""Run the command."""
self._handler(args)
self._handler(args, container)

@property
def name(self) -> str:
Expand All @@ -380,11 +378,11 @@ class CliCommandHandler(ABC):
"""Represents the handler of a command of a command line program."""

@abstractmethod
def init_parser(self, parser: ArgumentParser) -> None:
def init_parser(self, parser: ArgumentParser, resolver: DependencyResolver) -> None:
"""Initialize ``parser`` with command-specific arguments."""

@abstractmethod
def __call__(self, args: Namespace) -> None:
def __call__(self, args: Namespace, container: DependencyContainer) -> None:
"""Run the command."""


Expand Down Expand Up @@ -465,7 +463,7 @@ def __init__(
self._parser = None

@override
def init_parser(self, parser: ArgumentParser) -> None:
def init_parser(self, parser: ArgumentParser, resolver: DependencyResolver) -> None:
self._parser = parser

parser.add_argument(
Expand Down Expand Up @@ -533,7 +531,7 @@ def init_parser(self, parser: ArgumentParser) -> None:
)

@override
def __call__(self, args: Namespace) -> None:
def __call__(self, args: Namespace, container: DependencyContainer) -> None:
console = get_console()

setup_basic_logging(debug=args.debug)
Expand Down
5 changes: 3 additions & 2 deletions src/fairseq2/recipes/llama/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing_extensions import override

from fairseq2.console import get_error_console
from fairseq2.dependency import DependencyContainer, DependencyResolver
from fairseq2.logging import get_log_writer
from fairseq2.models.llama import load_llama_config
from fairseq2.models.llama.integ import convert_to_reference_checkpoint
Expand All @@ -32,7 +33,7 @@ class ConvertCheckpointCommandHandler(CliCommandHandler):
"""Converts fairseq2 LLaMA checkpoints to reference checkpoints."""

@override
def init_parser(self, parser: ArgumentParser) -> None:
def init_parser(self, parser: ArgumentParser, resolver: DependencyResolver) -> None:
parser.add_argument(
"--arch",
metavar="ARCH_NAME",
Expand All @@ -52,7 +53,7 @@ def init_parser(self, parser: ArgumentParser) -> None:
)

@override
def __call__(self, args: Namespace) -> None:
def __call__(self, args: Namespace, container: DependencyContainer) -> None:
if not args.input_dir.exists() or not args.input_dir.is_dir():
log.error("`input_dir` must be a directory.")

Expand Down
5 changes: 3 additions & 2 deletions src/fairseq2/recipes/lm/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from fairseq2.console import get_console
from fairseq2.data.text import load_text_tokenizer
from fairseq2.dependency import DependencyContainer, DependencyResolver
from fairseq2.gang import Gang
from fairseq2.generation import (
Chatbot,
Expand Down Expand Up @@ -43,7 +44,7 @@ class ChatbotCommandHandler(CliCommandHandler):
"""Runs a chatbot."""

@override
def init_parser(self, parser: ArgumentParser) -> None:
def init_parser(self, parser: ArgumentParser, resolver: DependencyResolver) -> None:
parser.add_argument(
"-m",
"--model",
Expand Down Expand Up @@ -107,7 +108,7 @@ def init_parser(self, parser: ArgumentParser) -> None:
)

@override
def __call__(self, args: Namespace) -> None:
def __call__(self, args: Namespace, container: DependencyContainer) -> None:
setup_basic_logging()

# Set up cluster-specific environment variables.
Expand Down

0 comments on commit 3e2d1ae

Please sign in to comment.