diff --git a/README.md b/README.md index 6afb931..f1c3a59 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,8 @@ conda activate wikichat python -m spacy download en_core_web_sm # Spacy is only needed for certain WikiChat configurations ``` +If you see `Error: Redis lookup failed` after running the chatbot, it probably means Redis is not properly installed. You can try reinstalling it by following its [official documentation](https://redis.io/docs/latest/operate/oss_and_stack/install/install-redis/). + Keep this environment activated for all subsequent commands. Install Docker for your operating system by following the instructions at https://docs.docker.com/engine/install/. WikiChat uses Docker primarily for creating and serving vector databases for retrieval, specifically [🤗 Text Embedding Inference](https://github.com/huggingface/text-embeddings-inference) and [Qdrant](https://github.com/qdrant/qdrant). On recent Ubuntu versions, you can try running `inv install-docker`. For other operating systems, follow the instructions on the docker website. @@ -203,6 +205,8 @@ See `wikipedia_preprocessing/preprocess_html_dump.py` for details on how this is inv index-collection --collection-path --collection-name ``` +This command starts docker containers for [🤗 Text Embedding Inference](https://github.com/huggingface/text-embeddings-inference) (one per available GPU). By default, it uses the docker image compatible with NVIDIA GPUs with Ampere 80 architecture, e.g. A100. Support for some other GPUs is also available, but you would need to choose the right docker image from [available docker images](https://github.com/huggingface/text-embeddings-inference?tab=readme-ov-file#docker-images). + #### To upload a Qdrant index to 🤗 Hub: 1. Split the index into smaller parts: @@ -268,7 +272,7 @@ This script reads the topic (i.e., a Wikipedia title and article) from the corre ```bash inv simulate-users --num-dialogues 1 --num-turns 2 --simulation-mode passage --language en --subset head ``` -Depending on the engine you are using, this might take some time. The simulated dialogues and log files will be saved in `benchmark/simulated_dialogs/`. +Depending on the engine you are using, this might take some time. The simulated dialogues and log files will be saved in `benchmark/simulated_dialogues/`. You can also provide any of the pipeline parameters from above. You can experiment with different user characteristics by modifying `user_characteristics` in `benchmark/user_simulator.py`. diff --git a/backend_server.py b/backend_server.py index 172fdb2..43978d5 100644 --- a/backend_server.py +++ b/backend_server.py @@ -12,8 +12,8 @@ add_pipeline_arguments, check_pipeline_arguments, ) -from tasks.defaults import CHATBOT_DEFAULT_CONFIG from pipelines.utils import dict_to_command_line, get_logger +from tasks.defaults import CHATBOT_DEFAULT_CONFIG logger = get_logger(__name__) diff --git a/benchmark/user_simulator.py b/benchmark/user_simulator.py index a496456..6e6a2dd 100644 --- a/benchmark/user_simulator.py +++ b/benchmark/user_simulator.py @@ -154,7 +154,9 @@ def main(args): elif args.mode == "multihop": with open(args.input_file) as input_file: dialogue_inputs = json.load(input_file) - dialogue_inputs = repeat_dialogue_inputs(dialogue_inputs, args.num_dialogues) + dialogue_inputs = repeat_dialogue_inputs( + dialogue_inputs, args.num_dialogues + ) topics = [m["title_1"] + " and " + m["title_2"] for m in dialogue_inputs] else: raise ValueError("Unknown mode: %s" % args.mode) diff --git a/pipelines/pipeline_arguments.py b/pipelines/pipeline_arguments.py index ee14342..0f55fac 100644 --- a/pipelines/pipeline_arguments.py +++ b/pipelines/pipeline_arguments.py @@ -3,6 +3,7 @@ """ from chainlite.llm_config import GlobalVars + from pipelines.chatbot_config import PipelineEnum from tasks.defaults import CHATBOT_DEFAULT_CONFIG diff --git a/retrieval/add_payload_index.py b/retrieval/add_payload_index.py index 2cfe139..1bee3e2 100644 --- a/retrieval/add_payload_index.py +++ b/retrieval/add_payload_index.py @@ -1,4 +1,5 @@ import argparse + from qdrant_client import QdrantClient from qdrant_client.models import PayloadSchemaType diff --git a/retrieval/create_index.py b/retrieval/create_index.py index 83cc908..1278c5f 100644 --- a/retrieval/create_index.py +++ b/retrieval/create_index.py @@ -24,11 +24,10 @@ ) from tqdm import tqdm - sys.path.insert(0, "./") -from tasks.defaults import DEFAULT_QDRANT_COLLECTION_NAME from pipelines.utils import get_logger from retrieval.qdrant_index import QdrantIndex +from tasks.defaults import DEFAULT_QDRANT_COLLECTION_NAME logger = get_logger(__name__) @@ -243,7 +242,9 @@ def batch_generator(collection_file, embedding_batch_size): default=48, help="The size of each request sent to GPU. The actual batch size is `embedding_batch_size * num_embedding_workers`", ) - parser.add_argument("--collection_name", default=DEFAULT_QDRANT_COLLECTION_NAME, type=str) + parser.add_argument( + "--collection_name", default=DEFAULT_QDRANT_COLLECTION_NAME, type=str + ) parser.add_argument( "--index", action="store_true", @@ -258,12 +259,12 @@ def batch_generator(collection_file, embedding_batch_size): args = parser.parse_args() model_port = args.model_port - embedding_size = QdrantIndex.get_embedding_model_parameters(args.embedding_model_name)[ - "embedding_dimension" - ] - query_prefix = QdrantIndex.get_embedding_model_parameters(args.embedding_model_name)[ - "query_prefix" - ] + embedding_size = QdrantIndex.get_embedding_model_parameters( + args.embedding_model_name + )["embedding_dimension"] + query_prefix = QdrantIndex.get_embedding_model_parameters( + args.embedding_model_name + )["query_prefix"] if args.index: collection_size = 0 diff --git a/retrieval/qdrant_index.py b/retrieval/qdrant_index.py index 3953f8a..924884b 100644 --- a/retrieval/qdrant_index.py +++ b/retrieval/qdrant_index.py @@ -2,11 +2,12 @@ import math from time import time from typing import Any -from pydantic import BaseModel, Field + import numpy as np import onnxruntime as ort import torch from huggingface_hub import hf_hub_download +from pydantic import BaseModel, Field from qdrant_client import AsyncQdrantClient from qdrant_client.models import ( FieldCondition, diff --git a/retrieval/qdrant_snapshot.py b/retrieval/qdrant_snapshot.py index ac4e6d9..3eba17d 100644 --- a/retrieval/qdrant_snapshot.py +++ b/retrieval/qdrant_snapshot.py @@ -47,7 +47,9 @@ def main(): args = parser.parse_args() - qdrant_client = QdrantClient(url="http://localhost", port=6333, timeout=60, prefer_grpc=False) + qdrant_client = QdrantClient( + url="http://localhost", port=6333, timeout=60, prefer_grpc=False + ) if args.action == "save": qdrant_client.create_snapshot(collection_name=args.collection_name, wait=False) diff --git a/retrieval/retriever_server.py b/retrieval/retriever_server.py index 2c92368..7c3e6b1 100644 --- a/retrieval/retriever_server.py +++ b/retrieval/retriever_server.py @@ -8,7 +8,7 @@ from async_lru import alru_cache from fastapi import FastAPI, Request, Response -from pydantic import BaseModel, Field, field_validator, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, field_validator from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address diff --git a/retrieval/upload_folder_to_hf_hub.py b/retrieval/upload_folder_to_hf_hub.py index 061517e..4537659 100644 --- a/retrieval/upload_folder_to_hf_hub.py +++ b/retrieval/upload_folder_to_hf_hub.py @@ -1,6 +1,8 @@ import argparse + from huggingface_hub import upload_folder + def main(repo_id, folder_path): upload_folder( folder_path=folder_path, @@ -10,11 +12,16 @@ def main(repo_id, folder_path): multi_commits_verbose=True, ) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Upload a folder to HuggingFace Hub") - parser.add_argument("--folder_path", type=str, help="The path to the folder to upload") - parser.add_argument("--repo_id", type=str, help="The repository ID on HuggingFace Hub") - + parser.add_argument( + "--folder_path", type=str, help="The path to the folder to upload" + ) + parser.add_argument( + "--repo_id", type=str, help="The repository ID on HuggingFace Hub" + ) + args = parser.parse_args() - - main(args.repo_id, args.folder_path) \ No newline at end of file + + main(args.repo_id, args.folder_path) diff --git a/tasks/benchmark.py b/tasks/benchmark.py index 0a9e296..63f7403 100644 --- a/tasks/benchmark.py +++ b/tasks/benchmark.py @@ -48,7 +48,7 @@ def simulate_users( Accepts all parameters that `inv demo` accepts, plus a few additional parameters for the user simulator. """ - + pipeline_flags = ( f"--pipeline {pipeline} " f"--engine {engine} " @@ -95,7 +95,7 @@ def simulate_users( f"--mode {simulation_mode} " f"--input_file benchmark/topics/{input_file} " f"--num_turns {num_turns} " - f"--output_file benchmark/simulated_dialogs/{pipeline}_{subset}_{language}_{engine}.txt " + f"--output_file benchmark/simulated_dialogues/{pipeline}_{subset}_{language}_{engine}.txt " f"--language {language} " f"--no_logging" ) diff --git a/tasks/defaults.py b/tasks/defaults.py index e66780c..001324b 100644 --- a/tasks/defaults.py +++ b/tasks/defaults.py @@ -12,7 +12,7 @@ DEFAULT_EMBEDDING_MODEL_NAME = "BAAI/bge-m3" DEFAULT_WIKIPEDIA_DUMP_LANGUAGE = "en" DEFAULT_WORKDIR = "workdir" -DEFAULT_QDRANT_COLLECTION_NAME="wikipedia" +DEFAULT_QDRANT_COLLECTION_NAME = "wikipedia" DEFAULT_BACKEND_PORT = 5001 DEFAULT_REDIS_PORT = 6379 diff --git a/tasks/docker_utils.py b/tasks/docker_utils.py index 9997510..1579f45 100644 --- a/tasks/docker_utils.py +++ b/tasks/docker_utils.py @@ -1,9 +1,9 @@ import os +import sys from time import sleep import docker from invoke import task -import sys sys.path.insert(0, "./") from pipelines.utils import get_logger diff --git a/tasks/retrieval.py b/tasks/retrieval.py index d74e5de..7a9d6be 100644 --- a/tasks/retrieval.py +++ b/tasks/retrieval.py @@ -5,13 +5,13 @@ import threading from concurrent.futures import ThreadPoolExecutor from datetime import datetime +from typing import Optional import requests from bs4 import BeautifulSoup from huggingface_hub import snapshot_download from invoke import task from tqdm import tqdm -from typing import Optional from tasks.docker_utils import ( start_embedding_docker_container, @@ -21,19 +21,18 @@ from tasks.main import start_redis sys.path.insert(0, "./") +from pipelines.utils import get_logger from tasks.defaults import ( DEFAULT_EMBEDDING_MODEL_NAME, DEFAULT_EMBEDDING_MODEL_PORT, + DEFAULT_EMBEDDING_USE_ONNX, DEFAULT_NUM_GPUS, DEFAULT_QDRANT_COLLECTION_NAME, DEFAULT_RETRIEVER_PORT, - DEFAULT_EMBEDDING_USE_ONNX, DEFAULT_WIKIPEDIA_DUMP_LANGUAGE, DEFAULT_WORKDIR, ) -from pipelines.utils import get_logger - logger = get_logger(__name__) diff --git a/tasks/setup.py b/tasks/setup.py index 0e87190..1540a94 100644 --- a/tasks/setup.py +++ b/tasks/setup.py @@ -14,7 +14,7 @@ def setup_nvme(c): """ Set up an NVMe drive on the VM by performing the following steps. Only works on certain Linux distributions. - + 1. Installs the `nvme-cli` package to manage NVMe devices. 2. Lists available NVMe devices on the system. 3. Extracts NVMe device names from the listing output. diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index a36e2c5..49c44bb 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -6,7 +6,6 @@ sys.path.insert(0, "./") from backend_server import chat_profiles_dict -from tasks.defaults import CHATBOT_DEFAULT_CONFIG from pipelines.chatbot import create_chain, run_one_turn from pipelines.dialogue_state import DialogueState from pipelines.pipeline_arguments import ( @@ -14,6 +13,7 @@ check_pipeline_arguments, ) from pipelines.utils import dict_to_command_line +from tasks.defaults import CHATBOT_DEFAULT_CONFIG test_user_utterances = [ "Hi", # a turn that doesn't need retrieval diff --git a/tests/test_retriever.py b/tests/test_retriever.py index e709d40..4f4cb01 100644 --- a/tests/test_retriever.py +++ b/tests/test_retriever.py @@ -1,6 +1,7 @@ +import sys + import pytest from pydantic import ValidationError -import sys sys.path.insert(0, "./") from retrieval.retriever_server import QueryData diff --git a/wikipedia_preprocessing/get_all_wiki_sizes.py b/wikipedia_preprocessing/get_all_wiki_sizes.py index a23bfcb..c8f1705 100644 --- a/wikipedia_preprocessing/get_all_wiki_sizes.py +++ b/wikipedia_preprocessing/get_all_wiki_sizes.py @@ -6,7 +6,10 @@ parser = argparse.ArgumentParser() parser.add_argument( - "--wikipedia_date", type=str, required=True, help="Enter the date in the format yyyymmdd" + "--wikipedia_date", + type=str, + required=True, + help="Enter the date in the format yyyymmdd", ) args = parser.parse_args() diff --git a/wikipedia_preprocessing/upload_collections_to_hf_hub.py b/wikipedia_preprocessing/upload_collections_to_hf_hub.py index f2b7ab8..5eb54a9 100644 --- a/wikipedia_preprocessing/upload_collections_to_hf_hub.py +++ b/wikipedia_preprocessing/upload_collections_to_hf_hub.py @@ -1,12 +1,11 @@ import argparse -from huggingface_hub import HfApi import gzip -import shutil import os +import shutil +from huggingface_hub import HfApi from tqdm import tqdm - if __name__ == "__main__": parser = argparse.ArgumentParser( description="Upload preprocessed Wikipedia collection files to HuggingFace Hub"