2024-12-30 10:26:08 +08:00

421 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import re
import time
import fitz
import PyPDF2
import tempfile
import requests
from ratelimit import sleep_and_retry, limits
from flask_app.general.clean_pdf import extract_common_header, clean_page_content
from flask_app.general.table_ocr import CommonOcr
# 调用豆包对json形式的表格数据进行重构
def extract_img_table_text(ocr_result_pages):
print(ocr_result_pages)
base_prompt = '''
任务你负责解析以json形式输入的表格信息根据提供的文件内容来恢复表格信息并输出不要遗漏任何一个文字。
要求与指南:
1. 请运用文档表格理解能力,根据文件内容定位表格的每列和每行,列与列之间用丨分隔,若某列或者某行没有信息则用/填充。
2. 请不要遗漏任何一个文字,同时不要打乱行与行之间的顺序,也不要打乱列与列之间的顺序,严格按照文字的位置信息来恢复表格信息。
示例输出:
表格标题:
|序号|名称|数量|单位|单价(元)|总价(元)|技术参数|备注|
|形象展示区|
|1|公园主E题雕塑|1|套|/|/|根据江夏体育与文化元素定制|/|
..........
'''
base_prompt += f"\n\n文件内容:\n{json.dumps(ocr_result_pages, ensure_ascii=False, indent=4)}"
model_res = doubao_model(base_prompt)
return model_res
# 判断pdf中是否有图片, 并输出含有图片的页面列表
def has_images(pdf_path):
# 打开PDF文件
pdf_document = fitz.open(pdf_path)
# 存储包含图片的页面页数
pages_with_imgs = {}
# 遍历PDF的每一页
for page_num in range(pdf_document.page_count):
page = pdf_document.load_page(page_num)
# 获取页面的图片列表
images = page.get_images(full=True)
# 如果页面中有图片返回True
if images:
pages_with_imgs[page_num + 1] = images
# 如果遍历了所有页面都没有图片则返回False
return pages_with_imgs
# 调用通用表格识别对图片中的表格进行提取放回json形式的表格结构
def table_ocr_extract(image_path):
table_ocr = CommonOcr(img_path=image_path) # 创建时传递 img_path
return table_ocr.recognize()
# def ocr_extract(image_path):
# # 调用您的OCR引擎来识别图像中的文本
# # return OcrEngine.recognize_text_from_image(image_path)
# # 调用本地ocr
# return local_ocr.run(image_path)
# 提取pdf中某一页的所有图片
def extract_images_from_page(pdf_path, image_list, page_num):
images = []
try:
doc = fitz.open(pdf_path)
for img in image_list:
xref = img[0]
base_image = doc.extract_image(xref)
image_bytes = base_image['image']
image_ext = base_image.get('ext', 'png')
images.append({'data': image_bytes, 'ext': image_ext, 'page_num': page_num + 1})
except Exception as e:
print(f"提取图片时出错: {e}")
return images
def pdf_image2txt(file_path, img_pdf_list):
common_header = extract_common_header(file_path)
# print(f"公共抬头:{common_header}")
# print("--------------------正文开始-------------------")
result = ""
pdf_document = fitz.open(file_path)
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
num_pages = len(reader.pages)
# print(f"Total pages: {num_pages}")
for page_num in range(num_pages):
page = reader.pages[page_num]
# text = page.extract_text()
text = page.extract_text() or ""
cleaned_text = clean_page_content(text, common_header)
# # print(f"--------第{page_num}页-----------")
if (page_num + 1) in img_pdf_list:
print(f"{page_num + 1} 页含有图片开始提取图片并OCR")
images = extract_images_from_page(file_path, img_pdf_list[page_num + 1], page_num)
for img in images:
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.' + img['ext']) as temp_image:
temp_image.write(img['data'])
temp_image.flush()
# 调用OCR函数
ocr_result = table_ocr_extract(temp_image.name)
ocr_result = json.loads(ocr_result)
# 判断是否提取成功并且 pages 中有数据
if ocr_result['code'] == 200 and len(ocr_result['result']['pages']) > 0:
print("提取成功,图片数据已提取。")
ocr_result_pages = ocr_result['result']['pages']
table_text = extract_img_table_text(ocr_result_pages)
if table_text.strip():
cleaned_text += "\n" + table_text
else:
print("提取失败或没有页面数据。")
except Exception as e:
print(f"OCR处理失败: {e}")
finally:
try:
os.remove(temp_image.name)
except Exception as e:
print(f"删除临时文件失败: {e}")
result += cleaned_text
directory = os.path.dirname(os.path.abspath(file_path))
output_path = os.path.join(directory, 'extract.txt')
# 将结果保存到 extract.txt 文件中
try:
with open(output_path, 'w', encoding='utf-8') as output_file:
output_file.write(result)
print(f"提取内容已保存到: {output_path}")
except IOError as e:
print(f"写入文件时发生错误: {e}")
# 返回保存的文件路径
return output_path
def pdf2txt(file_path):
common_header = extract_common_header(file_path)
# print(f"公共抬头:{common_header}")
# print("--------------------正文开始-------------------")
result = ""
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
num_pages = len(reader.pages)
# print(f"Total pages: {num_pages}")
for page_num in range(num_pages):
page = reader.pages[page_num]
text = page.extract_text()
if text:
# print(f"--------第{page_num}页-----------")
cleaned_text = clean_page_content(text,common_header)
# print(cleaned_text)
result += cleaned_text
# print(f"Page {page_num + 1} Content:\n{cleaned_text}")
else:
print(f"Page {page_num + 1} is empty or text could not be extracted.")
directory = os.path.dirname(os.path.abspath(file_path))
output_path = os.path.join(directory, 'extract.txt')
# 将结果保存到 extract.txt 文件中
try:
with open(output_path, 'w', encoding='utf-8') as output_file:
output_file.write(result)
print(f"提取内容已保存到: {output_path}")
except IOError as e:
print(f"写入文件时发生错误: {e}")
# 返回保存的文件路径
return output_path
def read_txt_to_string(file_path):
"""
读取txt文件内容并返回一个包含所有内容的字符串保持原有格式。
参数:
- file_path (str): txt文件的路径
返回:
- str: 包含文件内容的字符串
"""
try:
with open(file_path, 'r', encoding='utf-8') as file: # 确保使用适当的编码
content = file.read() # 使用 read() 保持文件格式
return content
except FileNotFoundError:
return "错误:文件未找到。"
except Exception as e:
return f"错误:读取文件时发生错误。详细信息:{e}"
def count_tokens(text):
"""
统计文本中的 tokens 数量:
1. 英文字母+数字作为一个 token如 DN90
2. 数字+小数点/百分号作为一个 token如 0.25%)。
3. 单个中文字符作为一个 token。
4. 单个符号或标点符号作为一个 token。
5. 忽略空白字符(空格、空行等)。
"""
# 正则表达式:
# - 英文字母和数字组合DN90
# - 数字+小数点/百分号组合0.25%
# - 单个中文字符:[\u4e00-\u9fff]
# - 单个非空白符号:[^\s]
token_pattern = r'[a-zA-Z0-9]+(?:\.\d+)?%?|[\u4e00-\u9fff]|[^\s]'
tokens = re.findall(token_pattern, text)
return len(tokens)# 返回 tokens 数量和匹配的 token 列表
def get_total_tokens(text):
"""
调用 API 计算给定文本的总 Token 数量。
参数:
- text (str): 需要计算 Token 的文本。
- model (str): 使用的模型名称,默认值为 "ep-20241119121710-425g6"
返回:
- int: 文本的 total_tokens 数量。
"""
# API 请求 URL
url = "https://ark.cn-beijing.volces.com/api/v3/tokenization"
# 获取 API 密钥
doubao_api_key = os.getenv("DOUBAO_API_KEY")
if not doubao_api_key:
raise ValueError("DOUBAO_API_KEY 环境变量未设置")
# 请求头
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + doubao_api_key
}
model = "ep-20241119121710-425g6"
# 请求体
payload = {
"model": model,
"text": [text] # API 文档中要求 text 是一个列表
}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
response_data = response.json()
total_tokens=response_data["data"][0]["total_tokens"]
return total_tokens
except Exception as e:
print(f"获取 Token 数量失败:{e}")
return 0
@sleep_and_retry
@limits(calls=10, period=1) # 每秒最多调用10次
def doubao_model(full_user_query, need_extra=False):
print("call doubao...")
# 相关参数
url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
doubao_api_key = os.getenv("DOUBAO_API_KEY")
# 定义主模型和备用模型
models = {
"pro_32k": "ep-20241119121710-425g6", # 豆包Pro 32k模型
"pro_128k": "ep-20241119121743-xt6wg" # 128k模型
}
# 判断用户查询字符串的长度
token_count = get_total_tokens(full_user_query)
if token_count > 31500:
selected_model = models["pro_128k"] # 如果长度超过32k直接使用128k模型
else:
selected_model = models["pro_32k"] # 默认使用32k模型
# 请求头
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + doubao_api_key
}
max_retries_429 = 2 # 针对 429 错误的最大重试次数
max_retries_other = 1 # 针对其他错误的最大重试次数
attempt = 0
response = None # 确保 response 被定义
while True:
# 请求数据
data = {
"model": selected_model,
"messages": [
{
"role": "user",
"content": full_user_query
}
],
"temperature": 0.2
}
try:
response = requests.post(url, headers=headers, json=data) # 设置超时时间为10秒
response.raise_for_status() # 如果响应状态码不是200将引发HTTPError
# 获取响应 JSON
response_json = response.json()
# 获取返回内容
content = response_json["choices"][0]["message"]["content"]
# 获取 completion_tokens
completion_tokens = response_json["usage"].get("completion_tokens", 0)
# 根据 need_extra 返回不同的结果
if need_extra:
return content, completion_tokens
else:
return content
except requests.exceptions.RequestException as e:
# 获取状态码并处理不同的重试逻辑
status_code = response.status_code if response is not None else None
print(f"请求失败,状态码: {status_code}")
print("请求失败,完整的响应内容如下:")
if response is not None:
print(response.text) # 打印原始的响应内容,可能是 JSON 格式,也可能是其他格式
# 如果是 429 错误
if status_code == 429:
if attempt < max_retries_429:
wait_time = 2 if attempt == 0 else 4
print(f"状态码为 429等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
print(f"状态码为 429已达到最大重试次数 {max_retries_429} 次。")
break # 超过最大重试次数,退出循环
else:
# 针对其他错误
if attempt < max_retries_other:
print("非 429 错误,等待 1 秒后重试...")
time.sleep(1)
else:
print(f"非 429 错误,已达到最大重试次数 {max_retries_other} 次。")
break # 超过最大重试次数,退出循环
attempt += 1 # 增加重试计数
# 如果到这里,说明所有尝试都失败了
print(f"请求失败,已达到最大重试次数。")
if need_extra:
return None, 0
else:
return None
def generate_full_user_query(file_path, prompt_template):
"""
根据文件路径和提示词模板生成完整的user_query。
参数:
- file_path (str): 需要解析的文件路径。
- prompt_template (str): 包含{full_text}占位符的提示词模板。
返回:
- str: 完整的user_query。
"""
# 假设extract_text_by_page已经定义用于提取文件内容
full_text=read_txt_to_string(file_path)
# 格式化提示词,将提取的文件内容插入到模板中
user_query = prompt_template.format(full_text=full_text)
return user_query
#7.文件内容为markdown格式 表格特殊情况处理对于表格数据可能存在原始pdf转换markdown时跨页导致同一个货物名称或系统名称分隔在上下两个单元格内你需要通过上下文语义判断是否合并之后才是完整且正确的货物名称或系统名称
if __name__ == "__main__":
txt_path = r"output.txt"
pdf_path_1 = "D:/bid_generator/task_folder/9a447eb0-24b8-4f51-8164-d91a62edea25/tmp/bid_format.pdf"
pdf_path_2 = r"C:\Users\Administrator\Desktop\货物标\output1\竞争性谈判文件_procurement.pdf"
prompt_template = '''
任务解析采购文件提取采购需求并以JSON格式返回。
要求与指南:
1. 精准定位:运用文档理解能力,找到文件中的采购需求部分。
2. 系统归属:若货物明确属于某个系统,则将其作为该系统的二级键。
3. 非清单形式处理:若未出现采购清单,则从表格或文字中摘取系统和货物信息。
4. 软件需求:对于软件应用需求,列出系统模块构成,并作为系统键值的一部分。
5. 系统功能:若文中提及系统功能,则在系统值中添加'系统功能'二级键,不展开具体内容。
6. 完整性:确保不遗漏系统内的货物,也不添加未提及的内容。
输出格式:
1.JSON格式最外层键名为'采购需求'
2.嵌套键名为系统或货物名称,与原文保持一致。
3.键值应为空对象({{}}),仅返回名称。
4.不包含'说明''规格''技术参数'等列内容。
5.层次关系用嵌套键值对表示。
6.最后一级键内值留空或填'未知'(如数量较多或未知内容)。
特殊情况处理:
同一层级下同名但采购要求不同的货物,以'货物名-编号'区分编号从1递增。
示例输出结构:
{{
"采购需求": {{
"交换机-1": {{}},
"交换机-2": {{}},
"门禁管理系统": {{
// 可包含其他货物或模块
}},
"交通监控视频子系统": {{
"系统功能": {{}},
"高清视频抓拍像机": {{}},
"补光灯": {{}}
}},
"LED全彩显示屏": {{}}
// 其他系统和货物
}}
}}
文件内容(已包含):{full_text}
注意事项:
1.严格按照上述要求执行,确保输出准确性和规范性。
2.如有任何疑问或不确定内容,请保留原文描述,必要时使用'未知'标注。
'''
# processed_filepath = convert_pdf_to_markdown(pdf_path_2) # 转markdown格式
# processed_filepath = pdf2txt(pdf_path_2) #纯文本提取
# user_query=generate_full_user_query(processed_filepath,prompt_template)
# user_query="一年有多少天?"
# res=doubao_model(user_query)
res=get_total_tokens("hh我是天才")
print(res)
# print("--------------------")
# print(user_query)