|
|
@@ -1,16 +1,96 @@
|
|
|
from config import MODEL_PATH, INFERENCE_URL, INFERENCE_AUTH_TOKEN, INFERENCE_MODEL, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION,PROMPT_EXTRACT_SUPPLIER,PROMPT_EXTRACT_ICON
|
|
|
|
|
|
from io import BytesIO
|
|
|
+from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
import base64
|
|
|
import json
|
|
|
+import logging
|
|
|
+import os
|
|
|
from PIL import Image, ImageFilter, ImageEnhance
|
|
|
import time
|
|
|
+import re
|
|
|
import requests
|
|
|
|
|
|
-def image_to_base64(pil_image, image_format="JPEG"):
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+# 从环境变量读取图像预处理配置(由 start.py 启动时注入)
|
|
|
+_IMAGE_MAX_SIZE = int(os.environ.get("IMAGE_MAX_SIZE", 512))
|
|
|
+_IMAGE_COMPRESS = os.environ.get("IMAGE_COMPRESS", "false").lower() == "true"
|
|
|
+_IMAGE_COMPRESS_QUALITY = int(os.environ.get("IMAGE_COMPRESS_QUALITY", 70))
|
|
|
+
|
|
|
+# GHS 参考图目录(与 agent.py 同级的上级目录下的 ghs_icons/)
|
|
|
+_GHS_ICONS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ghs_icons")
|
|
|
+_GHS_ICON_NAMES = [f"GHS{i:02d}" for i in range(1, 10)]
|
|
|
+
|
|
|
+def _load_ghs_reference_images() -> dict:
|
|
|
+ """加载 GHS01-GHS09 参考图,返回 {名称: base64} 字典。缺失的图跳过。"""
|
|
|
+ refs = {}
|
|
|
+ for name in _GHS_ICON_NAMES:
|
|
|
+ path = os.path.join(_GHS_ICONS_DIR, f"{name}.png")
|
|
|
+ if os.path.exists(path):
|
|
|
+ with open(path, "rb") as f:
|
|
|
+ refs[name] = base64.b64encode(f.read()).decode("utf-8")
|
|
|
+ else:
|
|
|
+ logger.warning(f"GHS 参考图缺失: {path}")
|
|
|
+ return refs
|
|
|
+
|
|
|
+# 启动时加载一次,避免每次请求重复读文件
|
|
|
+_GHS_REFERENCE_IMAGES: dict = _load_ghs_reference_images()
|
|
|
+
|
|
|
+# 容易混淆的图标对,判断目标图标时同时发送干扰项供模型对比排除
|
|
|
+# GHS09 外形独特(枯树+死鱼),不发送混淆图,避免反向干扰
|
|
|
+_GHS_CONFUSABLE = {
|
|
|
+ "GHS02": ["GHS03"],
|
|
|
+ "GHS03": ["GHS02"],
|
|
|
+ "GHS07": ["GHS08"],
|
|
|
+ "GHS08": ["GHS07", "GHS06", "GHS09"],
|
|
|
+ "GHS09": ["GHS08"],
|
|
|
+}
|
|
|
+
|
|
|
+# 每个图标的判定关键特征,嵌入提示词中强化区分
|
|
|
+_GHS_DISCRIMINATIVE_HINT = {
|
|
|
+ "GHS01": "关键特征:圆形炸弹向四周炸裂,有大量射线和碎片飞散。",
|
|
|
+ "GHS02": "关键特征:只有火焰图案,火焰底部有横线底座。火焰中间没有任何圆圈/圆环,这是与GHS03的根本区别。",
|
|
|
+ "GHS03": (
|
|
|
+ "【GHS03唯一判定标准】:图标内部必须同时存在两个元素:(1)火焰 AND (2)火焰包围的空心大圆圈/圆环(圆圈内部镂空呈白色)。"
|
|
|
+ "缺少圆圈就是GHS02,不是GHS03。"
|
|
|
+ "请先回答:图标中有没有空心圆圈?如果没有圆圈,直接回答NO。"
|
|
|
+ ),
|
|
|
+ "GHS04": (
|
|
|
+ "关键特征:一个横置的粗短圆柱形气瓶/钢瓶,右侧伸出一根细长的阀门管道。"
|
|
|
+ "整体形状类似一个横向的短粗矩形/圆柱体,右侧有细管伸出。"
|
|
|
+ "在小尺寸印刷版本中,整个图案看起来像一个横向的短粗横条(比感叹号的竖线更粗更短,且是水平的)。"
|
|
|
+ "图标内没有火焰、没有感叹号竖线(感叹号是竖向的,气瓶是横向的)、没有人形、没有树。"
|
|
|
+ "即使只看到一个横向的粗短形状,也可以回答YES。"
|
|
|
+ ),
|
|
|
+ "GHS05": (
|
|
|
+ "关键特征:图标内有腐蚀性液体滴落的场景——"
|
|
|
+ "上方有试管/容器,液体向下滴落,腐蚀下方的物体(金属板或手掌)。"
|
|
|
+ "整体看起来像'液体从上往下滴,下方被腐蚀出缺口'的形状。"
|
|
|
+ "印刷版可能很小,但可以看出上方有细管状物体、下方有不规则缺口形状。"
|
|
|
+ "即使细节模糊,只要能看出'滴落腐蚀'的大致形状,就回答YES。"
|
|
|
+ ),
|
|
|
+ "GHS06": "关键特征:骷髅头(空心白色)加下方两根交叉骨头。",
|
|
|
+ "GHS07": '关键特征:只有一个感叹号"!"(上方竖条加下方圆点),无任何其他图形,没有人形,没有骷髅。',
|
|
|
+ "GHS08": (
|
|
|
+ "关键特征:实心黑色人体上半身剪影(有明确的头部+肩膀+躯干轮廓),胸口有白色裂缝/射线向四周放射。"
|
|
|
+ "【重要】树木/植物形状不是GHS08——如果图案看起来像树枝或植物,那是GHS09而不是GHS08,回答NO。"
|
|
|
+ "不是骷髅,不是感叹号,不是树。"
|
|
|
+ ),
|
|
|
+ "GHS09": (
|
|
|
+ "GHS09是环境危害象形图,图标内有两个有机生物形状:"
|
|
|
+ "左边是一棵枯树(竖直树干+向两侧伸出的树枝,整体呈Y形或T形的树状轮廓),"
|
|
|
+ "右边是一条死鱼(横向椭圆形鱼身轮廓,肚皮朝上翻转)。"
|
|
|
+ "这两个形状与其他GHS图标完全不同——其他图标内没有植物或鱼类形状。"
|
|
|
+ "请对照参考图:如果标签中某个菱形图标内的图案与参考图相似(有树形和/或鱼形),回答YES。"
|
|
|
+ ),
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+def image_to_base64(pil_image, image_format="JPEG", quality=95):
|
|
|
"""将PIL Image图像转换为Base64编码"""
|
|
|
buffered = BytesIO()
|
|
|
- pil_image.save(buffered, format=image_format)
|
|
|
+ pil_image.save(buffered, format=image_format, quality=quality)
|
|
|
img_byte_array = buffered.getvalue()
|
|
|
encode_image = base64.b64encode(img_byte_array).decode('utf-8')
|
|
|
return encode_image
|
|
|
@@ -20,7 +100,6 @@ def resize_image(image, max_size=512):
|
|
|
width, height = image.size
|
|
|
max_dim = max(width, height)
|
|
|
|
|
|
- # 如果图像不需要缩小,直接返回
|
|
|
if max_dim <= max_size:
|
|
|
return image
|
|
|
|
|
|
@@ -28,13 +107,8 @@ def resize_image(image, max_size=512):
|
|
|
new_width = int(width * scaling_factor)
|
|
|
new_height = int(height * scaling_factor)
|
|
|
|
|
|
- # 使用 LANCZOS 高质量缩放
|
|
|
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
|
-
|
|
|
- # 应用 UnsharpMask 锐化,补偿缩放损失
|
|
|
resized = resized.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3))
|
|
|
-
|
|
|
- # 轻微增强对比度,提高文字识别率
|
|
|
enhancer = ImageEnhance.Contrast(resized)
|
|
|
resized = enhancer.enhance(1.1)
|
|
|
|
|
|
@@ -44,70 +118,293 @@ class OcrAgent:
|
|
|
def __init__(self):
|
|
|
self._url = INFERENCE_URL
|
|
|
|
|
|
- def extract_single(self, image_base64: str, prompt: str, index: int):
|
|
|
- """单个任务请求,返回 (index, 结果文本)"""
|
|
|
- response = requests.post(
|
|
|
- self._url,
|
|
|
- headers={
|
|
|
- "Authorization": INFERENCE_AUTH_TOKEN,
|
|
|
- "Content-Type": "application/json"
|
|
|
- },
|
|
|
- json={
|
|
|
- "model": INFERENCE_MODEL,
|
|
|
- "messages": [
|
|
|
- {"role": "system", "content": "You are a helpful assistant."},
|
|
|
- {
|
|
|
- "role": "user",
|
|
|
- "content": [
|
|
|
+ def _check_single_icon(self, ghs_name: str, ref_b64: str, image_base64: str, max_retries: int = 2) -> bool:
|
|
|
+ """二分类:判断标签图中是否存在指定的 GHS 象形图,返回 True/False。
|
|
|
+ 对容易混淆的图标,同时发送干扰项参考图,让模型对比排除。
|
|
|
+ """
|
|
|
+ content = [
|
|
|
+ {"type": "text", "text": f"以下是标准 {ghs_name} 象形图的参考图(目标图标):"},
|
|
|
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ref_b64}"}},
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 如果该图标有容易混淆的图标,一并发送供对比
|
|
|
+ confusable = _GHS_CONFUSABLE.get(ghs_name, [])
|
|
|
+ if confusable:
|
|
|
+ content.append({"type": "text", "text": f"以下是容易与 {ghs_name} 混淆的图标,注意区分:"})
|
|
|
+ for conf_name in confusable:
|
|
|
+ if conf_name in _GHS_REFERENCE_IMAGES:
|
|
|
+ content.append({"type": "text", "text": f"({conf_name},不是目标图标)"})
|
|
|
+ content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{_GHS_REFERENCE_IMAGES[conf_name]}"}})
|
|
|
+
|
|
|
+ content.append({"type": "text", "text": "以下是需要识别的化学品安全标签图像:"})
|
|
|
+ content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}})
|
|
|
+
|
|
|
+ discriminative = _GHS_DISCRIMINATIVE_HINT.get(ghs_name, "")
|
|
|
+ confusable_hint = f"注意不要把它和 {'/'.join(confusable)} 混淆。\n" if confusable else ""
|
|
|
+
|
|
|
+ if ghs_name == "GHS03":
|
|
|
+ final_question = (
|
|
|
+ f"请仔细检查最后一张标签图像中的每个菱形图标。\n"
|
|
|
+ f"{discriminative}\n"
|
|
|
+ f"关键问题:标签中是否有图标在火焰图案内部包含了一个明显的空心圆圈/圆环?\n"
|
|
|
+ f"如果有圆圈→YES,如果只有火焰没有圆圈→NO。\n"
|
|
|
+ f"只回答 YES 或 NO,不要输出其他任何内容。"
|
|
|
+ )
|
|
|
+ elif ghs_name == "GHS09":
|
|
|
+ final_question = (
|
|
|
+ f"请仔细观察最后一张化学品安全标签图像中的所有菱形图标。\n"
|
|
|
+ f"参考图是GHS09(环境危害)标准图标,内含枯树和死鱼。\n"
|
|
|
+ f"{discriminative}\n"
|
|
|
+ f"注意:印刷版标签中的GHS09图标为黑白色(无红色边框),"
|
|
|
+ f"图案可能很小,但可以看出树形轮廓(树干+分叉枝条)和/或鱼形(椭圆形鱼身)。\n"
|
|
|
+ f"对照参考图,标签中是否有任何一个菱形图标包含了【枯树】或【死鱼】图案?\n"
|
|
|
+ f"只要识别出树形或鱼形,就回答YES;完全看不出来才回答NO。\n"
|
|
|
+ f"只回答 YES 或 NO,不要输出其他任何内容。"
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ final_question = (
|
|
|
+ f"请仔细观察上方图片。\n"
|
|
|
+ f"第一张是目标图标 {ghs_name} 的标准参考图。\n"
|
|
|
+ f"{discriminative}\n"
|
|
|
+ f"{confusable_hint}"
|
|
|
+ f"请严格对照上述关键特征,判断:最后一张化学品安全标签图像中是否包含 {ghs_name} 图标?\n"
|
|
|
+ f"必须所有关键特征都匹配才回答YES,有任何一条不符合就回答NO。\n"
|
|
|
+ f"只回答 YES 或 NO,不要输出其他任何内容。"
|
|
|
+ )
|
|
|
+ content.append({"type": "text", "text": final_question})
|
|
|
+
|
|
|
+ last_err = None
|
|
|
+ for attempt in range(max_retries + 1):
|
|
|
+ try:
|
|
|
+ response = requests.post(
|
|
|
+ self._url,
|
|
|
+ headers={
|
|
|
+ "Authorization": INFERENCE_AUTH_TOKEN,
|
|
|
+ "Content-Type": "application/json"
|
|
|
+ },
|
|
|
+ json={
|
|
|
+ "model": INFERENCE_MODEL,
|
|
|
+ "messages": [
|
|
|
+ {"role": "system", "content": "You are a helpful assistant. Answer only YES or NO."},
|
|
|
+ {"role": "user", "content": content}
|
|
|
+ ],
|
|
|
+ "max_tokens": 16,
|
|
|
+ "stream": False,
|
|
|
+ "temperature": 0
|
|
|
+ },
|
|
|
+ timeout=600
|
|
|
+ )
|
|
|
+ response.raise_for_status()
|
|
|
+
|
|
|
+ resp_json = response.json()
|
|
|
+ answer = resp_json["choices"][0]["message"]["content"].strip().upper()
|
|
|
+ logger.info(f"[icon binary] {ghs_name} -> {answer}")
|
|
|
+ return answer.startswith("YES")
|
|
|
+
|
|
|
+ except requests.RequestException as e:
|
|
|
+ last_err = e
|
|
|
+ if attempt < max_retries:
|
|
|
+ wait = 2 ** attempt
|
|
|
+ logger.warning(f"[icon binary] {ghs_name} 请求异常: {e},{wait}s 后重试...")
|
|
|
+ time.sleep(wait)
|
|
|
+
|
|
|
+ logger.error(f"[icon binary] {ghs_name} 重试 {max_retries} 次后仍失败: {last_err}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ def _confirm_ghs03(self, image_base64: str) -> bool:
|
|
|
+ """GHS03 二次确认:直接询问标签中是否存在内部含空心圆圈的火焰图标。
|
|
|
+ 用于过滤 _check_single_icon 的 false-positive。"""
|
|
|
+ ghs02_b64 = _GHS_REFERENCE_IMAGES.get("GHS02", "")
|
|
|
+ ghs03_b64 = _GHS_REFERENCE_IMAGES.get("GHS03", "")
|
|
|
+ content = []
|
|
|
+ if ghs02_b64:
|
|
|
+ content.append({"type": "text", "text": "参考图A — GHS02(仅火焰,无圆圈):"})
|
|
|
+ content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ghs02_b64}"}})
|
|
|
+ if ghs03_b64:
|
|
|
+ content.append({"type": "text", "text": "参考图B — GHS03(火焰内部有空心圆圈):"})
|
|
|
+ content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ghs03_b64}"}})
|
|
|
+ content.append({"type": "text", "text": "待检查的化学品安全标签:"})
|
|
|
+ content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}})
|
|
|
+ content.append({"type": "text", "text": (
|
|
|
+ "请仔细观察标签中所有菱形图标内的图案。\n"
|
|
|
+ "问题:标签中是否存在【参考图B】所示的GHS03图标——即火焰图案内部有一个明显的空心圆圈/圆环?\n"
|
|
|
+ "注意:\n"
|
|
|
+ "- 如果只看到【参考图A】类型的纯火焰(无圆圈),回答NO\n"
|
|
|
+ "- 菱形框本身不算圆圈,必须是火焰内部的圆圈\n"
|
|
|
+ "- 如果没有火焰图案,回答NO\n"
|
|
|
+ "只回答 YES 或 NO,不要输出其他内容。"
|
|
|
+ )})
|
|
|
+ try:
|
|
|
+ response = requests.post(
|
|
|
+ self._url,
|
|
|
+ headers={"Authorization": INFERENCE_AUTH_TOKEN, "Content-Type": "application/json"},
|
|
|
+ json={
|
|
|
+ "model": INFERENCE_MODEL,
|
|
|
+ "messages": [
|
|
|
+ {"role": "system", "content": "You are a helpful assistant. Answer only YES or NO."},
|
|
|
+ {"role": "user", "content": content}
|
|
|
+ ],
|
|
|
+ "max_tokens": 16,
|
|
|
+ "stream": False,
|
|
|
+ "temperature": 0
|
|
|
+ },
|
|
|
+ timeout=600
|
|
|
+ )
|
|
|
+ response.raise_for_status()
|
|
|
+ answer = response.json()["choices"][0]["message"]["content"].strip().upper()
|
|
|
+ logger.info(f"[icon GHS03 confirm] -> {answer}")
|
|
|
+ return answer.startswith("YES")
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"[icon GHS03 confirm] 请求失败: {e},保守返回False")
|
|
|
+ return False
|
|
|
+
|
|
|
+ def extract_icon(self, image_base64: str, max_retries: int = 2):
|
|
|
+ """象形图识别:对 GHS01-GHS09 逐个并行做二分类,返回 (0, JSON字符串)。"""
|
|
|
+ with ThreadPoolExecutor(max_workers=len(_GHS_REFERENCE_IMAGES)) as executor:
|
|
|
+ futures = {
|
|
|
+ executor.submit(self._check_single_icon, name, ref_b64, image_base64, max_retries): name
|
|
|
+ for name, ref_b64 in _GHS_REFERENCE_IMAGES.items()
|
|
|
+ }
|
|
|
+ results = {}
|
|
|
+ for future in as_completed(futures):
|
|
|
+ name = futures[future]
|
|
|
+ results[name] = future.result()
|
|
|
+
|
|
|
+ matched = [name for name in _GHS_ICON_NAMES if results.get(name)]
|
|
|
+
|
|
|
+ # GHS02 和 GHS03 互斥:当两者同时出现时保留GHS02、丢弃GHS03
|
|
|
+ if "GHS02" in matched and "GHS03" in matched:
|
|
|
+ logger.info("[icon] GHS02/GHS03 冲突,保留GHS02,丢弃GHS03")
|
|
|
+ matched.remove("GHS03")
|
|
|
+
|
|
|
+ # 对 GHS03 做二次确认,减少 false-positive
|
|
|
+ if "GHS03" in matched:
|
|
|
+ confirmed = self._confirm_ghs03(image_base64)
|
|
|
+ if not confirmed:
|
|
|
+ logger.info("[icon] GHS03 二次确认为 NO,移除")
|
|
|
+ matched.remove("GHS03")
|
|
|
+
|
|
|
+ logger.info(f"[icon] 识别结果: {matched}")
|
|
|
+ return 0, json.dumps({"tag_images": matched}, ensure_ascii=False)
|
|
|
+
|
|
|
+ def extract_single(self, image_base64: str, prompt: str, index: int, max_retries: int = 2):
|
|
|
+ """单个任务请求,返回 (index, 结果文本)。失败时最多重试 max_retries 次。"""
|
|
|
+ last_err = None
|
|
|
+ for attempt in range(max_retries + 1):
|
|
|
+ try:
|
|
|
+ response = requests.post(
|
|
|
+ self._url,
|
|
|
+ headers={
|
|
|
+ "Authorization": INFERENCE_AUTH_TOKEN,
|
|
|
+ "Content-Type": "application/json"
|
|
|
+ },
|
|
|
+ json={
|
|
|
+ "model": INFERENCE_MODEL,
|
|
|
+ "messages": [
|
|
|
+ {"role": "system", "content": "You are a helpful assistant."},
|
|
|
{
|
|
|
- "type": "image_url",
|
|
|
- "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
|
|
|
- },
|
|
|
- {"type": "text", "text": prompt}
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "max_tokens": 4096,
|
|
|
- "stream": False,
|
|
|
- "temperature": 0
|
|
|
- },
|
|
|
- timeout=600
|
|
|
- )
|
|
|
- response.raise_for_status()
|
|
|
- content = response.json()["choices"][0]["message"]["content"]
|
|
|
- return index, content
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
|
|
|
+ },
|
|
|
+ {"type": "text", "text": prompt}
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "max_tokens": 4096,
|
|
|
+ "stream": False,
|
|
|
+ "temperature": 0
|
|
|
+ },
|
|
|
+ timeout=600
|
|
|
+ )
|
|
|
+ response.raise_for_status()
|
|
|
+
|
|
|
+ resp_json = response.json()
|
|
|
+ choice = resp_json["choices"][0]
|
|
|
+ finish_reason = choice.get("finish_reason", "")
|
|
|
+ content = choice["message"]["content"]
|
|
|
+
|
|
|
+ if finish_reason == "length":
|
|
|
+ # 输出被 token 数截断,内容不完整,重试
|
|
|
+ raise RuntimeError(
|
|
|
+ f"步骤[{index}] 模型输出被截断(finish_reason=length),"
|
|
|
+ f"第 {attempt + 1} 次尝试"
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info(f"步骤[{index}] finish_reason={finish_reason}")
|
|
|
+ return index, content
|
|
|
+
|
|
|
+ except RuntimeError as e:
|
|
|
+ last_err = e
|
|
|
+ if attempt < max_retries:
|
|
|
+ wait = 2 ** attempt
|
|
|
+ logger.warning(f"{e},{wait}s 后重试...")
|
|
|
+ time.sleep(wait)
|
|
|
+ except requests.RequestException as e:
|
|
|
+ last_err = e
|
|
|
+ if attempt < max_retries:
|
|
|
+ wait = 2 ** attempt
|
|
|
+ logger.warning(f"步骤[{index}] 请求异常: {e},{wait}s 后重试...")
|
|
|
+ time.sleep(wait)
|
|
|
+
|
|
|
+ raise RuntimeError(f"步骤[{index}] 重试 {max_retries} 次后仍失败: {last_err}")
|
|
|
|
|
|
@staticmethod
|
|
|
def _parse_json(text: str, step_name: str) -> dict:
|
|
|
"""
|
|
|
解析模型返回的 JSON 文本,自动清洗 ```json``` 标记。
|
|
|
- 解析失败时抛出 RuntimeError(不会被 ValueError 捕获误报为"参数验证失败")。
|
|
|
+ 若直接解析失败,尝试用正则从文本中提取第一个 JSON 对象/数组。
|
|
|
+ 解析失败时抛出 RuntimeError。
|
|
|
"""
|
|
|
- # 去除首尾空白
|
|
|
text = text.strip()
|
|
|
- # 兼容模型偶尔返回 ```json ... ``` 包裹的情况
|
|
|
if text.startswith("```"):
|
|
|
lines = text.splitlines()
|
|
|
- # 去掉首行的 ```json 或 ``` 和末行的 ```
|
|
|
text = "\n".join(
|
|
|
line for line in lines
|
|
|
if not line.strip().startswith("```")
|
|
|
).strip()
|
|
|
try:
|
|
|
return json.loads(text)
|
|
|
- except json.JSONDecodeError as e:
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ # 模型返回了思考过程 + JSON 混合内容,尝试提取第一个 JSON 块
|
|
|
+ match = re.search(r'\{[\s\S]*\}', text)
|
|
|
+ if match:
|
|
|
+ try:
|
|
|
+ return json.loads(match.group())
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ pass
|
|
|
raise RuntimeError(
|
|
|
- f"步骤[{step_name}]模型返回内容无法解析为 JSON: {e}\n原始内容: {text[:200]}"
|
|
|
+ f"步骤[{step_name}]模型返回内容无法解析为 JSON\n原始内容: {text[:200]}"
|
|
|
)
|
|
|
|
|
|
def agent_ocr(self, image):
|
|
|
"""qwen_ocr提取化学品安全标签信息"""
|
|
|
- image = resize_image(image, max_size=512)
|
|
|
- image_base64 = image_to_base64(image)
|
|
|
+ image = resize_image(image, max_size=_IMAGE_MAX_SIZE)
|
|
|
+ quality = _IMAGE_COMPRESS_QUALITY if _IMAGE_COMPRESS else 95
|
|
|
+ image_base64 = image_to_base64(image, quality=quality)
|
|
|
+ logger.info(f"图像预处理: max_size={_IMAGE_MAX_SIZE}, compress={_IMAGE_COMPRESS}, quality={quality}")
|
|
|
+
|
|
|
+ # 为象形图识别单独准备高分辨率裁剪图
|
|
|
+ # 竖向长图(高/宽 > 1.5)只取上部 30%,横向图取上部 60%
|
|
|
+ w, h = image.size
|
|
|
+ ratio = h / w
|
|
|
+ crop_ratio = 0.30 if ratio > 1.5 else 0.60
|
|
|
+ icon_crop = image.crop((0, 0, w, int(h * crop_ratio)))
|
|
|
+ # 放大使图标细节清晰,长边不超过 1600px
|
|
|
+ scale = min(1600 / icon_crop.width, 1600 / icon_crop.height, 3)
|
|
|
+ icon_crop = icon_crop.resize(
|
|
|
+ (int(icon_crop.width * scale), int(icon_crop.height * scale)),
|
|
|
+ Image.Resampling.LANCZOS
|
|
|
+ )
|
|
|
+ enhancer = ImageEnhance.Contrast(icon_crop)
|
|
|
+ icon_crop = enhancer.enhance(1.3)
|
|
|
+ icon_image_base64 = image_to_base64(icon_crop, quality=95)
|
|
|
+ logger.info(f"象形图识别用裁剪图: {icon_crop.size} (crop_ratio={crop_ratio})")
|
|
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
|
|
- # 定义需要并行执行的任务(顺序固定,用 index 保序)
|
|
|
prompts = [
|
|
|
PROMPT_EXTRACT_ICON, # 0
|
|
|
PROMPT_EXTRACT_NAME, # 1
|
|
|
@@ -117,13 +414,24 @@ class OcrAgent:
|
|
|
PROMPT_EXTRACT_SUPPLIER # 5
|
|
|
]
|
|
|
|
|
|
- # 串行发送 6 个请求
|
|
|
- results = []
|
|
|
- for idx, prompt in enumerate(prompts):
|
|
|
- _, content = self.extract_single(image_base64, prompt, idx)
|
|
|
- results.append(content)
|
|
|
+ # 并行发送 6 个请求,按 index 填回保证顺序
|
|
|
+ # index=0 的象形图识别使用带参考图的专用方法
|
|
|
+ results = [None] * len(prompts)
|
|
|
+ with ThreadPoolExecutor(max_workers=len(prompts)) as executor:
|
|
|
+ futures = {}
|
|
|
+ futures[executor.submit(self.extract_icon, icon_image_base64)] = 0
|
|
|
+ for idx, prompt in enumerate(prompts):
|
|
|
+ if idx == 0:
|
|
|
+ continue # icon 已单独提交
|
|
|
+ futures[executor.submit(self.extract_single, image_base64, prompt, idx)] = idx
|
|
|
+ for future in as_completed(futures):
|
|
|
+ idx, content = future.result() # 任意一个步骤失败会在此抛出
|
|
|
+ results[idx] = content
|
|
|
|
|
|
- # 从结果中提取数据(顺序已由 index 保证)
|
|
|
+ end_time = time.perf_counter()
|
|
|
+ logger.info(f"推理时间: {end_time - start_time:.3f} 秒")
|
|
|
+
|
|
|
+ # 解析各步骤结果(顺序由 index 保证,与串行时完全一致)
|
|
|
step_names = ["icon", "name", "components", "keyword", "prevention", "supplier"]
|
|
|
icon = self._parse_json(results[0], step_names[0])
|
|
|
name = self._parse_json(results[1], step_names[1])
|
|
|
@@ -132,11 +440,7 @@ class OcrAgent:
|
|
|
pre_notice = self._parse_json(results[4], step_names[4])
|
|
|
suppliers = self._parse_json(results[5], step_names[5])
|
|
|
|
|
|
- end_time = time.perf_counter()
|
|
|
- elapsed_time = end_time - start_time
|
|
|
- print(f"推理时间: {elapsed_time:.6f} 秒")
|
|
|
-
|
|
|
- result = {
|
|
|
+ return {
|
|
|
"tag": {
|
|
|
"name_cn": name["name_cn"],
|
|
|
"name_en": name["name_en"],
|
|
|
@@ -150,8 +454,6 @@ class OcrAgent:
|
|
|
"acc_tel": suppliers["acc_tel"],
|
|
|
}
|
|
|
|
|
|
- return result
|
|
|
-
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
image = Image.open("./test1.jpg").convert("RGB")
|