Browse Source

优化象形图识别:基于参考图的二分类+并行推理,调整prompt和模型配置

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
huanghongbo 3 weeks ago
parent
commit
f6f5c25c21

+ 363 - 61
agent/agent.py

@@ -1,16 +1,96 @@
 from config import MODEL_PATH, INFERENCE_URL, INFERENCE_AUTH_TOKEN, INFERENCE_MODEL, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION,PROMPT_EXTRACT_SUPPLIER,PROMPT_EXTRACT_ICON
 
 from io import BytesIO
+from concurrent.futures import ThreadPoolExecutor, as_completed
 import base64
 import json
+import logging
+import os
 from PIL import Image, ImageFilter, ImageEnhance
 import time
+import re
 import requests
 
-def image_to_base64(pil_image, image_format="JPEG"):
+logger = logging.getLogger(__name__)
+
+# 从环境变量读取图像预处理配置(由 start.py 启动时注入)
+_IMAGE_MAX_SIZE         = int(os.environ.get("IMAGE_MAX_SIZE", 512))
+_IMAGE_COMPRESS         = os.environ.get("IMAGE_COMPRESS", "false").lower() == "true"
+_IMAGE_COMPRESS_QUALITY = int(os.environ.get("IMAGE_COMPRESS_QUALITY", 70))
+
+# GHS 参考图目录(与 agent.py 同级的上级目录下的 ghs_icons/)
+_GHS_ICONS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ghs_icons")
+_GHS_ICON_NAMES = [f"GHS{i:02d}" for i in range(1, 10)]
+
+def _load_ghs_reference_images() -> dict:
+    """加载 GHS01-GHS09 参考图,返回 {名称: base64} 字典。缺失的图跳过。"""
+    refs = {}
+    for name in _GHS_ICON_NAMES:
+        path = os.path.join(_GHS_ICONS_DIR, f"{name}.png")
+        if os.path.exists(path):
+            with open(path, "rb") as f:
+                refs[name] = base64.b64encode(f.read()).decode("utf-8")
+        else:
+            logger.warning(f"GHS 参考图缺失: {path}")
+    return refs
+
+# 启动时加载一次,避免每次请求重复读文件
+_GHS_REFERENCE_IMAGES: dict = _load_ghs_reference_images()
+
+# 容易混淆的图标对,判断目标图标时同时发送干扰项供模型对比排除
+# GHS09 外形独特(枯树+死鱼),不发送混淆图,避免反向干扰
+_GHS_CONFUSABLE = {
+    "GHS02": ["GHS03"],
+    "GHS03": ["GHS02"],
+    "GHS07": ["GHS08"],
+    "GHS08": ["GHS07", "GHS06", "GHS09"],
+    "GHS09": ["GHS08"],
+}
+
+# 每个图标的判定关键特征,嵌入提示词中强化区分
+_GHS_DISCRIMINATIVE_HINT = {
+    "GHS01": "关键特征:圆形炸弹向四周炸裂,有大量射线和碎片飞散。",
+    "GHS02": "关键特征:只有火焰图案,火焰底部有横线底座。火焰中间没有任何圆圈/圆环,这是与GHS03的根本区别。",
+    "GHS03": (
+        "【GHS03唯一判定标准】:图标内部必须同时存在两个元素:(1)火焰 AND (2)火焰包围的空心大圆圈/圆环(圆圈内部镂空呈白色)。"
+        "缺少圆圈就是GHS02,不是GHS03。"
+        "请先回答:图标中有没有空心圆圈?如果没有圆圈,直接回答NO。"
+    ),
+    "GHS04": (
+        "关键特征:一个横置的粗短圆柱形气瓶/钢瓶,右侧伸出一根细长的阀门管道。"
+        "整体形状类似一个横向的短粗矩形/圆柱体,右侧有细管伸出。"
+        "在小尺寸印刷版本中,整个图案看起来像一个横向的短粗横条(比感叹号的竖线更粗更短,且是水平的)。"
+        "图标内没有火焰、没有感叹号竖线(感叹号是竖向的,气瓶是横向的)、没有人形、没有树。"
+        "即使只看到一个横向的粗短形状,也可以回答YES。"
+    ),
+    "GHS05": (
+        "关键特征:图标内有腐蚀性液体滴落的场景——"
+        "上方有试管/容器,液体向下滴落,腐蚀下方的物体(金属板或手掌)。"
+        "整体看起来像'液体从上往下滴,下方被腐蚀出缺口'的形状。"
+        "印刷版可能很小,但可以看出上方有细管状物体、下方有不规则缺口形状。"
+        "即使细节模糊,只要能看出'滴落腐蚀'的大致形状,就回答YES。"
+    ),
+    "GHS06": "关键特征:骷髅头(空心白色)加下方两根交叉骨头。",
+    "GHS07": '关键特征:只有一个感叹号"!"(上方竖条加下方圆点),无任何其他图形,没有人形,没有骷髅。',
+    "GHS08": (
+        "关键特征:实心黑色人体上半身剪影(有明确的头部+肩膀+躯干轮廓),胸口有白色裂缝/射线向四周放射。"
+        "【重要】树木/植物形状不是GHS08——如果图案看起来像树枝或植物,那是GHS09而不是GHS08,回答NO。"
+        "不是骷髅,不是感叹号,不是树。"
+    ),
+    "GHS09": (
+        "GHS09是环境危害象形图,图标内有两个有机生物形状:"
+        "左边是一棵枯树(竖直树干+向两侧伸出的树枝,整体呈Y形或T形的树状轮廓),"
+        "右边是一条死鱼(横向椭圆形鱼身轮廓,肚皮朝上翻转)。"
+        "这两个形状与其他GHS图标完全不同——其他图标内没有植物或鱼类形状。"
+        "请对照参考图:如果标签中某个菱形图标内的图案与参考图相似(有树形和/或鱼形),回答YES。"
+    ),
+}
+
+
+def image_to_base64(pil_image, image_format="JPEG", quality=95):
     """将PIL Image图像转换为Base64编码"""
     buffered = BytesIO()
