151 lines
5.3 KiB
Python
151 lines
5.3 KiB
Python
import ast
|
||
import os
|
||
import re
|
||
from functools import wraps
|
||
|
||
import PyPDF2
|
||
from ratelimit import sleep_and_retry, limits
|
||
import requests
|
||
from flask_app.general.读取文件.clean_pdf import extract_common_header, clean_page_content
|
||
|
||
@sleep_and_retry
|
||
@limits(calls=9, period=1) # 每秒最多调用20次,qpm=1200万,两个服务器分流,每个10 //吃满可能有问题,改为9
|
||
def rate_limiter():
|
||
pass # 这个函数本身不执行任何操作,只用于限流
|
||
|
||
# 创建一个共享的装饰器
|
||
def shared_rate_limit(func):
|
||
@wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
rate_limiter() # 通过共享的限流器
|
||
return func(*args, **kwargs)
|
||
return wrapper
|
||
|
||
def extract_error_details(error_message):
|
||
"""
|
||
从错误消息中提取错误代码和内部错误代码。
|
||
假设错误消息的格式包含 'Error code: XXX - {...}'
|
||
"""
|
||
# 提取数值型错误代码
|
||
error_code_match = re.search(r'Error code:\s*(\d+)', error_message)
|
||
error_code = int(error_code_match.group(1)) if error_code_match else None
|
||
|
||
# 提取内部错误代码字符串(如 'data_inspection_failed')
|
||
error_code_string = None
|
||
error_dict_match = re.search(r'Error code:\s*\d+\s*-\s*(\{.*\})', error_message)
|
||
if error_dict_match:
|
||
error_dict_str = error_dict_match.group(1)
|
||
try:
|
||
# 使用 ast.literal_eval 解析字典字符串
|
||
error_dict = ast.literal_eval(error_dict_str)
|
||
error_code_string = error_dict.get('error', {}).get('code')
|
||
print(error_code_string)
|
||
except Exception as e:
|
||
print(f"解析错误消息失败: {e}")
|
||
|
||
return error_code, error_code_string
|
||
|
||
|
||
def get_total_tokens(text):
|
||
"""
|
||
调用 API 计算给定文本的总 Token 数量。 注:doubao的计算方法!与qianwen不一样
|
||
返回:
|
||
- int: 文本的 total_tokens 数量。
|
||
"""
|
||
# API 请求 URL
|
||
url = "https://ark.cn-beijing.volces.com/api/v3/tokenization"
|
||
|
||
# 获取 API 密钥
|
||
doubao_api_key = os.getenv("DOUBAO_API_KEY")
|
||
if not doubao_api_key:
|
||
raise ValueError("DOUBAO_API_KEY 环境变量未设置")
|
||
|
||
# 请求头
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": "Bearer " + doubao_api_key
|
||
}
|
||
model = "ep-20241119121710-425g6"
|
||
# 请求体
|
||
payload = {
|
||
"model": model,
|
||
"text": [text] # API 文档中要求 text 是一个列表
|
||
}
|
||
|
||
try:
|
||
response = requests.post(url, headers=headers, json=payload)
|
||
response.raise_for_status()
|
||
response_data = response.json()
|
||
total_tokens=response_data["data"][0]["total_tokens"]
|
||
return total_tokens
|
||
except Exception as e:
|
||
print(f"获取 Token 数量失败:{e}")
|
||
return 0
|
||
|
||
def pdf2txt(file_path):
|
||
common_header = extract_common_header(file_path)
|
||
# print(f"公共抬头:{common_header}")
|
||
# print("--------------------正文开始-------------------")
|
||
result = ""
|
||
with open(file_path, 'rb') as file:
|
||
reader = PyPDF2.PdfReader(file)
|
||
num_pages = len(reader.pages)
|
||
# print(f"Total pages: {num_pages}")
|
||
for page_num in range(num_pages):
|
||
page = reader.pages[page_num]
|
||
text = page.extract_text()
|
||
if text:
|
||
# print(f"--------第{page_num}页-----------")
|
||
cleaned_text = clean_page_content(text,common_header)
|
||
# print(cleaned_text)
|
||
result += cleaned_text
|
||
# print(f"Page {page_num + 1} Content:\n{cleaned_text}")
|
||
else:
|
||
print(f"Page {page_num + 1} is empty or text could not be extracted.")
|
||
directory = os.path.dirname(os.path.abspath(file_path))
|
||
output_path = os.path.join(directory, 'extract.txt')
|
||
# 将结果保存到 extract.txt 文件中
|
||
try:
|
||
with open(output_path, 'w', encoding='utf-8') as output_file:
|
||
output_file.write(result)
|
||
print(f"提取内容已保存到: {output_path}")
|
||
except IOError as e:
|
||
print(f"写入文件时发生错误: {e}")
|
||
# 返回保存的文件路径
|
||
return output_path
|
||
|
||
def read_txt_to_string(file_path):
|
||
"""
|
||
读取txt文件内容并返回一个包含所有内容的字符串,保持原有格式。
|
||
|
||
参数:
|
||
- file_path (str): txt文件的路径
|
||
|
||
返回:
|
||
- str: 包含文件内容的字符串
|
||
"""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as file: # 确保使用适当的编码
|
||
content = file.read() # 使用 read() 保持文件格式
|
||
return content
|
||
except FileNotFoundError:
|
||
return "错误:文件未找到。"
|
||
except Exception as e:
|
||
return f"错误:读取文件时发生错误。详细信息:{e}"
|
||
|
||
def generate_full_user_query(file_path, prompt_template):
|
||
"""
|
||
根据文件路径和提示词模板生成完整的user_query。
|
||
|
||
参数:
|
||
- file_path (str): 需要解析的文件路径。
|
||
- prompt_template (str): 包含{full_text}占位符的提示词模板。
|
||
|
||
返回:
|
||
- str: 完整的user_query。
|
||
"""
|
||
# 假设extract_text_by_page已经定义,用于提取文件内容
|
||
full_text=read_txt_to_string(file_path)
|
||
# 格式化提示词,将提取的文件内容插入到模板中
|
||
user_query = prompt_template.format(full_text=full_text)
|
||
return user_query |