284 lines
11 KiB
Python
284 lines
11 KiB
Python
import os
|
||
import time
|
||
import PyPDF2
|
||
import requests
|
||
from ratelimit import sleep_and_retry, limits
|
||
from flask_app.general.读取文件.clean_pdf import extract_common_header, clean_page_content
|
||
|
||
def pdf2txt(file_path):
|
||
common_header = extract_common_header(file_path)
|
||
# print(f"公共抬头:{common_header}")
|
||
# print("--------------------正文开始-------------------")
|
||
result = ""
|
||
with open(file_path, 'rb') as file:
|
||
reader = PyPDF2.PdfReader(file)
|
||
num_pages = len(reader.pages)
|
||
# print(f"Total pages: {num_pages}")
|
||
for page_num in range(num_pages):
|
||
page = reader.pages[page_num]
|
||
text = page.extract_text()
|
||
if text:
|
||
# print(f"--------第{page_num}页-----------")
|
||
cleaned_text = clean_page_content(text,common_header)
|
||
# print(cleaned_text)
|
||
result += cleaned_text
|
||
# print(f"Page {page_num + 1} Content:\n{cleaned_text}")
|
||
else:
|
||
print(f"Page {page_num + 1} is empty or text could not be extracted.")
|
||
directory = os.path.dirname(os.path.abspath(file_path))
|
||
output_path = os.path.join(directory, 'extract.txt')
|
||
# 将结果保存到 extract.txt 文件中
|
||
try:
|
||
with open(output_path, 'w', encoding='utf-8') as output_file:
|
||
output_file.write(result)
|
||
print(f"提取内容已保存到: {output_path}")
|
||
except IOError as e:
|
||
print(f"写入文件时发生错误: {e}")
|
||
# 返回保存的文件路径
|
||
return output_path
|
||
|
||
def read_txt_to_string(file_path):
|
||
"""
|
||
读取txt文件内容并返回一个包含所有内容的字符串,保持原有格式。
|
||
|
||
参数:
|
||
- file_path (str): txt文件的路径
|
||
|
||
返回:
|
||
- str: 包含文件内容的字符串
|
||
"""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as file: # 确保使用适当的编码
|
||
content = file.read() # 使用 read() 保持文件格式
|
||
return content
|
||
except FileNotFoundError:
|
||
return "错误:文件未找到。"
|
||
except Exception as e:
|
||
return f"错误:读取文件时发生错误。详细信息:{e}"
|
||
|
||
def get_total_tokens(text):
|
||
"""
|
||
调用 API 计算给定文本的总 Token 数量。 注:doubao的计算方法!与qianwen不一样
|
||
返回:
|
||
- int: 文本的 total_tokens 数量。
|
||
"""
|
||
# API 请求 URL
|
||
url = "https://ark.cn-beijing.volces.com/api/v3/tokenization"
|
||
|
||
# 获取 API 密钥
|
||
doubao_api_key = os.getenv("DOUBAO_API_KEY")
|
||
if not doubao_api_key:
|
||
raise ValueError("DOUBAO_API_KEY 环境变量未设置")
|
||
|
||
# 请求头
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": "Bearer " + doubao_api_key
|
||
}
|
||
model = "ep-20241119121710-425g6"
|
||
# 请求体
|
||
payload = {
|
||
"model": model,
|
||
"text": [text] # API 文档中要求 text 是一个列表
|
||
}
|
||
|
||
try:
|
||
response = requests.post(url, headers=headers, json=payload)
|
||
response.raise_for_status()
|
||
response_data = response.json()
|
||
total_tokens=response_data["data"][0]["total_tokens"]
|
||
return total_tokens
|
||
except Exception as e:
|
||
print(f"获取 Token 数量失败:{e}")
|
||
return 0
|
||
|
||
@sleep_and_retry
|
||
@limits(calls=10, period=1) # 每秒最多调用10次
|
||
def doubao_model(full_user_query, need_extra=False):
|
||
"""
|
||
对于429错误,一共尝试三次,前两次等待若干时间再发起调用,第三次换模型
|
||
:param full_user_query:
|
||
:param need_extra:
|
||
:return:
|
||
"""
|
||
print("call doubao...")
|
||
# 相关参数
|
||
url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
|
||
doubao_api_key = os.getenv("DOUBAO_API_KEY")
|
||
|
||
# 定义主模型和备用模型
|
||
models = {
|
||
"pro_32k": "ep-20241119121710-425g6", # 豆包Pro 32k模型
|
||
"pro_128k": "ep-20241119121743-xt6wg" # 128k模型
|
||
}
|
||
|
||
# 判断用户查询字符串的长度
|
||
token_count = get_total_tokens(full_user_query)
|
||
if token_count > 31500:
|
||
selected_model = "pro_128k" # 如果长度超过32k,直接使用128k模型
|
||
else:
|
||
selected_model = "pro_32k" # 默认使用32k模型
|
||
|
||
# 请求头
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": "Bearer " + doubao_api_key
|
||
}
|
||
|
||
max_retries_429 = 3 # 针对 429 错误的最大重试次数
|
||
max_retries_other = 1 # 针对其他错误的最大重试次数
|
||
attempt = 0
|
||
response = None # 确保 response 被定义
|
||
|
||
while True:
|
||
# 请求数据
|
||
data = {
|
||
"model": models[selected_model],
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": full_user_query
|
||
}
|
||
],
|
||
"temperature": 0.2
|
||
}
|
||
try:
|
||
response = requests.post(url, headers=headers, json=data) # 设置超时时间为10秒
|
||
response.raise_for_status() # 如果响应状态码不是200,将引发HTTPError
|
||
|
||
# 获取响应 JSON
|
||
response_json = response.json()
|
||
|
||
# 获取返回内容
|
||
content = response_json["choices"][0]["message"]["content"]
|
||
|
||
# 获取 completion_tokens
|
||
completion_tokens = response_json["usage"].get("completion_tokens", 0)
|
||
|
||
# 根据 need_extra 返回不同的结果
|
||
if need_extra:
|
||
return content, completion_tokens
|
||
else:
|
||
return content
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
# 获取状态码并处理不同的重试逻辑
|
||
status_code = response.status_code if response is not None else None
|
||
print(f"请求失败,状态码: {status_code}")
|
||
print("请求失败,完整的响应内容如下:")
|
||
if response is not None:
|
||
print(response.text) # 打印原始的响应内容,可能是 JSON 格式,也可能是其他格式
|
||
|
||
# 如果是 429 错误
|
||
if status_code == 429:
|
||
if attempt < max_retries_429:
|
||
if attempt == 0:
|
||
wait_time = 3
|
||
elif attempt == 1:
|
||
wait_time = 6
|
||
elif attempt == 2:
|
||
# 第三次重试时切换模型
|
||
alternative_model = "pro_128k" if selected_model == "pro_32k" else "pro_32k"
|
||
print(f"状态码为 429,切换模型从 {selected_model} 到 {alternative_model} 并重试...")
|
||
selected_model = alternative_model
|
||
wait_time = 0 # 立即重试,无需等待
|
||
print(f"等待 {wait_time} 秒后重试...")
|
||
if wait_time > 0:
|
||
time.sleep(wait_time)
|
||
else:
|
||
print(f"状态码为 429,已达到最大重试次数 {max_retries_429} 次。")
|
||
break # 超过最大重试次数,退出循环
|
||
else:
|
||
# 针对其他错误
|
||
if attempt < max_retries_other:
|
||
print("非 429 错误,等待 1 秒后重试...")
|
||
time.sleep(1)
|
||
else:
|
||
print(f"非 429 错误,已达到最大重试次数 {max_retries_other} 次。")
|
||
break # 超过最大重试次数,退出循环
|
||
|
||
attempt += 1 # 增加重试计数
|
||
|
||
# 如果到这里,说明所有尝试都失败了
|
||
print(f"请求失败,已达到最大重试次数。")
|
||
if need_extra:
|
||
return None, 0
|
||
else:
|
||
return None
|
||
|
||
def generate_full_user_query(file_path, prompt_template):
|
||
"""
|
||
根据文件路径和提示词模板生成完整的user_query。
|
||
|
||
参数:
|
||
- file_path (str): 需要解析的文件路径。
|
||
- prompt_template (str): 包含{full_text}占位符的提示词模板。
|
||
|
||
返回:
|
||
- str: 完整的user_query。
|
||
"""
|
||
# 假设extract_text_by_page已经定义,用于提取文件内容
|
||
full_text=read_txt_to_string(file_path)
|
||
# 格式化提示词,将提取的文件内容插入到模板中
|
||
user_query = prompt_template.format(full_text=full_text)
|
||
return user_query
|
||
|
||
if __name__ == "__main__":
|
||
txt_path = r"output.txt"
|
||
pdf_path_1 = "D:/bid_generator/task_folder/9a447eb0-24b8-4f51-8164-d91a62edea25/tmp/bid_format.pdf"
|
||
pdf_path_2 = r"C:\Users\Administrator\Desktop\货物标\output1\竞争性谈判文件_procurement.pdf"
|
||
prompt_template = '''
|
||
任务:解析采购文件,提取采购需求,并以JSON格式返回。
|
||
|
||
要求与指南:
|
||
1. 精准定位:运用文档理解能力,找到文件中的采购需求部分。
|
||
2. 系统归属:若货物明确属于某个系统,则将其作为该系统的二级键。
|
||
3. 非清单形式处理:若未出现采购清单,则从表格或文字中摘取系统和货物信息。
|
||
4. 软件需求:对于软件应用需求,列出系统模块构成,并作为系统键值的一部分。
|
||
5. 系统功能:若文中提及系统功能,则在系统值中添加'系统功能'二级键,不展开具体内容。
|
||
6. 完整性:确保不遗漏系统内的货物,也不添加未提及的内容。
|
||
|
||
输出格式:
|
||
1.JSON格式,最外层键名为'采购需求'。
|
||
2.嵌套键名为系统或货物名称,与原文保持一致。
|
||
3.键值应为空对象({{}}),仅返回名称。
|
||
4.不包含'说明'、'规格'、'技术参数'等列内容。
|
||
5.层次关系用嵌套键值对表示。
|
||
6.最后一级键内值留空或填'未知'(如数量较多或未知内容)。
|
||
|
||
特殊情况处理:
|
||
同一层级下同名但采购要求不同的货物,以'货物名-编号'区分,编号从1递增。
|
||
|
||
示例输出结构:
|
||
{{
|
||
"采购需求": {{
|
||
"交换机-1": {{}},
|
||
"交换机-2": {{}},
|
||
"门禁管理系统": {{
|
||
// 可包含其他货物或模块
|
||
}},
|
||
"交通监控视频子系统": {{
|
||
"系统功能": {{}},
|
||
"高清视频抓拍像机": {{}},
|
||
"补光灯": {{}}
|
||
}},
|
||
"LED全彩显示屏": {{}}
|
||
// 其他系统和货物
|
||
}}
|
||
}}
|
||
|
||
文件内容(已包含):{full_text}
|
||
|
||
注意事项:
|
||
1.严格按照上述要求执行,确保输出准确性和规范性。
|
||
2.如有任何疑问或不确定内容,请保留原文描述,必要时使用'未知'标注。
|
||
'''
|
||
# processed_filepath = convert_pdf_to_markdown(pdf_path_2) # 转markdown格式
|
||
# processed_filepath = pdf2txt(pdf_path_2) #纯文本提取
|
||
# user_query=generate_full_user_query(processed_filepath,prompt_template)
|
||
user_query="一年有多少天?"
|
||
res=doubao_model(user_query)
|
||
# res=get_total_tokens("hh我是天才")
|
||
print(res)
|
||
# print("--------------------")
|
||
# print(user_query) |