-    pil_image.save(buffered, format=image_format)
+    pil_image.save(buffered, format=image_format, quality=quality)
     img_byte_array = buffered.getvalue()
     encode_image = base64.b64encode(img_byte_array).decode('utf-8')
     return encode_image
@@ -20,7 +100,6 @@ def resize_image(image, max_size=512):
     width, height = image.size
     max_dim = max(width, height)
 
-    # 如果图像不需要缩小,直接返回
     if max_dim <= max_size:
         return image
 
@@ -28,13 +107,8 @@ def resize_image(image, max_size=512):
     new_width = int(width * scaling_factor)
     new_height = int(height * scaling_factor)
 
-    # 使用 LANCZOS 高质量缩放
     resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
-
-    # 应用 UnsharpMask 锐化,补偿缩放损失
     resized = resized.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3))
-
-    # 轻微增强对比度,提高文字识别率
     enhancer = ImageEnhance.Contrast(resized)
     resized = enhancer.enhance(1.1)
 
@@ -44,70 +118,293 @@ class OcrAgent:
     def __init__(self):
         self._url = INFERENCE_URL
 
-    def extract_single(self, image_base64: str, prompt: str, index: int):
-        """单个任务请求,返回 (index, 结果文本)"""
-        response = requests.post(
-            self._url,
-            headers={
-                "Authorization": INFERENCE_AUTH_TOKEN,
-                "Content-Type": "application/json"
-            },
-            json={
-                "model": INFERENCE_MODEL,
-                "messages": [
-                    {"role": "system", "content": "You are a helpful assistant."},
-                    {
-                        "role": "user",
-                        "content": [
+    def _check_single_icon(self, ghs_name: str, ref_b64: str, image_base64: str, max_retries: int = 2) -> bool:
+        """二分类:判断标签图中是否存在指定的 GHS 象形图,返回 True/False。
+        对容易混淆的图标,同时发送干扰项参考图,让模型对比排除。
+        """
+        content = [
+            {"type": "text", "text": f"以下是标准 {ghs_name} 象形图的参考图(目标图标):"},
+            {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ref_b64}"}},
+        ]
+
+        # 如果该图标有容易混淆的图标,一并发送供对比
+        confusable = _GHS_CONFUSABLE.get(ghs_name, [])
+        if confusable:
+            content.append({"type": "text", "text": f"以下是容易与 {ghs_name} 混淆的图标,注意区分:"})
+            for conf_name in confusable:
+                if conf_name in _GHS_REFERENCE_IMAGES:
+                    content.append({"type": "text", "text": f"({conf_name},不是目标图标)"})
+                    content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{_GHS_REFERENCE_IMAGES[conf_name]}"}})
+
+        content.append({"type": "text", "text": "以下是需要识别的化学品安全标签图像:"})
+        content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}})
+
+        discriminative = _GHS_DISCRIMINATIVE_HINT.get(ghs_name, "")
+        confusable_hint = f"注意不要把它和 {'/'.join(confusable)} 混淆。\n" if confusable else ""
+
+        if ghs_name == "GHS03":
+            final_question = (
+                f"请仔细检查最后一张标签图像中的每个菱形图标。\n"
+                f"{discriminative}\n"
+                f"关键问题:标签中是否有图标在火焰图案内部包含了一个明显的空心圆圈/圆环?\n"
+                f"如果有圆圈→YES,如果只有火焰没有圆圈→NO。\n"
+                f"只回答 YES 或 NO,不要输出其他任何内容。"
+            )
+        elif ghs_name == "GHS09":
+            final_question = (
+                f"请仔细观察最后一张化学品安全标签图像中的所有菱形图标。\n"
+                f"参考图是GHS09(环境危害)标准图标,内含枯树和死鱼。\n"
+                f"{discriminative}\n"
+                f"注意:印刷版标签中的GHS09图标为黑白色(无红色边框),"
+                f"图案可能很小,但可以看出树形轮廓(树干+分叉枝条)和/或鱼形(椭圆形鱼身)。\n"
+                f"对照参考图,标签中是否有任何一个菱形图标包含了【枯树】或【死鱼】图案?\n"
+                f"只要识别出树形或鱼形,就回答YES;完全看不出来才回答NO。\n"
+                f"只回答 YES 或 NO,不要输出其他任何内容。"
+            )
+        else:
+            final_question = (
+                f"请仔细观察上方图片。\n"
+                f"第一张是目标图标 {ghs_name} 的标准参考图。\n"
+                f"{discriminative}\n"
+                f"{confusable_hint}"
+                f"请严格对照上述关键特征,判断:最后一张化学品安全标签图像中是否包含 {ghs_name} 图标?\n"
+                f"必须所有关键特征都匹配才回答YES,有任何一条不符合就回答NO。\n"
+                f"只回答 YES 或 NO,不要输出其他任何内容。"
+            )
+        content.append({"type": "text", "text": final_question})
+
+        last_err = None
+        for attempt in range(max_retries + 1):
+            try:
+                response = requests.post(
+                    self._url,
+                    headers={
+                        "Authorization": INFERENCE_AUTH_TOKEN,
+                        "Content-Type": "application/json"
+                    },
+                    json={
+                        "model": INFERENCE_MODEL,
+                        "messages": [
+                            {"role": "system", "content": "You are a helpful assistant. Answer only YES or NO."},
+                            {"role": "user", "content": content}
+                        ],
+                        "max_tokens": 16,
+                        "stream": False,
+                        "temperature": 0
+                    },
+                    timeout=600
+                )
+                response.raise_for_status()
+
+                resp_json = response.json()
+                answer = resp_json["choices"][0]["message"]["content"].strip().upper()
+                logger.info(f"[icon binary] {ghs_name} -> {answer}")
+                return answer.startswith("YES")
+
+            except requests.RequestException as e:
+                last_err = e
+                if attempt < max_retries:
+                    wait = 2 ** attempt
+                    logger.warning(f"[icon binary] {ghs_name} 请求异常: {e},{wait}s 后重试...")
+                    time.sleep(wait)
+
+        logger.error(f"[icon binary] {ghs_name} 重试 {max_retries} 次后仍失败: {last_err}")
+        return False
+
+    def _confirm_ghs03(self, image_base64: str) -> bool:
+        """GHS03 二次确认:直接询问标签中是否存在内部含空心圆圈的火焰图标。
+        用于过滤 _check_single_icon 的 false-positive。"""
+        ghs02_b64 = _GHS_REFERENCE_IMAGES.get("GHS02", "")
+        ghs03_b64 = _GHS_REFERENCE_IMAGES.get("GHS03", "")
+        content = []
+        if ghs02_b64:
+            content.append({"type": "text", "text": "参考图A — GHS02(仅火焰,无圆圈):"})
+            content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ghs02_b64}"}})
+        if ghs03_b64:
+            content.append({"type": "text", "text": "参考图B — GHS03(火焰内部有空心圆圈):"})
+            content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ghs03_b64}"}})
+        content.append({"type": "text", "text": "待检查的化学品安全标签:"})
+        content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}})
+        content.append({"type": "text", "text": (
+            "请仔细观察标签中所有菱形图标内的图案。\n"
+            "问题:标签中是否存在【参考图B】所示的GHS03图标——即火焰图案内部有一个明显的空心圆圈/圆环?\n"
+            "注意:\n"
+            "- 如果只看到【参考图A】类型的纯火焰(无圆圈),回答NO\n"
+            "- 菱形框本身不算圆圈,必须是火焰内部的圆圈\n"
+            "- 如果没有火焰图案,回答NO\n"
+            "只回答 YES 或 NO,不要输出其他内容。"
+        )})
+        try:
+            response = requests.post(
+                self._url,
+                headers={"Authorization": INFERENCE_AUTH_TOKEN, "Content-Type": "application/json"},
+                json={
+                    "model": INFERENCE_MODEL,
+                    "messages": [
+                        {"role": "system", "content": "You are a helpful assistant. Answer only YES or NO."},
+                        {"role": "user", "content": content}
+                    ],
+                    "max_tokens": 16,
+                    "stream": False,
+                    "temperature": 0
+                },
+                timeout=600
+            )
+            response.raise_for_status()
+            answer = response.json()["choices"][0]["message"]["content"].strip().upper()
+            logger.info(f"[icon GHS03 confirm] -> {answer}")
+            return answer.startswith("YES")
+        except Exception as e:
+            logger.warning(f"[icon GHS03 confirm] 请求失败: {e},保守返回False")
+            return False
+
+    def extract_icon(self, image_base64: str, max_retries: int = 2):
+        """象形图识别:对 GHS01-GHS09 逐个并行做二分类,返回 (0, JSON字符串)。"""
+        with ThreadPoolExecutor(max_workers=len(_GHS_REFERENCE_IMAGES)) as executor:
+            futures = {
+                executor.submit(self._check_single_icon, name, ref_b64, image_base64, max_retries): name
+                for name, ref_b64 in _GHS_REFERENCE_IMAGES.items()
+            }
+            results = {}
+            for future in as_completed(futures):
+                name = futures[future]
+                results[name] = future.result()
+
+        matched = [name for name in _GHS_ICON_NAMES if results.get(name)]
+
+        # GHS02 和 GHS03 互斥:当两者同时出现时保留GHS02、丢弃GHS03
+        if "GHS02" in matched and "GHS03" in matched:
+            logger.info("[icon] GHS02/GHS03 冲突,保留GHS02,丢弃GHS03")
+            matched.remove("GHS03")
+
+        # 对 GHS03 做二次确认,减少 false-positive
+        if "GHS03" in matched:
+            confirmed = self._confirm_ghs03(image_base64)
+            if not confirmed:
+                logger.info("[icon] GHS03 二次确认为 NO,移除")
+                matched.remove("GHS03")
+
+        logger.info(f"[icon] 识别结果: {matched}")
+        return 0, json.dumps({"tag_images": matched}, ensure_ascii=False)
+
+    def extract_single(self, image_base64: str, prompt: str, index: int, max_retries: int = 2):
+        """单个任务请求,返回 (index, 结果文本)。失败时最多重试 max_retries 次。"""
+        last_err = None
+        for attempt in range(max_retries + 1):
+            try:
+                response = requests.post(
+                    self._url,
+                    headers={
+                        "Authorization": INFERENCE_AUTH_TOKEN,
+                        "Content-Type": "application/json"
+                    },
+                    json={
+                        "model": INFERENCE_MODEL,
+                        "messages": [
+                            {"role": "system", "content": "You are a helpful assistant."},
                             {
-                                "type": "image_url",
-                                "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
-                            },
-                            {"type": "text", "text": prompt}
-                        ]
-                    }
-                ],
-                "max_tokens": 4096,
-                "stream": False,
-                "temperature": 0
-            },
-            timeout=600
-        )
-        response.raise_for_status()
-        content = response.json()["choices"][0]["message"]["content"]
-        return index, content
+                                "role": "user",
+                                "content": [
+                                    {
+                                        "type": "image_url",
+                                        "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
+                                    },
+                                    {"type": "text", "text": prompt}
+                                ]
+                            }
+                        ],
+                        "max_tokens": 4096,
+                        "stream": False,
+                        "temperature": 0
+                    },
+                    timeout=600
+                )
+                response.raise_for_status()
+
+                resp_json = response.json()
+                choice = resp_json["choices"][0]
+                finish_reason = choice.get("finish_reason", "")
+                content = choice["message"]["content"]
+
+                if finish_reason == "length":
+                    # 输出被 token 数截断,内容不完整,重试
+                    raise RuntimeError(
+                        f"步骤[{index}] 模型输出被截断(finish_reason=length),"
+                        f"第 {attempt + 1} 次尝试"
+                    )
+
+                logger.info(f"步骤[{index}] finish_reason={finish_reason}")
+                return index, content
+
+            except RuntimeError as e:
+                last_err = e
+                if attempt < max_retries:
+                    wait = 2 ** attempt
+                    logger.warning(f"{e},{wait}s 后重试...")
+                    time.sleep(wait)
+            except requests.RequestException as e:
+                last_err = e
+                if attempt < max_retries:
+                    wait = 2 ** attempt
+                    logger.warning(f"步骤[{index}] 请求异常: {e},{wait}s 后重试...")
+                    time.sleep(wait)
+
+        raise RuntimeError(f"步骤[{index}] 重试 {max_retries} 次后仍失败: {last_err}")
 
     @staticmethod
     def _parse_json(text: str, step_name: str) -> dict:
         """
         解析模型返回的 JSON 文本,自动清洗 ```json``` 标记。
-        解析失败时抛出 RuntimeError(不会被 ValueError 捕获误报为"参数验证失败")。
+        若直接解析失败,尝试用正则从文本中提取第一个 JSON 对象/数组。
+        解析失败时抛出 RuntimeError。
         """
