zbparse/flask_app/general/llm/多线程提问.py

245 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import queue
import concurrent.futures
import time
from dashscope import Assistants, Messages, Runs, Threads
from llama_index.indices.managed.dashscope import DashScopeCloudRetriever
from flask_app.general.llm.大模型通用函数 import read_txt_to_string
from flask_app.general.llm.通义千问long import qianwen_long, upload_file
from flask_app.general.llm.qianwen_plus import qianwen_plus
def read_questions_from_file(file_path):
questions = []
current_question = ""
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip()
if not line: # 跳过空行
continue
if line.startswith('#'): # 跳过以#开头的行
continue
# 检查是否是新的问题编号,例如 "1."
match = re.match(r'^(\d+)\.', line)
if match:
# 如果有之前的问题,保存它
if current_question:
questions.append(current_question.strip())
# 提取问题内容,去掉编号和点
current_question = line.split('.', 1)[1].strip() + "\n"
else:
# 继续添加到当前问题
current_question += line + "\n"
# 添加最后一个问题(如果存在)
if current_question:
questions.append(current_question.strip())
return questions
#正文和文档名之间的内容
def send_message(assistant, message='百炼是什么?'):
ans = []
print(f"Query: {message}")
# create thread.
thread = Threads.create()
print(thread)
# create a message.
message = Messages.create(thread.id, content=message)
# create run
run = Runs.create(thread.id, assistant_id=assistant.id)
# print(run)
# wait for run completed or requires_action
run_status = Runs.wait(run.id, thread_id=thread.id)
# print(run_status)
# get the thread messages.
msgs = Messages.list(thread.id)
for message in msgs['data'][::-1]:
ans.append(message['content'][0]['text']['value'])
return ans
def rag_assistant(knowledge_name):
retriever = DashScopeCloudRetriever(knowledge_name)
pipeline_id = str(retriever.pipeline_id)
assistant = Assistants.create(
model='qwen-max',
name='smart helper',
description='智能助手,支持知识库查询和插件调用。',
temperature='0.3',
instructions="请记住以下材料,他们对回答问题有帮助,请你简洁准确地给出回答,不要给出无关内容。${documents}",
tools=[
{
"type": "code_interpreter"
},
{
"type": "rag",
"prompt_ra": {
"pipeline_id": pipeline_id,
"parameters": {
"type": "object",
"properties": {
"query_word": {
"type": "str",
"value": "${documents}"
}
}
}
}
}]
)
return assistant
def pure_assistant():
assistant = Assistants.create(
model='qwen-max',
name='smart helper',
description='智能助手,能基于用户的要求精准简洁地回答用户的提问',
instructions='智能助手,能基于用户的要求精准简洁地回答用户的提问',
tools=[
{
"type": "code_interpreter"
},
]
)
return assistant
def llm_call(question, knowledge_name,file_id, result_queue, ans_index, llm_type,need_extra=False):
"""
调用不同的 LLM 模型并将结果放入结果队列。
"""
try:
if llm_type==1:
print(f"rag_assistant! question:{question}")
assistant = rag_assistant(knowledge_name)
# assistant=create_assistant(knowledge_name)
ans = send_message(assistant, message=question)
result_queue.put((ans_index, (question, ans))) # 在队列中添加索引 (question, ans)
elif llm_type==2:
# print(f"qianwen_long! question:{question}")
# qianwen_res,usage = qianwen_long(file_id,question) #有bug
qianwen_res = qianwen_long(file_id, question,2,1,need_extra)
if not qianwen_res:
result_queue.put((ans_index, None)) # 如果为空字符串,直接返回 None
else:
result_queue.put((ans_index, (question, qianwen_res)))
elif llm_type==3:
# print(f"doubao! question:{question}") #暂时废弃doubao
# doubao_res=doubao_model(question,need_extra)
qianwen_plus_res=qianwen_plus(question,need_extra)
if not qianwen_plus_res:
result_queue.put((ans_index, None)) # 如果为空字符串,直接返回 None
else:
result_queue.put((ans_index, (question, qianwen_plus_res)))
elif llm_type==4:
qianwen_plus_res = qianwen_plus(question, need_extra)
if not qianwen_plus_res:
result_queue.put((ans_index, None)) # 如果为空字符串,直接返回 None
else:
result_queue.put((ans_index, (question, qianwen_plus_res)))
else :
assistant = pure_assistant()
ans = send_message(assistant, message=question)
result_queue.put((ans_index, (question, ans))) # 在队列中添加索引 (question, ans)
except Exception as e:
print(f"LLM 调用失败,查询索引 {ans_index},错误:{e}")
result_queue.put((ans_index, None)) # 使用 None 作为失败的占位符
def multi_threading(queries, knowledge_name="", file_id="", llm_type=1, need_extra=False):
if not queries:
return []
print("多线程提问starting multi_threading...")
result_queue = queue.Queue()
def submit_task(executor, query, idx):
# 提交任务,并返回对应的 future
return executor.submit(llm_call, query, knowledge_name, file_id, result_queue, idx, llm_type, need_extra)
future_to_index = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=40) as executor:
if llm_type == 4:
# 每组10个任务组与组之间间隔5秒,提高缓存命中率
group_size = 10
for i in range(0, len(queries), group_size):
group = queries[i:i + group_size]
for idx, query in enumerate(group, start=i):
future = submit_task(executor, query, idx)
future_to_index[future] = idx
# 若还有后续组则等待5秒
if i + group_size < len(queries):
time.sleep(5)
else:
# 直接一次性提交所有任务
for index, query in enumerate(queries):
future = submit_task(executor, query, index)
future_to_index[future] = index
# 收集所有任务的异常
for future in concurrent.futures.as_completed(future_to_index):
idx = future_to_index[future]
try:
future.result()
except Exception as exc:
print(f"查询索引 {idx} 生成了一个异常:{exc}")
result_queue.put((idx, None))
# 根据任务索引初始化结果列表
results = [None] * len(queries)
while not result_queue.empty():
index, result = result_queue.get()
results[index] = result
# 如果所有结果都为 None则返回空列表否则过滤掉 None 项
if all(result is None for result in results):
return []
return [r for r in results if r is not None]
if __name__ == "__main__":
start_time=time.time()
# file_path = r"C:\Users\Administrator\Desktop\fsdownload\39b0c3b4-1807-456c-8330-c5c7d1b7a2ca\ztbfile_procurement\ztbfile_procurement_1.pdf"
# file_id = upload_file(file_path)
# questions=["该招标文件的项目名称是项目编号或招标编号采购人或招标人采购代理机构或招标代理机构请按json格式给我提供信息键名分别是'项目名称','项目编号','采购人','采购代理机构',若存在未知信息,在对应的键值中填'未知'。","该招标文件的项目概况是项目基本情况是请按json格式给我提供信息键名分别为'项目概况','项目基本情况',若存在嵌套信息,嵌套内容键名以文件中对应字段命名,而嵌套键值必须与原文保持一致,若存在未知信息,在对应的键值中填'未知'。"]
# results=multi_threading(questions,"",file_id,2) #1代表使用百炼rag 2代表使用qianwen-long
# if not results:
# print("errror!")
# else:
# # 打印结果
# for question, response in results:
# print(f"Question: {question}")
# print(f"Response: {response}")
#
query=[]
processed_filepath=r"C:\Users\Administrator\Desktop\货物标\extract_files\107国道.txt"
full_text=read_txt_to_string(processed_filepath)
temp="请告诉我LED 全彩显示屏的功能是怎样的请以JSON格式返回键名为'LED 全彩显示屏',键值为字符串。"
user_query = f"文本内容:{full_text}\n" + temp
# user_query='''
# </tr><tr>
# <td colspan="1" rowspan="1">3</td>
# <td colspan="1" rowspan="1">大屏播控系统</td>
# <td colspan="1" rowspan="1">1、具有多用户多权限管理功能,支持多用户同时登录客户端,每个用户根据自身不同权限管理显示屏;2、系统对输入信号源进行预监视,实现在播控前预先查看的功能。</td>
# <td colspan="1" rowspan="1">中国</td>
# <td colspan="1" rowspan="1">1</td>
# <td colspan="1" rowspan="1">套</td>
# </tr><tr>
# 请告诉我大屏播控系统的功能是怎样的请以JSON格式返回键名为'大屏播控系统',键值为字符串。
# '''
for i in range(1,15):
query.append(user_query)
res=multi_threading(query,"","",4)
for _,response in res:
print(response)
# end_time = time.time()