zbparse/flask_app/general/test_doubao.py

178 lines
6.1 KiB
Python
Raw Normal View History

2024-12-23 15:47:41 +08:00
import os
import time
import requests
from ratelimit import sleep_and_retry, limits
def read_txt_to_string(file_path):
"""
读取txt文件内容并返回一个包含所有内容的字符串保持原有格式
参数:
- file_path (str): txt文件的路径
返回:
- str: 包含文件内容的字符串
"""
try:
with open(file_path, 'r', encoding='utf-8') as file: # 确保使用适当的编码
content = file.read() # 使用 read() 保持文件格式
return content
except FileNotFoundError:
return "错误:文件未找到。"
except Exception as e:
return f"错误:读取文件时发生错误。详细信息:{e}"
def generate_full_user_query(file_path, prompt_template):
"""
根据文件路径和提示词模板生成完整的user_query
参数
- file_path (str): 需要解析的文件路径
- prompt_template (str): 包含{full_text}占位符的提示词模板
返回
- str: 完整的user_query
"""
# 假设extract_text_by_page已经定义用于提取文件内容
full_text=read_txt_to_string(file_path)
# 格式化提示词,将提取的文件内容插入到模板中
user_query = prompt_template.format(full_text=full_text)
return user_query
def get_total_tokens(text):
"""
调用 API 计算给定文本的总 Token 数量
参数
- text (str): 需要计算 Token 的文本
- model (str): 使用的模型名称默认值为 "ep-20241119121710-425g6"
返回
- int: 文本的 total_tokens 数量
"""
# API 请求 URL
url = "https://ark.cn-beijing.volces.com/api/v3/tokenization"
# 获取 API 密钥
doubao_api_key = os.getenv("DOUBAO_API_KEY")
if not doubao_api_key:
raise ValueError("DOUBAO_API_KEY 环境变量未设置")
# 请求头
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + doubao_api_key
}
model = "ep-20241119121710-425g6"
# 请求体
payload = {
"model": model,
"text": [text] # API 文档中要求 text 是一个列表
}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
response_data = response.json()
total_tokens=response_data["data"][0]["total_tokens"]
return total_tokens
except Exception as e:
print(f"获取 Token 数量失败:{e}")
return 0
@sleep_and_retry
@limits(calls=10, period=1) # 每秒最多调用10次
def doubao_model(full_user_query, need_extra=False):
print("call doubao...")
# 相关参数
url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
doubao_api_key = os.getenv("DOUBAO_API_KEY")
# 定义主模型和备用模型
models = {
"pro_32k": "ep-20241119121710-425g6", # 豆包Pro 32k模型
"pro_128k": "ep-20241119121743-xt6wg" # 128k模型
}
# 判断用户查询字符串的长度
token_count = get_total_tokens(full_user_query)
if token_count > 31500:
selected_model = models["pro_128k"] # 如果长度超过32k直接使用128k模型
else:
selected_model = models["pro_32k"] # 默认使用32k模型
# 请求头
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + doubao_api_key
}
max_retries_429 = 2 # 针对 429 错误的最大重试次数
max_retries_other = 1 # 针对其他错误的最大重试次数
attempt = 0
response = None # 确保 response 被定义
while True:
# 请求数据
data = {
"model": selected_model,
"messages": [
{
"role": "user",
"content": full_user_query
}
],
"temperature": 0.2
}
try:
response = requests.post(url, headers=headers, json=data) # 设置超时时间为10秒
response.raise_for_status() # 如果响应状态码不是200将引发HTTPError
# 获取响应 JSON
response_json = response.json()
# 获取返回内容
content = response_json["choices"][0]["message"]["content"]
# 获取 completion_tokens
completion_tokens = response_json["usage"].get("completion_tokens", 0)
# 根据 need_extra 返回不同的结果
if need_extra:
return content, completion_tokens
else:
return content
except requests.exceptions.RequestException as e:
# 获取状态码并处理不同的重试逻辑
status_code = response.status_code if response is not None else None
print(f"请求失败,状态码: {status_code}")
print("请求失败,完整的响应内容如下:")
if response is not None:
print(response.text) # 打印原始的响应内容,可能是 JSON 格式,也可能是其他格式
# 如果是 429 错误
if status_code == 429:
if attempt < max_retries_429:
wait_time = 2 if attempt == 0 else 4
print(f"状态码为 429等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
print(f"状态码为 429已达到最大重试次数 {max_retries_429} 次。")
break # 超过最大重试次数,退出循环
else:
# 针对其他错误
if attempt < max_retries_other:
print("非 429 错误,等待 1 秒后重试...")
time.sleep(1)
else:
print(f"非 429 错误,已达到最大重试次数 {max_retries_other} 次。")
break # 超过最大重试次数,退出循环
attempt += 1 # 增加重试计数
# 如果到这里,说明所有尝试都失败了
print(f"请求失败,已达到最大重试次数。")
if need_extra:
return None, 0
else:
return None