-        # 去除首尾空白
         text = text.strip()
-        # 兼容模型偶尔返回 ```json ... ``` 包裹的情况
         if text.startswith("```"):
             lines = text.splitlines()
-            # 去掉首行的 ```json 或 ``` 和末行的 ```
             text = "\n".join(
                 line for line in lines
                 if not line.strip().startswith("```")
             ).strip()
         try:
             return json.loads(text)
-        except json.JSONDecodeError as e:
+        except json.JSONDecodeError:
+            # 模型返回了思考过程 + JSON 混合内容,尝试提取第一个 JSON 块
+            match = re.search(r'\{[\s\S]*\}', text)
+            if match:
+                try:
+                    return json.loads(match.group())
+                except json.JSONDecodeError:
+                    pass
             raise RuntimeError(
-                f"步骤[{step_name}]模型返回内容无法解析为 JSON: {e}\n原始内容: {text[:200]}"
+                f"步骤[{step_name}]模型返回内容无法解析为 JSON\n原始内容: {text[:200]}"
             )
 
     def agent_ocr(self, image):
         """qwen_ocr提取化学品安全标签信息"""
-        image = resize_image(image, max_size=512)
-        image_base64 = image_to_base64(image)
+        image = resize_image(image, max_size=_IMAGE_MAX_SIZE)
+        quality = _IMAGE_COMPRESS_QUALITY if _IMAGE_COMPRESS else 95
+        image_base64 = image_to_base64(image, quality=quality)
+        logger.info(f"图像预处理: max_size={_IMAGE_MAX_SIZE}, compress={_IMAGE_COMPRESS}, quality={quality}")
+
+        # 为象形图识别单独准备高分辨率裁剪图
+        # 竖向长图(高/宽 > 1.5)只取上部 30%,横向图取上部 60%
+        w, h = image.size
+        ratio = h / w
+        crop_ratio = 0.30 if ratio > 1.5 else 0.60
+        icon_crop = image.crop((0, 0, w, int(h * crop_ratio)))
+        # 放大使图标细节清晰,长边不超过 1600px
+        scale = min(1600 / icon_crop.width, 1600 / icon_crop.height, 3)
+        icon_crop = icon_crop.resize(
+            (int(icon_crop.width * scale), int(icon_crop.height * scale)),
+            Image.Resampling.LANCZOS
+        )
+        enhancer = ImageEnhance.Contrast(icon_crop)
+        icon_crop = enhancer.enhance(1.3)
+        icon_image_base64 = image_to_base64(icon_crop, quality=95)
+        logger.info(f"象形图识别用裁剪图: {icon_crop.size} (crop_ratio={crop_ratio})")
 
         start_time = time.perf_counter()
 
