Skip to content

Commit

Permalink
If max_tokens is None, it maximizes.
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed Sep 13, 2023
1 parent 214b7f8 commit 749a93d
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 47 deletions.
7 changes: 5 additions & 2 deletions llama_api/mixins/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class CompletionStatus:
started_at: float = field(default_factory=time, init=False)
state: Literal["done", "interrupted"] = field(default="done", init=False)

# These fields are set by `accept_settings` method.
# These fields are set by `build_max_tokens` method.
input_text: str = field(default="", init=False)
input_tokens: int = field(default=0, init=False)

Expand Down Expand Up @@ -47,7 +47,10 @@ def get_finish_reason(
"""Get the finish reason for the completion."""
return (
"length"
if self.completion_status[request.completion_id].generated_tokens
if request.max_tokens is not None
and self.completion_status[
request.completion_id
].generated_tokens
>= request.max_tokens
else "stop"
if request.grammar is None
Expand Down
20 changes: 15 additions & 5 deletions llama_api/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,23 @@ def get_text_generator(
)
else:
prompt = request.prompt

# Build the settings for generating the text
self.build_max_tokens_from_settings(request, prompt)
self.build_stops_from_settings(request)
return self.generate_text(prompt, request)

def build_max_tokens_from_settings(
self,
request: Union[CreateChatCompletionRequest, CreateCompletionRequest],
prompt: str,
) -> int:
"""Build the max_tokens parameter for generating the text."""
prompt_ids = self.encode(prompt)
prompt_tokens = len(prompt_ids)
context_window = self.llm_model.max_total_tokens

if request.max_tokens is None:
request.max_tokens = context_window - prompt_tokens
if MainCliArgs.max_tokens_limit.value:
request.max_tokens = min(
request.max_tokens, MainCliArgs.max_tokens_limit.value
Expand Down Expand Up @@ -331,10 +344,7 @@ def get_text_generator(
completion_id = request.completion_id
self.completion_status[completion_id].input_text = prompt
self.completion_status[completion_id].input_tokens = prompt_tokens

# Cache the stops for later use of stop_checker
self.build_stops_from_settings(request)
return self.generate_text(prompt, request)
return request.max_tokens


class BaseEmbeddingGenerator(ABC):
Expand Down
1 change: 1 addition & 0 deletions llama_api/modules/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def _generate_text(
else None
) or None

assert settings.max_tokens is not None, "max_tokens must be set"
for _ in range(settings.max_tokens):
# If the generator was interrupted, stop the generation
if self.check_interruption(completion_status):
Expand Down
1 change: 1 addition & 0 deletions llama_api/modules/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def _generate_text(
_load_cache(client, client.cache, input_ids)
if self.check_interruption(completion_status):
return
assert settings.max_tokens is not None, "max_tokens must be set"
for _, token_id in zip(
range(settings.max_tokens),
client.generate(
Expand Down
5 changes: 2 additions & 3 deletions llama_api/schemas/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,8 @@ class TextGenerationSettings(BaseModel):
default_factory=lambda: f"cmpl-{str(uuid4())}",
description="The unique ID of the text generation",
)
max_tokens: int = Field(
default=128,
ge=1,
max_tokens: Optional[int] = Field(
default=None,
description="The maximum number of tokens to generate.",
)
temperature: float = Field(
Expand Down
2 changes: 1 addition & 1 deletion llama_api/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def context_length_exceeded(
return 400, ErrorResponse(
message=message.format(
context_window,
completion_tokens + prompt_tokens,
(completion_tokens or 0) + prompt_tokens,
prompt_tokens,
completion_tokens,
),
Expand Down
4 changes: 2 additions & 2 deletions llama_api/utils/model_definition_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from os import environ
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Tuple, Union
from typing import Dict, Tuple, Union


from ..schemas.api import (
Expand Down Expand Up @@ -180,7 +180,7 @@ def _collect_from_environs(

@classmethod
def _refresh_modules(cls) -> None:
model_definition_paths = [] # type: List[Path]
model_definition_paths = [] # type: list[Path]

for path in Path(".").glob(cls.MODULE_GLOB_PATTERN):
if path.stem == "model_definitions":
Expand Down
66 changes: 32 additions & 34 deletions llama_api/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,41 +194,40 @@ def resolve_model_path_to_posix(
model_path: str, default_model_directory: Optional[str] = None
) -> str:
"""Resolve a model path to a POSIX path."""
with logger.log_any_error("Error resolving model path"):
path = Path(model_path)
if path.is_absolute():
# The path is already absolute
if path.exists():
logger.info(f"`{path.name}` found in {path.parent}")
return path.resolve().as_posix()
raise FileNotFoundError(
f"`{path.name}` not found in {path.resolve()}"
)
path = Path(model_path)
if path.is_absolute():
# The path is already absolute
if path.exists():
logger.info(f"`{path.name}` found in {path.parent}")
return path.resolve().as_posix()
raise FileNotFoundError(
f"`{path.name}` not found in {path.resolve()}"
)

parent_dir_candidates = _make_model_dir_candidates("models")
if default_model_directory is not None:
# Add the default relative directory to the list of candidates
parent_dir_candidates.update(
_make_model_dir_candidates(default_model_directory)
)
parent_dir_candidates = _make_model_dir_candidates("models")
if default_model_directory is not None:
# Add the default relative directory to the list of candidates
parent_dir_candidates.update(
_make_model_dir_candidates(default_model_directory)
)

# Try to find the model in all possible scenarios
for parent_dir in parent_dir_candidates:
if (parent_dir / model_path).exists():
logger.info(f"`{path.name}` found in {parent_dir}")
return (parent_dir / model_path).resolve().as_posix()
# Try to find the model in all possible scenarios
for parent_dir in parent_dir_candidates:
if (parent_dir / model_path).exists():
logger.info(f"`{path.name}` found in {parent_dir}")
return (parent_dir / model_path).resolve().as_posix()

if model_path.count("/") != 1:
raise FileNotFoundError(
f"`{model_path}` not found in any of the following "
"directories:\n"
+ "\n".join(
f"- {(parent_dir / model_path).resolve()}"
for parent_dir in parent_dir_candidates
)
if model_path.count("/") != 1:
raise FileNotFoundError(
f"`{model_path}` not found in any of the following "
"directories:\n"
+ "\n".join(
f"- {(parent_dir / model_path).resolve()}"
for parent_dir in parent_dir_candidates
)
# Try to resolve the model path from Huggingface
return HuggingfaceResolver(model_path).resolve()
)
# Try to resolve the model path from Huggingface
return HuggingfaceResolver(model_path).resolve()


def resolve_model_path_to_posix_with_cache(
Expand Down Expand Up @@ -269,9 +268,8 @@ def resolve_model_path_to_posix_with_cache(
cache[model_path] = resolved

# Update the cache file
with logger.log_any_error("Error writing model path cache"):
with open(cache_file, "w") as f:
f.write(orjson.dumps(cache).decode())
with open(cache_file, "w") as f:
f.write(orjson.dumps(cache).decode())
return resolved
except (Timeout, TypeError) as e:
logger.warning(
Expand Down

0 comments on commit 749a93d

Please sign in to comment.