Skip to content

Commit

Permalink
Added --model-dir option
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed Sep 12, 2023
1 parent ff2b2fd commit 214b7f8
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 26 deletions.
5 changes: 3 additions & 2 deletions llama_api/schemas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Literal, Optional

from ..modules.base import BaseLLMModel
from ..shared.config import MainCliArgs
from ..utils.path import path_resolver


Expand Down Expand Up @@ -104,7 +105,7 @@ def __post_init__(self) -> None:
def model_path_resolved(self) -> str:
return path_resolver(
self.model_path,
default_relative_directory="models/ggml",
default_model_directory=MainCliArgs.model_dir.value,
)


Expand Down Expand Up @@ -170,7 +171,7 @@ def __post_init__(self) -> None:
def model_path_resolved(self) -> str:
return path_resolver(
self.model_path,
default_relative_directory="models/gptq",
default_model_directory=MainCliArgs.model_dir.value,
)


Expand Down
6 changes: 6 additions & 0 deletions llama_api/shared/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class MainCliArgs(AppSettingsCliArgs):
short_option="t",
help="Tunnel the server through cloudflared",
)
model_dir: CliArg[str] = CliArg(
type=str,
short_option="m",
help="Directory to store models; default is `./models`",
default="./models",
)
# xformers: CliArg[bool] = CliArg(
# type=bool,
# action="store_true",
Expand Down
81 changes: 57 additions & 24 deletions llama_api/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def __init__(
@property
def model_type(self) -> Literal["ggml", "gptq"]:
"""Get the model type: ggml or gptq."""
classifications: List[Classification] = self.hf_info["classifications"]
classifications: List[Classification] = self.hf_info[
"classifications"
]
if "ggml" in classifications:
return "ggml"
elif (
Expand Down Expand Up @@ -107,7 +109,12 @@ def preferred_ggml_files(self) -> List[str]:
if self.is_ggml(file_name)
]
if not ggml_file_names:
raise FileNotFoundError("No GGML file found.")
raise FileNotFoundError(
"No GGML file found in following links:"
+ "\n".join(
f"- {link}" for link in self.hf_info["file_names"]
)
)

# Sort the GGML files by the preferences
# Return the most preferred GGML file, or the first one if none of the
Expand Down Expand Up @@ -140,7 +147,9 @@ def resolve(self) -> str:
(
link
for link in self.hf_info["links"]
if any(ggml in link for ggml in self.preferred_ggml_files)
if any(
ggml in link for ggml in self.preferred_ggml_files
)
),
None,
)
Expand All @@ -155,12 +164,34 @@ def resolve(self) -> str:

# The model is not downloaded, and the download failed
raise FileNotFoundError(
f"`{model_path.name}` not found in {model_path.parent}"
f"`{model_path.name}` not found in {model_path.resolve()}"
)


def _make_model_dir_candidates(path: str) -> "set[Path]":
return {
dir_path.resolve()
for dir_path in (
Path(path),
Path(path) / "ggml",
Path(path) / "gguf",
Path(path) / "gptq",
Config.project_root,
Config.project_root / path,
Config.project_root / path / "ggml",
Config.project_root / path / "gguf",
Config.project_root / path / "gptq",
Path.cwd(),
Path.cwd() / path,
Path.cwd() / path / "ggml",
Path.cwd() / path / "gguf",
Path.cwd() / path / "gptq",
)
}


def resolve_model_path_to_posix(
model_path: str, default_relative_directory: Optional[str] = None
model_path: str, default_model_directory: Optional[str] = None
) -> str:
"""Resolve a model path to a POSIX path."""
with logger.log_any_error("Error resolving model path"):
Expand All @@ -171,20 +202,14 @@ def resolve_model_path_to_posix(
logger.info(f"`{path.name}` found in {path.parent}")
return path.resolve().as_posix()
raise FileNotFoundError(
f"`{path.name}` not found in {path.parent}"
f"`{path.name}` not found in {path.resolve()}"
)

parent_dir_candidates = [
Config.project_root / "models",
Config.project_root / "llama_api",
Config.project_root,
Path.cwd(),
]

if default_relative_directory is not None:
parent_dir_candidates = _make_model_dir_candidates("models")
if default_model_directory is not None:
# Add the default relative directory to the list of candidates
parent_dir_candidates.insert(
0, Path.cwd() / Path(default_relative_directory)
parent_dir_candidates.update(
_make_model_dir_candidates(default_model_directory)
)

# Try to find the model in all possible scenarios
Expand All @@ -196,15 +221,19 @@ def resolve_model_path_to_posix(
if model_path.count("/") != 1:
raise FileNotFoundError(
f"`{model_path}` not found in any of the following "
f"directories: {parent_dir_candidates}"
"directories:\n"
+ "\n".join(
f"- {(parent_dir / model_path).resolve()}"
for parent_dir in parent_dir_candidates
)
)
# Try to resolve the model path from Huggingface
return HuggingfaceResolver(model_path).resolve()


def resolve_model_path_to_posix_with_cache(
model_path: str,
default_relative_directory: Optional[str] = None,
default_model_directory: Optional[str] = None,
) -> str:
"""Resolve a model path to a POSIX path, with caching."""
from filelock import FileLock, Timeout
Expand All @@ -229,9 +258,13 @@ def resolve_model_path_to_posix_with_cache(
f"Invalid cache entry for model path `{model_path}`: "
f"{resolved}"
)
if not resolved:
if not resolved or not Path(resolved).exists():
unresolved = resolved
resolved = resolve_model_path_to_posix(
model_path, default_relative_directory
model_path, default_model_directory
)
logger.warning(
f"Model path `{unresolved}` resolved to `{resolved}`"
)
cache[model_path] = resolved

Expand All @@ -247,19 +280,19 @@ def resolve_model_path_to_posix_with_cache(
+ f": {e}"
)
return resolve_model_path_to_posix(
model_path, default_relative_directory
model_path, default_model_directory
)


def path_resolver(
model_path: str, default_relative_directory: Optional[str] = None
model_path: str, default_model_directory: Optional[str] = None
) -> str:
"""Resolve a model path to a POSIX path, with caching if possible."""
try:
return resolve_model_path_to_posix_with_cache(
model_path, default_relative_directory
model_path, default_model_directory
)
except ImportError:
return resolve_model_path_to_posix(
model_path, default_relative_directory
model_path, default_model_directory
)

0 comments on commit 214b7f8

Please sign in to comment.