-        # 定义需要并行执行的任务(顺序固定,用 index 保序)
         prompts = [
             PROMPT_EXTRACT_ICON,        # 0
             PROMPT_EXTRACT_NAME,        # 1
@@ -117,13 +414,24 @@ class OcrAgent:
             PROMPT_EXTRACT_SUPPLIER     # 5
         ]
 
-        # 串行发送 6 个请求
-        results = []
-        for idx, prompt in enumerate(prompts):
-            _, content = self.extract_single(image_base64, prompt, idx)
-            results.append(content)
+        # 并行发送 6 个请求,按 index 填回保证顺序
+        # index=0 的象形图识别使用带参考图的专用方法
+        results = [None] * len(prompts)
+        with ThreadPoolExecutor(max_workers=len(prompts)) as executor:
+            futures = {}
+            futures[executor.submit(self.extract_icon, icon_image_base64)] = 0
+            for idx, prompt in enumerate(prompts):
+                if idx == 0:
+                    continue  # icon 已单独提交
+                futures[executor.submit(self.extract_single, image_base64, prompt, idx)] = idx
+            for future in as_completed(futures):
+                idx, content = future.result()  # 任意一个步骤失败会在此抛出
+                results[idx] = content
 
-        # 从结果中提取数据(顺序已由 index 保证)
+        end_time = time.perf_counter()
+        logger.info(f"推理时间: {end_time - start_time:.3f} 秒")
+
+        # 解析各步骤结果(顺序由 index 保证,与串行时完全一致)
         step_names = ["icon", "name", "components", "keyword", "prevention", "supplier"]
         icon        = self._parse_json(results[0], step_names[0])
         name        = self._parse_json(results[1], step_names[1])
