256 lines
11 KiB
Python
256 lines
11 KiB
Python
import json
|
||
import logging
|
||
import threading
|
||
from ratelimit import limits, sleep_and_retry
|
||
import time
|
||
from pathlib import Path
|
||
from openai import OpenAI
|
||
import os
|
||
from flask_app.general.llm.大模型通用函数 import extract_error_details, shared_rate_limit
|
||
|
||
file_write_lock = threading.Lock()
|
||
@sleep_and_retry
|
||
@limits(calls=2, period=1)
|
||
def upload_file(file_path,output_folder=""):
|
||
"""
|
||
Uploads a file to DashScope and returns the file ID.
|
||
Additionally, saves the file ID to 'file_ids.txt' in the given output folder.
|
||
"""
|
||
if not output_folder:
|
||
output_folder=os.path.dirname(file_path)
|
||
client = OpenAI(
|
||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
)
|
||
|
||
# 上传文件并获取 file_id
|
||
file = client.files.create(file=Path(file_path), purpose="file-extract")
|
||
file_id = file.id
|
||
|
||
# 创建output_folder路径,如果它不存在
|
||
if not os.path.exists(output_folder):
|
||
os.makedirs(output_folder)
|
||
|
||
# 确保文件写入是互斥的
|
||
with file_write_lock: # 在这个代码块中,其他线程无法进入
|
||
file_ids_path = os.path.join(output_folder, 'file_ids.txt')
|
||
# 如果文件不存在,就创建它并写入 file_id
|
||
with open(file_ids_path, 'a') as f:
|
||
f.write(f'{file_id}\n')
|
||
|
||
return file_id
|
||
|
||
@shared_rate_limit
|
||
def qianwen_long(file_id, user_query, max_retries=2, backoff_factor=1.0, need_extra=False):
|
||
logger = logging.getLogger('model_log') # 通过日志名字获取记录器
|
||
"""
|
||
基于上传的文件 ID 和用户查询生成响应,并在失败时自动重试。
|
||
参数:
|
||
- file_id: 上传文件的 ID
|
||
- user_query: 用户查询
|
||
- max_retries: 最大重试次数(默认 2 次)
|
||
- backoff_factor: 指数退避的基础等待时间(默认 1.0 秒)
|
||
- need_extra: 是否需要返回额外数据(默认 False)
|
||
"""
|
||
client = OpenAI(
|
||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
)
|
||
for attempt in range(1, max_retries + 2): # +1 是为了包括初始调用
|
||
try:
|
||
# 调用 API
|
||
completion = client.chat.completions.create(
|
||
model="qwen-long",
|
||
temperature=0.5,
|
||
messages=[
|
||
{
|
||
'role': 'system',
|
||
'content': f'fileid://{file_id}'
|
||
},
|
||
{
|
||
'role': 'user',
|
||
'content': user_query
|
||
}
|
||
],
|
||
stream=False
|
||
)
|
||
token_usage = completion.usage.completion_tokens #生成的tokens,不算输入的
|
||
# 如果调用成功,返回响应内容
|
||
if need_extra:
|
||
return completion.choices[0].message.content, token_usage
|
||
else:
|
||
return completion.choices[0].message.content
|
||
|
||
except Exception as exc:
|
||
# 提取错误代码
|
||
error_code, error_code_string,request_id = extract_error_details(str(exc))
|
||
logger.error(
|
||
f"第 {attempt} 次尝试失败,查询:'{user_query}',error_code:{error_code}, request_id={request_id}",
|
||
exc_info=True
|
||
)
|
||
|
||
if error_code == 429: # 超 qps/qpm
|
||
if attempt <= max_retries:
|
||
sleep_time = backoff_factor * (2 ** (attempt - 1)) # 指数退避
|
||
logger.warning(f"错误代码为 429,将在 {sleep_time} 秒后重试...")
|
||
time.sleep(sleep_time)
|
||
else:
|
||
print(f"查询 '{user_query}' 的所有 {max_retries + 1} 次尝试均失败(429 错误)。")
|
||
break
|
||
elif error_code == 400 and error_code_string in [
|
||
'data_inspection_failed', 'ResponseTimeout', 'DataInspectionFailed',
|
||
'response_timeout', 'request_timeout', "RequestTimeOut","invalid_parameter_error"
|
||
]:
|
||
logger.warning(f"错误代码为 400 - {error_code_string},将调用 qianwen_long_stream 执行一次...")
|
||
try:
|
||
# 超时就调用 qianwen_long_stream
|
||
stream_result = qianwen_long_stream(file_id, user_query, max_retries=0,backoff_factor=1,need_extra=need_extra) # 禁用内部重试
|
||
if need_extra:
|
||
if isinstance(stream_result, tuple) and len(stream_result) == 2:
|
||
return stream_result[0], stream_result[1] # 返回内容和默认的 token_usage=0
|
||
else:
|
||
logger.error(f"qianwen_long_stream 返回值不符合预期(需要元组)。返回值: {stream_result}")
|
||
return "", 0 # 处理异常返回
|
||
else:
|
||
return stream_result # need_extra=False,直接返回内容
|
||
except Exception as stream_exc:
|
||
logger.error(f"调用 qianwen_long_stream 时出错:{stream_exc}", exc_info=True)
|
||
break # 跳出循环,不再重试
|
||
else:
|
||
# 对于非 429 和非特定 400 错误,不进行重试,直接抛出异常
|
||
logger.error(f"遇到非 429 或非 'data_inspection_failed' 的 400 错误(错误代码:{error_code}),不进行重试。")
|
||
break
|
||
|
||
# 在所有重试失败的情况下返回
|
||
if need_extra:
|
||
return "", 0 # 返回默认值和 token_usage = 0
|
||
else:
|
||
return ""
|
||
|
||
@shared_rate_limit
|
||
def qianwen_long_stream(file_id, user_query, max_retries=2, backoff_factor=1.0, need_extra=False):
|
||
logger = logging.getLogger('model_log') # 通过日志名字获取记录器
|
||
"""
|
||
使用之前上传的文件,根据用户查询生成响应,并实时显示流式输出。
|
||
参数:
|
||
- file_id: 上传文件的 ID
|
||
- user_query: 用户查询
|
||
- max_retries: 最大重试次数(默认 2 次)
|
||
- backoff_factor: 指数退避的基础等待时间(默认 1.0 秒)
|
||
- need_extra: 是否需要返回额外数据(默认 False)
|
||
返回:
|
||
- 当 need_extra=False 时: 返回响应内容 (str)
|
||
- 当 need_extra=True 时: 返回 (响应内容, token_usage)
|
||
"""
|
||
|
||
client = OpenAI(
|
||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
)
|
||
for attempt in range(1, max_retries + 2): # +1 是为了包括初始调用
|
||
try:
|
||
completion_tokens = 0 # 初始化 completion_tokens 为 0
|
||
# 生成基于文件ID的响应
|
||
completion = client.chat.completions.create(
|
||
model="qwen-long",
|
||
temperature=0.4,
|
||
messages=[
|
||
{
|
||
'role': 'system',
|
||
'content': f'fileid://{file_id}'
|
||
},
|
||
{
|
||
'role': 'user',
|
||
'content': user_query
|
||
}
|
||
],
|
||
stream=True, # 启用流式响应
|
||
stream_options={"include_usage": True}
|
||
)
|
||
|
||
full_response = "" # 用于存储完整的响应内容
|
||
|
||
for chunk in completion:
|
||
# print(chunk.model_dump_json())
|
||
if hasattr(chunk, 'to_dict'):
|
||
chunk_data = chunk.to_dict()
|
||
else:
|
||
chunk_data = json.loads(chunk.model_dump_json())
|
||
|
||
# 处理 usage 信息
|
||
usage = chunk_data.get('usage')
|
||
if usage is not None:
|
||
completion_tokens = usage.get('completion_tokens', 0) #生成的tokens,不算输入的
|
||
|
||
# 处理 choices 信息
|
||
choices = chunk_data.get('choices', [])
|
||
if choices:
|
||
choice = choices[0]
|
||
delta = choice.get('delta', {})
|
||
content = delta.get('content', '')
|
||
if content:
|
||
full_response += content
|
||
# 实时打印内容(可以取消注释下面一行以实时输出)
|
||
# print(content, end='', flush=True)
|
||
if choice.get('finish_reason'):
|
||
# 不再提前跳出循环,允许处理最后一个包含 usage 的块
|
||
pass # 或者记录 finish_reason 以供后续使用
|
||
|
||
if need_extra:
|
||
return full_response, completion_tokens
|
||
else:
|
||
return full_response
|
||
|
||
except Exception as exc:
|
||
# 提取错误代码
|
||
error_code, error_code_string,request_id = extract_error_details(str(exc))
|
||
logger.error(
|
||
f"第 {attempt} 次尝试失败,查询:'{user_query}',error_code:{error_code}, request_id={request_id}",
|
||
exc_info=True
|
||
)
|
||
|
||
if error_code == 429:
|
||
if attempt <= max_retries:
|
||
sleep_time = backoff_factor * (2 ** (attempt - 1)) # 指数退避
|
||
logger.warning(f"错误代码为 429,将在 {sleep_time} 秒后重试...")
|
||
time.sleep(sleep_time)
|
||
else:
|
||
logger.error(f"查询 '{user_query}' 的所有 {max_retries + 1} 次尝试均失败(429 错误)。")
|
||
break
|
||
elif error_code == 400 and error_code_string in [
|
||
'data_inspection_failed', 'ResponseTimeout', 'DataInspectionFailed',
|
||
'response_timeout', 'request_timeout', "RequestTimeOut","invalid_parameter_error"
|
||
]:
|
||
if attempt == 1: # 只重试一次
|
||
logger.warning(f"错误代码为 400 - {error_code_string},将立即重试...")
|
||
continue # 直接跳到下一次循环(即重试一次)
|
||
else:
|
||
logger.error(f"查询 '{user_query}' 的所有 {max_retries + 1} 次尝试均失败(400 - {error_code_string})。")
|
||
break
|
||
else:
|
||
# 对于非 429 和非特定 400 错误,不进行重试,直接抛出异常
|
||
logger.error(f"遇到非 429 或非 'data_inspection_failed...' 的 400 错误(错误代码:{error_code}),不进行重试。")
|
||
break
|
||
|
||
# 如果所有尝试都失败了,返回空字符串
|
||
if need_extra:
|
||
return "", 0
|
||
else:
|
||
return ""
|
||
|
||
if __name__ == "__main__":
|
||
# Example file path - replace with your actual file path
|
||
|
||
file_path = r"C:\Users\Administrator\Downloads\2022-广东-鹏华基金管理有限公司深圳深业上城办公室装修项目.pdf"
|
||
file_id = upload_file(file_path)
|
||
# print(file_id)
|
||
user_query1 ="该招标文件的项目编号是什么?"
|
||
|
||
# # res1,res2=qianwen_long_stream(file_id,user_query1,2,1,True)
|
||
# res1,res2= qianwen_long_stream(file_id, user_query1, 2, 1,True)
|
||
res=qianwen_long(file_id,user_query1)
|
||
print(res)
|
||
# end_time=time.time()
|
||
# print("elapsed time:"+str(end_time-start_time))
|
||
|