zbparse/flask_app/general/llm/大模型通用函数.py

151 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import ast
import os
import re
from functools import wraps
import PyPDF2
from ratelimit import sleep_and_retry, limits
import requests
from flask_app.general.读取文件.clean_pdf import extract_common_header, clean_page_content
@sleep_and_retry
@limits(calls=9, period=1) # 每秒最多调用20次qpm=1200万两个服务器分流每个10 //吃满可能有问题改为9
def rate_limiter():
pass # 这个函数本身不执行任何操作,只用于限流
# 创建一个共享的装饰器
def shared_rate_limit(func):
@wraps(func)
def wrapper(*args, **kwargs):
rate_limiter() # 通过共享的限流器
return func(*args, **kwargs)
return wrapper
def extract_error_details(error_message):
"""
从错误消息中提取错误代码和内部错误代码。
假设错误消息的格式包含 'Error code: XXX - {...}'
"""
# 提取数值型错误代码
error_code_match = re.search(r'Error code:\s*(\d+)', error_message)
error_code = int(error_code_match.group(1)) if error_code_match else None
# 提取内部错误代码字符串(如 'data_inspection_failed'
error_code_string = None
error_dict_match = re.search(r'Error code:\s*\d+\s*-\s*(\{.*\})', error_message)
if error_dict_match:
error_dict_str = error_dict_match.group(1)
try:
# 使用 ast.literal_eval 解析字典字符串
error_dict = ast.literal_eval(error_dict_str)
error_code_string = error_dict.get('error', {}).get('code')
print(error_code_string)
except Exception as e:
print(f"解析错误消息失败: {e}")
return error_code, error_code_string
def get_total_tokens(text):
"""
调用 API 计算给定文本的总 Token 数量。 注doubao的计算方法与qianwen不一样
返回:
- int: 文本的 total_tokens 数量。
"""
# API 请求 URL
url = "https://ark.cn-beijing.volces.com/api/v3/tokenization"
# 获取 API 密钥
doubao_api_key = os.getenv("DOUBAO_API_KEY")
if not doubao_api_key:
raise ValueError("DOUBAO_API_KEY 环境变量未设置")
# 请求头
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + doubao_api_key
}
model = "ep-20241119121710-425g6"
# 请求体
payload = {
"model": model,
"text": [text] # API 文档中要求 text 是一个列表
}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
response_data = response.json()
total_tokens=response_data["data"][0]["total_tokens"]
return total_tokens
except Exception as e:
print(f"获取 Token 数量失败:{e}")
return 0
def pdf2txt(file_path):
common_header = extract_common_header(file_path)
# print(f"公共抬头:{common_header}")
# print("--------------------正文开始-------------------")
result = ""
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
num_pages = len(reader.pages)
# print(f"Total pages: {num_pages}")
for page_num in range(num_pages):
page = reader.pages[page_num]
text = page.extract_text()
if text:
# print(f"--------第{page_num}页-----------")
cleaned_text = clean_page_content(text,common_header)
# print(cleaned_text)
result += cleaned_text
# print(f"Page {page_num + 1} Content:\n{cleaned_text}")
else:
print(f"Page {page_num + 1} is empty or text could not be extracted.")
directory = os.path.dirname(os.path.abspath(file_path))
output_path = os.path.join(directory, 'extract.txt')
# 将结果保存到 extract.txt 文件中
try:
with open(output_path, 'w', encoding='utf-8') as output_file:
output_file.write(result)
print(f"提取内容已保存到: {output_path}")
except IOError as e:
print(f"写入文件时发生错误: {e}")
# 返回保存的文件路径
return output_path
def read_txt_to_string(file_path):
"""
读取txt文件内容并返回一个包含所有内容的字符串保持原有格式。
参数:
- file_path (str): txt文件的路径
返回:
- str: 包含文件内容的字符串
"""
try:
with open(file_path, 'r', encoding='utf-8') as file: # 确保使用适当的编码
content = file.read() # 使用 read() 保持文件格式
return content
except FileNotFoundError:
return "错误:文件未找到。"
except Exception as e:
return f"错误:读取文件时发生错误。详细信息:{e}"
def generate_full_user_query(file_path, prompt_template):
"""
根据文件路径和提示词模板生成完整的user_query。
参数:
- file_path (str): 需要解析的文件路径。
- prompt_template (str): 包含{full_text}占位符的提示词模板。
返回:
- str: 完整的user_query。
"""
# 假设extract_text_by_page已经定义用于提取文件内容
full_text=read_txt_to_string(file_path)
# 格式化提示词,将提取的文件内容插入到模板中
user_query = prompt_template.format(full_text=full_text)
return user_query