@@ -132,11 +440,7 @@ class OcrAgent:
         pre_notice  = self._parse_json(results[4], step_names[4])
         suppliers   = self._parse_json(results[5], step_names[5])
 
-        end_time = time.perf_counter()
-        elapsed_time = end_time - start_time
-        print(f"推理时间: {elapsed_time:.6f} 秒")
-
-        result = {
+        return {
             "tag": {
                 "name_cn": name["name_cn"],
                 "name_en": name["name_en"],
@@ -150,8 +454,6 @@ class OcrAgent:
             "acc_tel": suppliers["acc_tel"],
         }
 
-        return result
-
 
 if __name__ == "__main__":
     image = Image.open("./test1.jpg").convert("RGB")

+ 34 - 32
config/config.py

@@ -33,11 +33,12 @@ INFERENCE_MODEL = _os.environ.get(
 PROMPT_EXTRACT_NAME = """
 你是一个专业的化学品安全标签说明识别助手。
 请从图像中提取化学品的中文名称和英文名称(如有)。
+化学品名称通常出现在标签最顶部,字号最大,可能同时包含中文名和英文名及化学式。
 
 按照以下JSON格式输出结果:
 {
     "name_cn": "化学品中文名称",
-    "name_en": "化学品英文名称"
+    "name_en": "化学品英文名称(含化学式,如无则填空字符串)"
 }
 
 注意:返回结果必须是标准JSON格式,不要包含```json```标记。
@@ -68,13 +69,13 @@ PROMPT_EXTRACT_COMPONENTS = """
 # 步骤3:提取安全提醒信号词
 PROMPT_EXTRACT_KEYWORD = """
 你是一个专业的化学品安全标签说明识别助手。
-请从图像中提取安全提醒信号词和危险性说明
-安全信号词:通常以比较醒目的方式显示,只识别"危险"、"警告"
-危险性说明:通常在安全提醒词附近,只返回中文的表示
+请从图像中提取安全提醒信号词和危险性说明
+- 信号词:通常在一个醒目的方框内(黑底白字或红底白字),只识别"危险"或"警告"这两种
+- 危险性说明:通常紧跟在信号词和象形图下方,是一段描述危害的文字,只返回中文内容
 
 按照以下JSON格式输出结果:
 {
-    "key_word": "安全提醒信号词",
+    "key_word": "危险或警告",
     "risk_notice": "危险性说明内容"
 }
 
@@ -92,7 +93,7 @@ PROMPT_EXTRACT_PREVENTION = """
 按照以下JSON格式输出结果:
 {
     "pre_notice": {
-        "pre_method": "预防措施",
+        "pre_methoud": "预防措施",
         "acc_response": "事故响应",
         "safe_keep": "安全存储",
         "abandon_deal": "废弃处置"
@@ -105,8 +106,9 @@ PROMPT_EXTRACT_PREVENTION = """
 # 步骤5:提取供应商标识
 PROMPT_EXTRACT_SUPPLIER = """
 你是一个专业的化学品安全标签说明识别助手。
-请从图像中提取所有供应商信息和应急咨询电话,
-供应商信息包括:供应商名称、供应商地址、供应商电话、供应商邮编;
+请从图像中提取供应商信息和应急咨询电话。
+供应商信息通常在标签底部,包括供应商名称、地址、电话、邮编;部分标签可能没有邮编字段,填空字符串即可。
+应急咨询电话通常单独标注,关键字为"应急咨询电话"或"化学事故应急咨询电话"。
 
 按照以下JSON格式输出结果:
 {
@@ -114,13 +116,13 @@ PROMPT_EXTRACT_SUPPLIER = """
         "name": "供应商名称",
         "address": "供应商地址",
         "tel": "供应商电话",
-        "post": "供应商邮编",
+        "post": "供应商邮编(无则填空字符串)"
     }],
-    "acc_tel": "应急咨询电话"
+    "acc_tel": "应急咨询电话(无则填空字符串)"
 }
 
 注意:
-供应商的信息可能有多个,请提取对应的多个供应商的信息
+供应商的信息可能有多个,请提取所有供应商信息
 返回结果必须是标准JSON格式,不要包含```json```标记。
 """
 
@@ -163,30 +165,30 @@ OCR_PROMPT_FULL = """
 
 # 步骤6:提取象形图标识
 PROMPT_EXTRACT_ICON = """
-你是一个专业的化学品安全标签说明识别助手。
-请识别图像中的GHS危险象形图标识。这些象形图通常是红色菱形框内的黑色符号图案,包括但不限于:
-- GHS01:爆炸物(爆炸图案)
-- GHS02:易燃物(火焰图案)
-- GHS03:氧化剂(火焰与圆圈图案)
-- GHS04:压缩气体(气瓶图案)
-- GHS05:腐蚀性物质(手和金属被腐蚀图案)
-- GHS06:急性毒性(骷髅和交叉骨头图案)
-- GHS07:有害物质(感叹号图案)
-- GHS08:健康危害(人体剪影图案)
-- GHS09:环境危害(死鱼和枯树图案)
-
-请仔细对比参考图像和待识别图像中的象形图,按照图像中从左到右的顺序识别这些象形图的类别。
+你是一个专业的化学品安全标签识别助手,专门负责识别GHS危险象形图。
+
+我已在上方依次提供了GHS01至GHS09的标准参考图像,请仔细查看每一张。
+
+现在请直接对比最后一张化学品安全标签图像,找出其中所有菱形边框内的象形图,
+将每个象形图与上方参考图逐一对比,确定对应编号。
+
+关键区分特征(仅作辅助,以参考图为准):
+- GHS01:圆形炸弹向四周炸裂,有大量射线和碎片飞散
+- GHS02:火焰图案,火焰底部有一条黑色横线底座,无圆圈
+- GHS03:火焰包围一个空心大圆圈(圆圈内部是白色),底部也有横线底座,与GHS02的区别是中间有空心圆圈
+- GHS04:一个横置的粗短气瓶/钢瓶,右侧有细长阀门
+- GHS05:顶部有两根试管,左边液体滴腐蚀金属板,右边液体滴腐蚀手掌
+- GHS06:骷髅头(空心白色)+下方两根交叉骨头
+- GHS07:只有一个感叹号"!"(上方竖条+下方圆点),无任何其他图形
+- GHS08:实心黑色人体上半身剪影,胸口有白色裂缝向四周放射(不是骷髅,不是感叹号)
+- GHS09:左侧一棵枯树(无叶子只有树枝),右侧一条翻肚白色死鱼,底部有水面横线,图案独特无火焰无人形
+
+无法确认的跳过,不要猜测。
 
 按照以下JSON格式输出结果:
-{
-    "tag_images": [识别到的形象图]
-}
+{"tag_images": ["GHS02", "GHS07"]}
 
-注意:
-1. 必须按照图像中象形图从左到右的实际顺序排列
-2. 如果某个位置的象形图无法识别则跳过
-3. 识别出的象形图用对应的GHS编号(如GHS01-GHS09)表示
-4. 返回结果必须是标准JSON格式,不要包含```json```标记
+注意:tag_images 中只填GHS编号,不要求顺序,返回结果必须是标准JSON格式,不要包含```json```标记。
 """
 
 # 默认使用的提示词(向后兼容)

BIN
ghs_icons/GHS01.png


BIN
ghs_icons/GHS02.png


BIN
ghs_icons/GHS03.png


BIN
ghs_icons/GHS04.png


BIN
ghs_icons/GHS05.png


BIN
ghs_icons/GHS06.png


BIN
ghs_icons/GHS07.png


BIN
ghs_icons/GHS08.png


BIN
ghs_icons/GHS09.png


+ 11 - 2
ocr/ocr_config.yaml

@@ -4,11 +4,11 @@
 # ---------- 推理服务配置 ----------
 inference:
   # 推理服务请求地址
-  url: "http://10.69.29.202:31277/inference-api/exp-api/inf-1480928240416935936/v1/chat/completions"
+  url: "http://10.69.29.202:31277/inference-api/exp-api/inf-1504078419424792576/v1/chat/completions"
   # 鉴权 Token(格式:Bearer <token>)
   auth_token: "Bearer QDiS42vR9EqP-j73zeeyWB8zSJ4juheflm6yDKUDz5c"
   # 使用的模型名称
-  model: "Qwen3-VL-32B-Instruct"
+  model: "Qwen3.5-35B-A3B"
 
 # ---------- API 服务配置 ----------
 server:
@@ -18,3 +18,12 @@ server:
   port: 6006
   # 最大并发请求数
   max_concurrent: 5
+
+# ---------- 图像预处理配置 ----------
+image:
+  # 图像长边最大像素,超过则等比缩小(越大识别越准但请求越慢)
+  max_size: 1600
+  # 是否压缩图像质量(true=压缩,false=不压缩,用于对比测试)
+  compress: false
+  # compress=true 时生效,JPEG 压缩质量 1-95(越低体积越小,建议不低于 60)
+  compress_quality: 70

+ 19 - 0
ocr/ocr_config_compress.yaml

@@ -0,0 +1,19 @@
+# ========== Agent OCR 服务配置文件(压缩模式)==========
+
+# ---------- 推理服务配置 ----------
+inference:
+  url: "http://10.69.29.202:31277/inference-api/exp-api/inf-1480928240416935936/v1/chat/completions"
+  auth_token: "Bearer QDiS42vR9EqP-j73zeeyWB8zSJ4juheflm6yDKUDz5c"
+  model: "Qwen3-VL-32B-Instruct"
+
+# ---------- API 服务配置 ----------
+server:
+  host: "0.0.0.0"
+  port: 6006
+  max_concurrent: 5
+
+# ---------- 图像预处理配置 ----------
+image:
+  max_size: 512
+  compress: true
+  compress_quality: 70

+ 13 - 2
ocr/start.py

@@ -44,11 +44,18 @@ def apply_server_config(cfg: dict) -> tuple:
     host           = str(srv.get("host", "0.0.0.0"))
     port           = int(srv.get("port", 6006))
     max_concurrent = int(srv.get("max_concurrent", 5))
-    # 写入环境变量,供 run_api.py 的 lifespan 读取
     os.environ["MAX_CONCURRENT"] = str(max_concurrent)
     return host, port, max_concurrent
 
 
+def apply_image_config(cfg: dict):
+    """将图像预处理配置写入环境变量,供 agent/agent.py 读取"""
+    img = cfg.get("image", {})
+    os.environ["IMAGE_MAX_SIZE"]        = str(img.get("max_size", 512))
+    os.environ["IMAGE_COMPRESS"]        = str(img.get("compress", False)).lower()
+    os.environ["IMAGE_COMPRESS_QUALITY"]= str(img.get("compress_quality", 70))
+
+
 def setup_path():
     """将项目根目录(ocr/ 的上级)加入 sys.path,确保能找到 agent/api/config 包"""
     ocr_dir  = os.path.dirname(os.path.abspath(__file__))
@@ -59,8 +66,8 @@ def setup_path():
 
 def print_config(cfg: dict, host: str, port: int, max_concurrent: int):
     inf = cfg.get("inference", {})
+    img = cfg.get("image", {})
     token = inf.get("auth_token", "")
-    # 只显示 token 末尾 6 位,避免泄露
     masked = ("*" * max(0, len(token) - 6)) + token[-6:] if token else ""
     print("=" * 50)
     print("  Agent OCR 服务配置")
@@ -70,6 +77,9 @@ def print_config(cfg: dict, host: str, port: int, max_concurrent: int):
     print(f"  模型名称  : {inf.get('model', '')}")
     print(f"  服务地址  : http://{host}:{port}")
     print(f"  最大并发  : {max_concurrent}")
+    print(f"  图像最大边 : {img.get('max_size', 512)}px")
+    compress = img.get("compress", False)
+    print(f"  图像压缩  : {'开启,质量=' + str(img.get('compress_quality', 70)) if compress else '关闭'}")
     print("=" * 50)
 
 
@@ -85,6 +95,7 @@ def main():
     cfg = load_config(args.config)
     apply_inference_config(cfg)
     host, port, max_concurrent = apply_server_config(cfg)
+    apply_image_config(cfg)
     setup_path()
 
     print_config(cfg, host, port, max_concurrent)

+ 43 - 0
test_icon.py

@@ -0,0 +1,43 @@
+"""
+象形图识别测试脚本
+用法:python test_icon.py
+"""
+import sys
+import os
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+from PIL import Image
+from agent.agent import OcrAgent, image_to_base64, resize_image
+import json
+
+# 预期结果
+CASES = [
+    {"file": "pic/A.png", "name": "镀铜药水", "expected": {"GHS05", "GHS09"}},
+    {"file": "pic/B.png", "name": "液氨",     "expected": {"GHS04", "GHS05", "GHS06", "GHS09"}},
+    {"file": "pic/C.png", "name": "三氯苯",   "expected": {"GHS07", "GHS09"}},
+]
+
+agent = OcrAgent()
+
+print("=" * 60)
+for case in CASES:
+    img = Image.open(case["file"]).convert("RGB")
+    img = resize_image(img, max_size=1600)
+    b64 = image_to_base64(img, quality=95)
+
+    _, result_json = agent.extract_icon(b64)
+    got = set(json.loads(result_json)["tag_images"])
+    expected = case["expected"]
+
+    hit     = got & expected        # 正确识别
+    missed  = expected - got        # 漏识别
+    extra   = got - expected        # 多识别
+
+    status = "✓ PASS" if not missed and not extra else "✗ FAIL"
+    print(f"\n[{status}] {case['name']}")
+    print(f"  期望: {sorted(expected)}")
+    print(f"  识别: {sorted(got)}")
+    if missed: print(f"  漏识别: {sorted(missed)}")
+    if extra:  print(f"  误报:   {sorted(extra)}")
+
+print("\n" + "=" * 60)