zbparse/flask_app/main/回答来源.py
2024-08-29 16:37:09 +08:00

217 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#基于多线程提问,现已废弃
# assistant_id
import queue
import concurrent.futures
from dashscope import Assistants, Messages, Runs, Threads
from llama_index.indices.managed.dashscope import DashScopeCloudRetriever
from json_utils import extract_content_from_json
prompt = """
# 角色
你是一个文档处理专家,专门负责理解和操作基于特定内容的文档任务,这包括解析、总结、搜索或生成与给定文档相关的各类信息。
## 技能
### 技能 1文档解析与摘要
- 深入理解并分析${document1}的内容,提取关键信息。
- 根据需求生成简洁明了的摘要,保持原文核心意义不变。
### 技能 2信息检索与关联
- 在${document1}中高效检索特定信息或关键词。
- 能够识别并链接到文档内部或外部的相关内容,增强信息的连贯性和深度。
## 限制
- 所有操作均需基于${document1}的内容,不可超出此范围创造信息。
- 在处理敏感或机密信息时,需遵守严格的隐私和安全规定。
- 确保所有生成或改编的内容逻辑连贯,无误导性信息。
请注意,上述技能执行时将直接利用并参考${document1}的具体内容,以确保所有产出紧密相关且高质量。
"""
prom = '请记住以下材料,他们对回答问题有帮助,请你简洁准确地给出回答,不要给出无关内容。${document1}'
#正文和文档名之间的内容
def extract_content_between_tags(text):
results = []
# 使用“【正文】”来分割文本
parts = text.split('【正文】')[1:] # 跳过第一个分割结果,因为它前面不会有内容
for index, part in enumerate(parts):
# 查找“【文档名】”标签的位置
doc_name_index = part.find('【文档名】')
# 查找 'file_ids' 标签的位置
file_ids_index = part.find("'file_ids'")
# 根据是否找到“【文档名】”来决定提取内容的截止点
if doc_name_index != -1:
end_index = doc_name_index
elif file_ids_index != -1:
end_index = file_ids_index
else:
end_index = len(part)
# 提取内容
content = part[:end_index].strip()
results.append(content)
# 如果存在 file_ids处理最后一部分特别提取 file_ids 前的内容
if "'file_ids'" in parts[-1]:
file_ids_index = parts[-1].find("'file_ids'")
if file_ids_index != -1:
last_content = parts[-1][:file_ids_index].strip()
results[-1] = last_content # 更新最后一部分的内容,确保只到 file_ids
return results
def find_references_in_extracted(formatted_ans, extracted_references):
results = {} # 用来存储匹配结果的字典
# 递归函数,用于处理多层嵌套的字典
def recurse_through_dict(current_dict, path=[]):
for key, value in current_dict.items():
# 检查值是否还是字典,如果是,进行递归
if isinstance(value, dict):
recurse_through_dict(value, path + [key])
else:
# 特定值处理:如果值为'未知',直接设置索引为-1
if value == '未知':
results['.'.join(path + [key])] = -1
else:
# 进行匹配检查
found = False
for index, reference in enumerate(extracted_references):
if str(value) in reference: # 转换为字符串,确保兼容性
results['.'.join(path + [key])] = index # 使用点表示法记录路径
found = True
break
if not found:
results['.'.join(path + [key])] = None
# 从根字典开始递归
recurse_through_dict(formatted_ans)
return results
def send_message(assistant, message='百炼是什么?'):
ans = []
print(f"Query: {message}")
# create thread.
thread = Threads.create()
print(thread)
# create a message.
message = Messages.create(thread.id, content=message)
# create run
run = Runs.create(thread.id, assistant_id=assistant.id)
print(run)
# wait for run completed or requires_action
run_status = Runs.wait(run.id, thread_id=thread.id)
print(run_status)
reference_txt = str(run_status)
extracted_references = extract_content_between_tags(reference_txt) #引用的文章来源list
# get the thread messages.
msgs = Messages.list(thread.id)
for message in msgs['data'][::-1]:
ans.append(message['content'][0]['text']['value'])
return ans,extracted_references
def rag_assistant(knowledge_name):
retriever = DashScopeCloudRetriever(knowledge_name)
pipeline_id = str(retriever.pipeline_id)
assistant = Assistants.create(
model='qwen-max',
name='smart helper',
description='智能助手,支持知识库查询和插件调用。',
temperature='0.3',
instructions=prom,
tools=[
{
"type": "code_interpreter"
},
{
"type": "rag",
"prompt_ra": {
"pipeline_id": pipeline_id,
"parameters": {
"type": "object",
"properties": {
"query_word": {
"type": "str",
"value": "${document1}"
}
}
}
}
}]
)
return assistant
def pure_assistant():
assistant = Assistants.create(
model='qwen-max',
name='smart helper',
description='智能助手,能基于用户的要求精准简洁地回答用户的提问',
instructions='智能助手,能基于用户的要求精准简洁地回答用户的提问',
tools=[
{
"type": "code_interpreter"
},
]
)
return assistant
def llm_call(question, knowledge_name, result_queue, ans_index, use_rag=True):
if use_rag:
assistant = rag_assistant(knowledge_name)
else:
assistant = pure_assistant()
ans,extracted_references = send_message(assistant, message=question)
for index, reference in enumerate(extracted_references, start=0):
print(f"{index}. {reference}")
formatted_ans=extract_content_from_json(ans[1])
print(formatted_ans)
results = find_references_in_extracted(formatted_ans, extracted_references)
for key, index in results.items():
print(f"{key}: Found at index {index}")
result_queue.put((ans_index, (question, ans))) # 在队列中添加索引 (question, ans)
def multi_threading(queries, knowledge_name, use_rag=True):
result_queue = queue.Queue()
# 使用 ThreadPoolExecutor 管理线程
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
# 使用字典保存每个提交的任务的Future对象以便按顺序访问结果
future_to_query = {executor.submit(llm_call, query, knowledge_name, result_queue, index, use_rag): index for
index, query in enumerate(queries)}
# 收集每个线程的结果
for future in concurrent.futures.as_completed(future_to_query):
index = future_to_query[future]
# 由于 llm_call 函数本身会处理结果,这里只需要确保任务执行完成
try:
future.result() # 可以用来捕获异常或确认任务完成
except Exception as exc:
print(f"Query {index} generated an exception: {exc}")
# 从队列中获取所有结果并按索引排序
results = [None] * len(queries)
while not result_queue.empty():
index, result = result_queue.get()
results[index] = result
return results
if __name__ == "__main__":
# 读取问题列表
questions = ["该招标文件的工程概况或项目概况招标范围是请按json格式给我提供信息键名分别为'工程概况','招标范围',若存在嵌套信息,嵌套内容键名以文件中对应字段命名,若存在未知信息,在对应的键值中填'未知'"]
knowledge_name = "招标解析5word"
results = multi_threading(questions, knowledge_name, use_rag=True)
# 打印结果
for question, response in results:
print(f"Question: {question}")
print(f"Response: {response}")