178 lines
6.1 KiB
Python
178 lines
6.1 KiB
Python
|
import os
|
|||
|
import time
|
|||
|
|
|||
|
import requests
|
|||
|
from ratelimit import sleep_and_retry, limits
|
|||
|
def read_txt_to_string(file_path):
|
|||
|
"""
|
|||
|
读取txt文件内容并返回一个包含所有内容的字符串,保持原有格式。
|
|||
|
|
|||
|
参数:
|
|||
|
- file_path (str): txt文件的路径
|
|||
|
|
|||
|
返回:
|
|||
|
- str: 包含文件内容的字符串
|
|||
|
"""
|
|||
|
try:
|
|||
|
with open(file_path, 'r', encoding='utf-8') as file: # 确保使用适当的编码
|
|||
|
content = file.read() # 使用 read() 保持文件格式
|
|||
|
return content
|
|||
|
except FileNotFoundError:
|
|||
|
return "错误:文件未找到。"
|
|||
|
except Exception as e:
|
|||
|
return f"错误:读取文件时发生错误。详细信息:{e}"
|
|||
|
def generate_full_user_query(file_path, prompt_template):
|
|||
|
"""
|
|||
|
根据文件路径和提示词模板生成完整的user_query。
|
|||
|
|
|||
|
参数:
|
|||
|
- file_path (str): 需要解析的文件路径。
|
|||
|
- prompt_template (str): 包含{full_text}占位符的提示词模板。
|
|||
|
|
|||
|
返回:
|
|||
|
- str: 完整的user_query。
|
|||
|
"""
|
|||
|
# 假设extract_text_by_page已经定义,用于提取文件内容
|
|||
|
full_text=read_txt_to_string(file_path)
|
|||
|
# 格式化提示词,将提取的文件内容插入到模板中
|
|||
|
user_query = prompt_template.format(full_text=full_text)
|
|||
|
|
|||
|
return user_query
|
|||
|
def get_total_tokens(text):
|
|||
|
"""
|
|||
|
调用 API 计算给定文本的总 Token 数量。
|
|||
|
|
|||
|
参数:
|
|||
|
- text (str): 需要计算 Token 的文本。
|
|||
|
- model (str): 使用的模型名称,默认值为 "ep-20241119121710-425g6"。
|
|||
|
|
|||
|
返回:
|
|||
|
- int: 文本的 total_tokens 数量。
|
|||
|
"""
|
|||
|
# API 请求 URL
|
|||
|
url = "https://ark.cn-beijing.volces.com/api/v3/tokenization"
|
|||
|
|
|||
|
# 获取 API 密钥
|
|||
|
doubao_api_key = os.getenv("DOUBAO_API_KEY")
|
|||
|
if not doubao_api_key:
|
|||
|
raise ValueError("DOUBAO_API_KEY 环境变量未设置")
|
|||
|
|
|||
|
# 请求头
|
|||
|
headers = {
|
|||
|
"Content-Type": "application/json",
|
|||
|
"Authorization": "Bearer " + doubao_api_key
|
|||
|
}
|
|||
|
model = "ep-20241119121710-425g6"
|
|||
|
# 请求体
|
|||
|
payload = {
|
|||
|
"model": model,
|
|||
|
"text": [text] # API 文档中要求 text 是一个列表
|
|||
|
}
|
|||
|
|
|||
|
try:
|
|||
|
response = requests.post(url, headers=headers, json=payload)
|
|||
|
response.raise_for_status()
|
|||
|
response_data = response.json()
|
|||
|
total_tokens=response_data["data"][0]["total_tokens"]
|
|||
|
return total_tokens
|
|||
|
except Exception as e:
|
|||
|
print(f"获取 Token 数量失败:{e}")
|
|||
|
return 0
|
|||
|
|
|||
|
@sleep_and_retry
|
|||
|
@limits(calls=10, period=1) # 每秒最多调用10次
|
|||
|
def doubao_model(full_user_query, need_extra=False):
|
|||
|
print("call doubao...")
|
|||
|
# 相关参数
|
|||
|
url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
|
|||
|
doubao_api_key = os.getenv("DOUBAO_API_KEY")
|
|||
|
|
|||
|
# 定义主模型和备用模型
|
|||
|
models = {
|
|||
|
"pro_32k": "ep-20241119121710-425g6", # 豆包Pro 32k模型
|
|||
|
"pro_128k": "ep-20241119121743-xt6wg" # 128k模型
|
|||
|
}
|
|||
|
|
|||
|
# 判断用户查询字符串的长度
|
|||
|
token_count = get_total_tokens(full_user_query)
|
|||
|
if token_count > 31500:
|
|||
|
selected_model = models["pro_128k"] # 如果长度超过32k,直接使用128k模型
|
|||
|
else:
|
|||
|
selected_model = models["pro_32k"] # 默认使用32k模型
|
|||
|
|
|||
|
# 请求头
|
|||
|
headers = {
|
|||
|
"Content-Type": "application/json",
|
|||
|
"Authorization": "Bearer " + doubao_api_key
|
|||
|
}
|
|||
|
|
|||
|
max_retries_429 = 2 # 针对 429 错误的最大重试次数
|
|||
|
max_retries_other = 1 # 针对其他错误的最大重试次数
|
|||
|
attempt = 0
|
|||
|
response = None # 确保 response 被定义
|
|||
|
|
|||
|
while True:
|
|||
|
# 请求数据
|
|||
|
data = {
|
|||
|
"model": selected_model,
|
|||
|
"messages": [
|
|||
|
{
|
|||
|
"role": "user",
|
|||
|
"content": full_user_query
|
|||
|
}
|
|||
|
],
|
|||
|
"temperature": 0.2
|
|||
|
}
|
|||
|
try:
|
|||
|
response = requests.post(url, headers=headers, json=data) # 设置超时时间为10秒
|
|||
|
response.raise_for_status() # 如果响应状态码不是200,将引发HTTPError
|
|||
|
|
|||
|
# 获取响应 JSON
|
|||
|
response_json = response.json()
|
|||
|
|
|||
|
# 获取返回内容
|
|||
|
content = response_json["choices"][0]["message"]["content"]
|
|||
|
|
|||
|
# 获取 completion_tokens
|
|||
|
completion_tokens = response_json["usage"].get("completion_tokens", 0)
|
|||
|
|
|||
|
# 根据 need_extra 返回不同的结果
|
|||
|
if need_extra:
|
|||
|
return content, completion_tokens
|
|||
|
else:
|
|||
|
return content
|
|||
|
|
|||
|
except requests.exceptions.RequestException as e:
|
|||
|
# 获取状态码并处理不同的重试逻辑
|
|||
|
status_code = response.status_code if response is not None else None
|
|||
|
print(f"请求失败,状态码: {status_code}")
|
|||
|
print("请求失败,完整的响应内容如下:")
|
|||
|
if response is not None:
|
|||
|
print(response.text) # 打印原始的响应内容,可能是 JSON 格式,也可能是其他格式
|
|||
|
|
|||
|
# 如果是 429 错误
|
|||
|
if status_code == 429:
|
|||
|
if attempt < max_retries_429:
|
|||
|
wait_time = 2 if attempt == 0 else 4
|
|||
|
print(f"状态码为 429,等待 {wait_time} 秒后重试...")
|
|||
|
time.sleep(wait_time)
|
|||
|
else:
|
|||
|
print(f"状态码为 429,已达到最大重试次数 {max_retries_429} 次。")
|
|||
|
break # 超过最大重试次数,退出循环
|
|||
|
else:
|
|||
|
# 针对其他错误
|
|||
|
if attempt < max_retries_other:
|
|||
|
print("非 429 错误,等待 1 秒后重试...")
|
|||
|
time.sleep(1)
|
|||
|
else:
|
|||
|
print(f"非 429 错误,已达到最大重试次数 {max_retries_other} 次。")
|
|||
|
break # 超过最大重试次数,退出循环
|
|||
|
|
|||
|
attempt += 1 # 增加重试计数
|
|||
|
|
|||
|
# 如果到这里,说明所有尝试都失败了
|
|||
|
print(f"请求失败,已达到最大重试次数。")
|
|||
|
if need_extra:
|
|||
|
return None, 0
|
|||
|
else:
|
|||
|
return None
|