zbparse/flask_app/general/test_doubao.py

178 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import time
import requests
from ratelimit import sleep_and_retry, limits
def read_txt_to_string(file_path):
"""
读取txt文件内容并返回一个包含所有内容的字符串保持原有格式。
参数:
- file_path (str): txt文件的路径
返回:
- str: 包含文件内容的字符串
"""
try:
with open(file_path, 'r', encoding='utf-8') as file: # 确保使用适当的编码
content = file.read() # 使用 read() 保持文件格式
return content
except FileNotFoundError:
return "错误:文件未找到。"
except Exception as e:
return f"错误:读取文件时发生错误。详细信息:{e}"
def generate_full_user_query(file_path, prompt_template):
"""
根据文件路径和提示词模板生成完整的user_query。
参数:
- file_path (str): 需要解析的文件路径。
- prompt_template (str): 包含{full_text}占位符的提示词模板。
返回:
- str: 完整的user_query。
"""
# 假设extract_text_by_page已经定义用于提取文件内容
full_text=read_txt_to_string(file_path)
# 格式化提示词,将提取的文件内容插入到模板中
user_query = prompt_template.format(full_text=full_text)
return user_query
def get_total_tokens(text):
"""
调用 API 计算给定文本的总 Token 数量。
参数:
- text (str): 需要计算 Token 的文本。
- model (str): 使用的模型名称,默认值为 "ep-20241119121710-425g6"
返回:
- int: 文本的 total_tokens 数量。
"""
# API 请求 URL
url = "https://ark.cn-beijing.volces.com/api/v3/tokenization"
# 获取 API 密钥
doubao_api_key = os.getenv("DOUBAO_API_KEY")
if not doubao_api_key:
raise ValueError("DOUBAO_API_KEY 环境变量未设置")
# 请求头
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + doubao_api_key
}
model = "ep-20241119121710-425g6"
# 请求体
payload = {
"model": model,
"text": [text] # API 文档中要求 text 是一个列表
}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
response_data = response.json()
total_tokens=response_data["data"][0]["total_tokens"]
return total_tokens
except Exception as e:
print(f"获取 Token 数量失败:{e}")
return 0
@sleep_and_retry
@limits(calls=10, period=1) # 每秒最多调用10次
def doubao_model(full_user_query, need_extra=False):
print("call doubao...")
# 相关参数
url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
doubao_api_key = os.getenv("DOUBAO_API_KEY")
# 定义主模型和备用模型
models = {
"pro_32k": "ep-20241119121710-425g6", # 豆包Pro 32k模型
"pro_128k": "ep-20241119121743-xt6wg" # 128k模型
}
# 判断用户查询字符串的长度
token_count = get_total_tokens(full_user_query)
if token_count > 31500:
selected_model = models["pro_128k"] # 如果长度超过32k直接使用128k模型
else:
selected_model = models["pro_32k"] # 默认使用32k模型
# 请求头
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + doubao_api_key
}
max_retries_429 = 2 # 针对 429 错误的最大重试次数
max_retries_other = 1 # 针对其他错误的最大重试次数
attempt = 0
response = None # 确保 response 被定义
while True:
# 请求数据
data = {
"model": selected_model,
"messages": [
{
"role": "user",
"content": full_user_query
}
],
"temperature": 0.2
}
try:
response = requests.post(url, headers=headers, json=data) # 设置超时时间为10秒
response.raise_for_status() # 如果响应状态码不是200将引发HTTPError
# 获取响应 JSON
response_json = response.json()
# 获取返回内容
content = response_json["choices"][0]["message"]["content"]
# 获取 completion_tokens
completion_tokens = response_json["usage"].get("completion_tokens", 0)
# 根据 need_extra 返回不同的结果
if need_extra:
return content, completion_tokens
else:
return content
except requests.exceptions.RequestException as e:
# 获取状态码并处理不同的重试逻辑
status_code = response.status_code if response is not None else None
print(f"请求失败,状态码: {status_code}")
print("请求失败,完整的响应内容如下:")
if response is not None:
print(response.text) # 打印原始的响应内容,可能是 JSON 格式,也可能是其他格式
# 如果是 429 错误
if status_code == 429:
if attempt < max_retries_429:
wait_time = 2 if attempt == 0 else 4
print(f"状态码为 429等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
print(f"状态码为 429已达到最大重试次数 {max_retries_429} 次。")
break # 超过最大重试次数,退出循环
else:
# 针对其他错误
if attempt < max_retries_other:
print("非 429 错误,等待 1 秒后重试...")
time.sleep(1)
else:
print(f"非 429 错误,已达到最大重试次数 {max_retries_other} 次。")
break # 超过最大重试次数,退出循环
attempt += 1 # 增加重试计数
# 如果到这里,说明所有尝试都失败了
print(f"请求失败,已达到最大重试次数。")
if need_extra:
return None, 0
else:
return None