Skip to content

Commit

Permalink
1. Integrated code for web search
Browse files Browse the repository at this point in the history
2. Add .env sample file

3. Optimized the keyword filtering logic, removed empty strings

4. Remove GPT response in search to improve search speed

5. Removed the search logic inside the chatgpt2api.py file, and deleted the gptsearch function inside the search_web_and_summary in agent.py.

6. Optimized the extraction prompt for search keywords.

7. Refactored the chatgpt2api.py code, using precise token calculation, and fixed the error in the gpt-3.5-turbo max token parameter.

8. Optimized log display
  • Loading branch information
yym68686 committed Dec 12, 2023
1 parent cf6dcb8 commit e2b0aec
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 228 deletions.
8 changes: 8 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
BOT_TOKEN=
API_URL=
API=
GOOGLE_API_KEY=
GOOGLE_CSE_ID=
claude_api_key=
ADMIN_LIST=
GROUP_LIST=
130 changes: 34 additions & 96 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ async def command_bot(update, context, language=None, prompt=translator_prompt,
if has_command:
message = ' '.join(context.args)
if prompt and has_command:
prompt = prompt.format(language)
if translator_prompt == prompt:
prompt = prompt.format(language)
message = prompt + message
if message:
if "claude" in config.GPT_ENGINE and config.ClaudeAPI:
Expand Down Expand Up @@ -128,70 +129,6 @@ async def getChatGPT(update, context, title, robot, message, chatid, messageid):
result = re.sub(r",", ',', result)
await context.bot.edit_message_text(chat_id=chatid, message_id=messageid, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True)

@decorators.GroupAuthorization
@decorators.Authorization
async def search(update, context, title, robot):
message = update.message.text if config.NICK is None else update.message.text[botNicKLength:].strip() if update.message.text[:botNicKLength].lower() == botNick else None
print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m")
if (len(context.args) == 0):
message = (
f"格式错误哦~,示例:\n\n"
f"`/search 今天的微博热搜有哪些?`\n\n"
f"👆点击上方命令复制格式\n\n"
)
await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)
return
message = ' '.join(context.args)
result = title
text = message
modifytime = 0
lastresult = ''
message = await context.bot.send_message(
chat_id=update.message.chat_id,
text="搜索中💭",
parse_mode='MarkdownV2',
reply_to_message_id=update.message.message_id,
)
messageid = message.message_id
get_answer = robot.search_summary
if not config.API or (config.USE_G4F and not config.SEARCH_USE_GPT):
import utils.gpt4free as gpt4free
get_answer = gpt4free.get_response

try:
for data in get_answer(text, convo_id=str(update.message.chat_id), pass_history=config.PASS_HISTORY):
result = result + data
tmpresult = result
modifytime = modifytime + 1
if re.sub(r"```", '', result).count("`") % 2 != 0:
tmpresult = result + "`"
if result.count("```") % 2 != 0:
tmpresult = result + "\n```"
if modifytime % 20 == 0 and lastresult != tmpresult:
if 'claude2' in title:
tmpresult = re.sub(r",", ',', tmpresult)
await context.bot.edit_message_text(chat_id=update.message.chat_id, message_id=messageid, text=escape(tmpresult), parse_mode='MarkdownV2', disable_web_page_preview=True)
lastresult = tmpresult
except Exception as e:
print('\033[31m')
print("response_msg", result)
print("error", e)
traceback.print_exc()
print('\033[0m')
if config.API:
robot.reset(convo_id=str(update.message.chat_id), system_prompt=config.systemprompt)
if "You exceeded your current quota, please check your plan and billing details." in str(e):
print("OpenAI api 已过期!")
await context.bot.delete_message(chat_id=update.message.chat_id, message_id=messageid)
messageid = ''
config.API = ''
result += f"`出错啦!{e}`"
print(result)
if lastresult != result and messageid:
if 'claude2' in title:
result = re.sub(r",", ',', result)
await context.bot.edit_message_text(chat_id=update.message.chat_id, message_id=messageid, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True)

@decorators.GroupAuthorization
@decorators.Authorization
async def image(update, context):
Expand Down Expand Up @@ -540,35 +477,35 @@ async def handle_pdf(update, context):
)
await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)

@decorators.GroupAuthorization
@decorators.Authorization
async def qa(update, context):
if (len(context.args) != 2):
message = (
f"格式错误哦~,需要两个参数,注意路径或者链接、问题之间的空格\n\n"
f"请输入 `/qa 知识库链接 要问的问题`\n\n"
f"例如知识库链接为 https://abc.com ,问题是 蘑菇怎么分类?\n\n"
f"则输入 `/qa https://abc.com 蘑菇怎么分类?`\n\n"
f"问题务必不能有空格,👆点击上方命令复制格式\n\n"
f"除了输入网址,同时支持本地知识库,本地知识库文件夹路径为 `./wiki`,问题是 蘑菇怎么分类?\n\n"
f"则输入 `/qa ./wiki 蘑菇怎么分类?`\n\n"
f"问题务必不能有空格,👆点击上方命令复制格式\n\n"
f"本地知识库目前只支持 Markdown 文件\n\n"
)
await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)
return
print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m")
await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING)
result = await docQA(context.args[0], context.args[1], get_doc_from_local)
print(result["answer"])
# source_url = set([i.metadata['source'] for i in result["source_documents"]])
# source_url = "\n".join(source_url)
# message = (
# f"{result['result']}\n\n"
# f"参考链接:\n"
# f"{source_url}"
# )
await context.bot.send_message(chat_id=update.message.chat_id, text=escape(result["answer"]), parse_mode='MarkdownV2', disable_web_page_preview=True)
# @decorators.GroupAuthorization
# @decorators.Authorization
# async def qa(update, context):
# if (len(context.args) != 2):
# message = (
# f"格式错误哦~,需要两个参数,注意路径或者链接、问题之间的空格\n\n"
# f"请输入 `/qa 知识库链接 要问的问题`\n\n"
# f"例如知识库链接为 https://abc.com ,问题是 蘑菇怎么分类?\n\n"
# f"则输入 `/qa https://abc.com 蘑菇怎么分类?`\n\n"
# f"问题务必不能有空格,👆点击上方命令复制格式\n\n"
# f"除了输入网址,同时支持本地知识库,本地知识库文件夹路径为 `./wiki`,问题是 蘑菇怎么分类?\n\n"
# f"则输入 `/qa ./wiki 蘑菇怎么分类?`\n\n"
# f"问题务必不能有空格,👆点击上方命令复制格式\n\n"
# f"本地知识库目前只支持 Markdown 文件\n\n"
# )
# await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)
# return
# print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m")
# await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING)
# result = await docQA(context.args[0], context.args[1], get_doc_from_local)
# print(result["answer"])
# # source_url = set([i.metadata['source'] for i in result["source_documents"]])
# # source_url = "\n".join(source_url)
# # message = (
# # f"{result['result']}\n\n"
# # f"参考链接:\n"
# # f"{source_url}"
# # )
# await context.bot.send_message(chat_id=update.message.chat_id, text=escape(result["answer"]), parse_mode='MarkdownV2', disable_web_page_preview=True)

