From 2018a0847caf6a175bac79fb1ff01dc5aaed67c7 Mon Sep 17 00:00:00 2001 From: PromptEngineer <134474669+PromtEngineer@users.noreply.github.com> Date: Sun, 4 Feb 2024 00:14:32 -0800 Subject: [PATCH] Update run_localGPT.py --- run_localGPT.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/run_localGPT.py b/run_localGPT.py index 5f01a4e7..24a1978a 100644 --- a/run_localGPT.py +++ b/run_localGPT.py @@ -130,7 +130,7 @@ def get_embeddings(): if "instructor" in EMBEDDING_MODEL_NAME: return HuggingFaceInstructEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, + model_kwargs={"device": device_type}, embed_instruction='Represent the document for retrieval:', query_instruction='Represent the question for retrieving supporting documents:' ) @@ -138,14 +138,14 @@ def get_embeddings(): elif "bge" in EMBEDDING_MODEL_NAME: return HuggingFaceBgeEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, + model_kwargs={"device": device_type}, query_instruction='Represent this sentence for searching relevant passages:' ) else: return HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, + model_kwargs={"device": device_type}, ) embeddings = get_embeddings() logging.info(f"Loaded embeddings from {EMBEDDING_MODEL_NAME}")