335 lines
14 KiB
Python
335 lines
14 KiB
Python
import ast
|
||
import json
|
||
import logging
|
||
import re
|
||
import threading
|
||
from functools import wraps
|
||
from ratelimit import limits, sleep_and_retry
|
||
import time
|
||
from pathlib import Path
|
||
from openai import OpenAI
|
||
import os
|
||
|
||
file_write_lock = threading.Lock()
|
||
@sleep_and_retry
|
||
@limits(calls=2, period=1)
|
||
def upload_file(file_path,output_folder=""):
|
||
"""
|
||
Uploads a file to DashScope and returns the file ID.
|
||
Additionally, saves the file ID to 'file_ids.txt' in the given output folder.
|
||
"""
|
||
if not output_folder:
|
||
output_folder=os.path.dirname(file_path)
|
||
client = OpenAI(
|
||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
)
|
||
|
||
# 上传文件并获取 file_id
|
||
file = client.files.create(file=Path(file_path), purpose="file-extract")
|
||
file_id = file.id
|
||
|
||
# 创建output_folder路径,如果它不存在
|
||
if not os.path.exists(output_folder):
|
||
os.makedirs(output_folder)
|
||
|
||
# 确保文件写入是互斥的
|
||
with file_write_lock: # 在这个代码块中,其他线程无法进入
|
||
file_ids_path = os.path.join(output_folder, 'file_ids.txt')
|
||
# 如果文件不存在,就创建它并写入 file_id
|
||
with open(file_ids_path, 'a') as f:
|
||
f.write(f'{file_id}\n')
|
||
|
||
return file_id
|
||
|
||
@sleep_and_retry
|
||
@limits(calls=8, period=1) # 每秒最多调用4次
|
||
def rate_limiter():
|
||
pass # 这个函数本身不执行任何操作,只用于限流
|
||
|
||
# 创建一个共享的装饰器
|
||
def shared_rate_limit(func):
|
||
@wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
rate_limiter() # 通过共享的限流器
|
||
return func(*args, **kwargs)
|
||
return wrapper
|
||
|
||
def extract_error_details(error_message):
|
||
"""
|
||
从错误消息中提取错误代码和内部错误代码。
|
||
假设错误消息的格式包含 'Error code: XXX - {...}'
|
||
"""
|
||
# 提取数值型错误代码
|
||
error_code_match = re.search(r'Error code:\s*(\d+)', error_message)
|
||
error_code = int(error_code_match.group(1)) if error_code_match else None
|
||
|
||
# 提取内部错误代码字符串(如 'data_inspection_failed')
|
||
error_code_string = None
|
||
error_dict_match = re.search(r'Error code:\s*\d+\s*-\s*(\{.*\})', error_message)
|
||
if error_dict_match:
|
||
error_dict_str = error_dict_match.group(1)
|
||
try:
|
||
# 使用 ast.literal_eval 解析字典字符串
|
||
error_dict = ast.literal_eval(error_dict_str)
|
||
error_code_string = error_dict.get('error', {}).get('code')
|
||
print(error_code_string)
|
||
except Exception as e:
|
||
print(f"解析错误消息失败: {e}")
|
||
|
||
return error_code, error_code_string
|
||
@shared_rate_limit
|
||
def qianwen_long(file_id, user_query, max_retries=2, backoff_factor=1.0, need_extra=False):
|
||
logger = logging.getLogger('model_log') # 通过日志名字获取记录器
|
||
"""
|
||
基于上传的文件 ID 和用户查询生成响应,并在失败时自动重试。
|
||
参数:
|
||
- file_id: 上传文件的 ID
|
||
- user_query: 用户查询
|
||
- max_retries: 最大重试次数(默认 2 次)
|
||
- backoff_factor: 指数退避的基础等待时间(默认 1.0 秒)
|
||
- need_extra: 是否需要返回额外数据(默认 False)
|
||
"""
|
||
client = OpenAI(
|
||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
)
|
||
for attempt in range(1, max_retries + 2): # +1 是为了包括初始调用
|
||
try:
|
||
# 调用 API
|
||
completion = client.chat.completions.create(
|
||
model="qwen-long",
|
||
temperature=0.5,
|
||
messages=[
|
||
{
|
||
'role': 'system',
|
||
'content': f'fileid://{file_id}'
|
||
},
|
||
{
|
||
'role': 'user',
|
||
'content': user_query
|
||
}
|
||
],
|
||
stream=False
|
||
)
|
||
token_usage = completion.usage.completion_tokens
|
||
# 如果调用成功,返回响应内容
|
||
if need_extra:
|
||
return completion.choices[0].message.content, token_usage
|
||
else:
|
||
return completion.choices[0].message.content
|
||
|
||
except Exception as exc:
|
||
# 提取错误代码
|
||
error_code, error_code_string = extract_error_details(str(exc))
|
||
logger.error(f"第 {attempt} 次尝试失败,查询:'{user_query}',错误:{exc}", exc_info=True)
|
||
|
||
if error_code == 429: # 超 qps/qpm
|
||
if attempt <= max_retries:
|
||
sleep_time = backoff_factor * (2 ** (attempt - 1)) # 指数退避
|
||
logger.warning(f"错误代码为 429,将在 {sleep_time} 秒后重试...")
|
||
time.sleep(sleep_time)
|
||
else:
|
||
print(f"查询 '{user_query}' 的所有 {max_retries + 1} 次尝试均失败(429 错误)。")
|
||
break
|
||
elif error_code == 400 and error_code_string in [
|
||
'data_inspection_failed', 'ResponseTimeout', 'DataInspectionFailed',
|
||
'response_timeout', 'request_timeout', "RequestTimeOut"
|
||
]:
|
||
logger.warning(f"错误代码为 400 - {error_code_string},将调用 qianwen_long_stream 执行一次...")
|
||
try:
|
||
# 超时就调用 qianwen_long_stream
|
||
stream_result = qianwen_long_stream(file_id, user_query, max_retries=0) # 禁用内部重试
|
||
if need_extra:
|
||
if isinstance(stream_result, tuple) and len(stream_result) == 2:
|
||
return stream_result[0], stream_result[1] # 返回内容和默认的 token_usage=0
|
||
else:
|
||
logger.error(f"qianwen_long_stream 返回值不符合预期(需要元组)。返回值: {stream_result}")
|
||
return "", 0 # 处理异常返回
|
||
else:
|
||
return stream_result # need_extra=False,直接返回内容
|
||
except Exception as stream_exc:
|
||
logger.error(f"调用 qianwen_long_stream 时出错:{stream_exc}", exc_info=True)
|
||
break # 跳出循环,不再重试
|
||
else:
|
||
# 对于非 429 和非特定 400 错误,不进行重试,直接抛出异常
|
||
logger.error(f"遇到非 429 或非 'data_inspection_failed' 的 400 错误(错误代码:{error_code}),不进行重试。")
|
||
break
|
||
|
||
# 在所有重试失败的情况下返回
|
||
if need_extra:
|
||
return "", 0 # 返回默认值和 token_usage = 0
|
||
else:
|
||
return ""
|
||
|
||
@shared_rate_limit
|
||
def qianwen_long_stream(file_id, user_query, max_retries=2, backoff_factor=1.0, need_extra=False):
|
||
logger = logging.getLogger('model_log') # 通过日志名字获取记录器
|
||
"""
|
||
使用之前上传的文件,根据用户查询生成响应,并实时显示流式输出。
|
||
参数:
|
||
- file_id: 上传文件的 ID
|
||
- user_query: 用户查询
|
||
- max_retries: 最大重试次数(默认 2 次)
|
||
- backoff_factor: 指数退避的基础等待时间(默认 1.0 秒)
|
||
- need_extra: 是否需要返回额外数据(默认 False)
|
||
返回:
|
||
- 当 need_extra=False 时: 返回响应内容 (str)
|
||
- 当 need_extra=True 时: 返回 (响应内容, token_usage)
|
||
"""
|
||
print("调用 qianwen-long stream...")
|
||
|
||
client = OpenAI(
|
||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
)
|
||
for attempt in range(1, max_retries + 2): # +1 是为了包括初始调用
|
||
try:
|
||
completion_tokens = 0 # 初始化 completion_tokens 为 0
|
||
# 生成基于文件ID的响应
|
||
completion = client.chat.completions.create(
|
||
model="qwen-long",
|
||
temperature=0.4,
|
||
messages=[
|
||
{
|
||
'role': 'system',
|
||
'content': f'fileid://{file_id}'
|
||
},
|
||
{
|
||
'role': 'user',
|
||
'content': user_query
|
||
}
|
||
],
|
||
stream=True, # 启用流式响应
|
||
stream_options={"include_usage": True}
|
||
)
|
||
|
||
full_response = "" # 用于存储完整的响应内容
|
||
|
||
for chunk in completion:
|
||
# print(chunk.model_dump_json())
|
||
if hasattr(chunk, 'to_dict'):
|
||
chunk_data = chunk.to_dict()
|
||
else:
|
||
chunk_data = json.loads(chunk.model_dump_json())
|
||
|
||
# 处理 usage 信息
|
||
usage = chunk_data.get('usage')
|
||
if usage is not None:
|
||
completion_tokens = usage.get('completion_tokens', 0)
|
||
|
||
# 处理 choices 信息
|
||
choices = chunk_data.get('choices', [])
|
||
if choices:
|
||
choice = choices[0]
|
||
delta = choice.get('delta', {})
|
||
content = delta.get('content', '')
|
||
if content:
|
||
full_response += content
|
||
# 实时打印内容(可以取消注释下面一行以实时输出)
|
||
# print(content, end='', flush=True)
|
||
if choice.get('finish_reason'):
|
||
# 不再提前跳出循环,允许处理最后一个包含 usage 的块
|
||
pass # 或者记录 finish_reason 以供后续使用
|
||
|
||
if need_extra:
|
||
return full_response, completion_tokens
|
||
else:
|
||
return full_response
|
||
|
||
except Exception as exc:
|
||
# 提取错误代码
|
||
error_code, error_code_string = extract_error_details(str(exc))
|
||
logger.error(f"第 {attempt} 次尝试失败,查询:'{user_query}',错误:{exc}", exc_info=True)
|
||
|
||
if error_code == 429:
|
||
if attempt <= max_retries:
|
||
sleep_time = backoff_factor * (2 ** (attempt - 1)) # 指数退避
|
||
logger.warning(f"错误代码为 429,将在 {sleep_time} 秒后重试...")
|
||
time.sleep(sleep_time)
|
||
else:
|
||
logger.error(f"查询 '{user_query}' 的所有 {max_retries + 1} 次尝试均失败(429 错误)。")
|
||
break
|
||
elif error_code == 400 and error_code_string in [
|
||
'data_inspection_failed', 'ResponseTimeout', 'DataInspectionFailed',
|
||
'response_timeout', 'request_timeout', "RequestTimeOut"
|
||
]:
|
||
if attempt == 1: # 只重试一次
|
||
logger.warning(f"错误代码为 400 - {error_code_string},将立即重试...")
|
||
continue # 直接跳到下一次循环(即重试一次)
|
||
else:
|
||
logger.error(f"查询 '{user_query}' 的所有 {max_retries + 1} 次尝试均失败(400 - {error_code_string})。")
|
||
break
|
||
else:
|
||
# 对于非 429 和非特定 400 错误,不进行重试,直接抛出异常
|
||
logger.error(f"遇到非 429 或非 'data_inspection_failed...' 的 400 错误(错误代码:{error_code}),不进行重试。")
|
||
break
|
||
|
||
# 如果所有尝试都失败了,返回空字符串
|
||
if need_extra:
|
||
return "", 0
|
||
else:
|
||
return ""
|
||
|
||
@shared_rate_limit
|
||
def qianwen_long_text(file_id, user_query):
|
||
print("call qianwen-long text...")
|
||
"""
|
||
Uses a previously uploaded file to generate a response based on a user query.
|
||
"""
|
||
client = OpenAI(
|
||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
)
|
||
|
||
# Generate a response based on the file ID
|
||
completion = client.chat.completions.create(
|
||
model="qwen-long",
|
||
# top_p=0.5,
|
||
temperature=0.5,
|
||
messages=[
|
||
{
|
||
'role': 'system',
|
||
'content': f'fileid://{file_id}'
|
||
},
|
||
{
|
||
'role': 'user',
|
||
'content': user_query
|
||
}
|
||
],
|
||
stream=False
|
||
)
|
||
|
||
# Return the response content
|
||
return completion.choices[0].message.content
|
||
|
||
#TODO:若采购需求和评分那块响应超时比较多,考虑都改为流式
|
||
if __name__ == "__main__":
|
||
# Example file path - replace with your actual file path
|
||
|
||
file_path = r"C:\Users\Administrator\Desktop\货物标\截取test\2-招标文件_before.pdf"
|
||
file_id = upload_file(file_path)
|
||
print(file_id)
|
||
user_query1 = "该招标文件前附表中的项目名称是什么,请以json格式返回给我"
|
||
# res1,res2=qianwen_long_stream(file_id,user_query1,2,1,True)
|
||
res1,res2= qianwen_long_stream(file_id, user_query1, 2, 1,True)
|
||
print(res1)
|
||
print(res2)
|
||
#
|
||
#
|
||
# user_query2 = ("请提供文件中关于资格审查的具体内容和标准。")
|
||
# start_time=time.time()
|
||
# # First query
|
||
# print("starting qianwen-long...")
|
||
# result1 ,result2= qianwen_long(file_id, user_query1)
|
||
# print("First Query Result:", result1)
|
||
# print(type(result1))
|
||
# print(result2)
|
||
# # Second query
|
||
# print("starting qianwen-long...")
|
||
# result2 = qianwen_long(file_id, user_query2)
|
||
# print("Second Query Result:", result2)
|
||
# end_time=time.time()
|
||
# print("elapsed time:"+str(end_time-start_time))
|
||
|