async def start(update, context): # 当用户输入/start时,返回文本
user = update.effective_user
Expand Down Expand Up @@ -617,13 +554,14 @@ async def post_init(application: Application) -> None:

application.add_handler(CommandHandler("start", start))
application.add_handler(CommandHandler("pic", image))
application.add_handler(CommandHandler("search", lambda update, context: search(update, context, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot)))
application.add_handler(CommandHandler("search", lambda update, context: command_bot(update, context, prompt="search: ", title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot, has_command="search")))
# application.add_handler(CommandHandler("search", lambda update, context: search(update, context, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot)))
application.add_handler(CallbackQueryHandler(button_press))
application.add_handler(CommandHandler("reset", reset_chat))
application.add_handler(CommandHandler("en2zh", lambda update, context: command_bot(update, context, config.LANGUAGE, robot=config.translate_bot)))
application.add_handler(CommandHandler("zh2en", lambda update, context: command_bot(update, context, "english", robot=config.translate_bot)))
application.add_handler(CommandHandler("info", info))
application.add_handler(CommandHandler("qa", qa))
# application.add_handler(CommandHandler("qa", qa))
application.add_handler(MessageHandler(filters.Document.PDF | filters.Document.TXT | filters.Document.DOC, handle_pdf))
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, lambda update, context: command_bot(update, context, prompt=None, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot, has_command=False)))
application.add_handler(MessageHandler(filters.COMMAND, unknown))
Expand Down
94 changes: 94 additions & 0 deletions test/test_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import tiktoken
from utils.function_call import function_call_list
import config
import requests
import json
import re

from dotenv import load_dotenv
load_dotenv()

def get_token_count(messages) -> int:
tiktoken.get_encoding("cl100k_base")
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")

num_tokens = 0
for message in messages:
# every message follows <im_start>{role/name}\n{content}<im_end>\n
num_tokens += 5
for key, value in message.items():
if value:
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += 5 # role is always required and always 1 token
num_tokens += 5 # every reply is primed with <im_start>assistant
return num_tokens
# print(get_token_count(message_list))



def get_message_token(url, json_post):
headers = {"Authorization": f"Bearer {os.environ.get('API', None)}"}
response = requests.Session().post(
url,
headers=headers,
json=json_post,
timeout=None,
)
if response.status_code != 200:
json_response = json.loads(response.text)
string = json_response["error"]["message"]
print(string)
string = re.findall(r"\((.*?)\)", string)[0]
numbers = re.findall(r"\d+\.?\d*", string)
numbers = [int(i) for i in numbers]
if len(numbers) == 2:
return {
"messages": numbers[0],
"total": numbers[0],
}
elif len(numbers) == 3:
return {
"messages": numbers[0],
"functions": numbers[1],
"total": numbers[0] + numbers[1],
}
else:
raise Exception("Unknown error")


if __name__ == "__main__":
# message_list = [{'role': 'system', 'content': 'You are ChatGPT, a large language model trained by OpenAI. Respond conversationally in Simplified Chinese. Knowledge cutoff: 2021-09. Current date: [ 2023-12-12 ]'}, {'role': 'user', 'content': 'hi'}]
messages = [{'role': 'system', 'content': 'You are ChatGPT, a large language model trained by OpenAI. Respond conversationally in Simplified Chinese. Knowledge cutoff: 2021-09. Current date: [ 2023-12-12 ]'}, {'role': 'user', 'content': 'hi'}, {'role': 'assistant', 'content': '你好!有什么我可以帮助你的吗?'}]

model = "gpt-3.5-turbo"
temperature = 0.5
top_p = 0.7
presence_penalty = 0.0
frequency_penalty = 0.0
reply_count = 1
role = "user"
model_max_tokens = 5000
url = config.bot_api_url.chat_url

json_post = {
"model": model,
"messages": messages,
"stream": True,
"temperature": temperature,
"top_p": top_p,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"n": reply_count,
"user": role,
"max_tokens": model_max_tokens,
}
# json_post.update(function_call_list["base"])
# if config.SEARCH_USE_GPT:
# json_post["functions"].append(function_call_list["web_search"])
# json_post["functions"].append(function_call_list["url_fetch"])
# print(get_token_count(message_list))
print(get_message_token(url, json_post))
Loading

0 comments on commit e2b0aec

Please sign in to comment.