from config import MODEL_PATH, INFERENCE_URL, INFERENCE_AUTH_TOKEN, INFERENCE_MODEL, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION,PROMPT_EXTRACT_SUPPLIER,PROMPT_EXTRACT_ICON from io import BytesIO from concurrent.futures import ThreadPoolExecutor, as_completed import base64 import json import logging import os from PIL import Image, ImageFilter, ImageEnhance import time import re import requests logger = logging.getLogger(__name__) # 从环境变量读取图像预处理配置(由 start.py 启动时注入) _IMAGE_MAX_SIZE = int(os.environ.get("IMAGE_MAX_SIZE", 512)) _IMAGE_COMPRESS = os.environ.get("IMAGE_COMPRESS", "false").lower() == "true" _IMAGE_COMPRESS_QUALITY = int(os.environ.get("IMAGE_COMPRESS_QUALITY", 70)) # GHS 参考图目录(与 agent.py 同级的上级目录下的 ghs_icons/) _GHS_ICONS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ghs_icons") _GHS_ICON_NAMES = [f"GHS{i:02d}" for i in range(1, 10)] def _load_ghs_reference_images() -> dict: """加载 GHS01-GHS09 参考图,返回 {名称: base64} 字典。缺失的图跳过。""" refs = {} for name in _GHS_ICON_NAMES: path = os.path.join(_GHS_ICONS_DIR, f"{name}.png") if os.path.exists(path): with open(path, "rb") as f: refs[name] = base64.b64encode(f.read()).decode("utf-8") else: logger.warning(f"GHS 参考图缺失: {path}") return refs # 启动时加载一次,避免每次请求重复读文件 _GHS_REFERENCE_IMAGES: dict = _load_ghs_reference_images() # 容易混淆的图标对,判断目标图标时同时发送干扰项供模型对比排除 # GHS09 外形独特(枯树+死鱼),不发送混淆图,避免反向干扰 _GHS_CONFUSABLE = { "GHS02": ["GHS03"], "GHS03": ["GHS02"], "GHS07": ["GHS08"], "GHS08": ["GHS07", "GHS06", "GHS09"], "GHS09": ["GHS08"], } # 每个图标的判定关键特征,嵌入提示词中强化区分 _GHS_DISCRIMINATIVE_HINT = { "GHS01": "关键特征:圆形炸弹向四周炸裂,有大量射线和碎片飞散。", "GHS02": "关键特征:只有火焰图案,火焰底部有横线底座。火焰中间没有任何圆圈/圆环,这是与GHS03的根本区别。", "GHS03": ( "【GHS03唯一判定标准】:图标内部必须同时存在两个元素:(1)火焰 AND (2)火焰包围的空心大圆圈/圆环(圆圈内部镂空呈白色)。" "缺少圆圈就是GHS02,不是GHS03。" "请先回答:图标中有没有空心圆圈?如果没有圆圈,直接回答NO。" ), "GHS04": ( "关键特征:一个横置的粗短圆柱形气瓶/钢瓶,右侧伸出一根细长的阀门管道。" "整体形状类似一个横向的短粗矩形/圆柱体,右侧有细管伸出。" "在小尺寸印刷版本中,整个图案看起来像一个横向的短粗横条(比感叹号的竖线更粗更短,且是水平的)。" "图标内没有火焰、没有感叹号竖线(感叹号是竖向的,气瓶是横向的)、没有人形、没有树。" "即使只看到一个横向的粗短形状,也可以回答YES。" ), "GHS05": ( "关键特征:图标内有腐蚀性液体滴落的场景——" "上方有试管/容器,液体向下滴落,腐蚀下方的物体(金属板或手掌)。" "整体看起来像'液体从上往下滴,下方被腐蚀出缺口'的形状。" "印刷版可能很小,但可以看出上方有细管状物体、下方有不规则缺口形状。" "即使细节模糊,只要能看出'滴落腐蚀'的大致形状,就回答YES。" ), "GHS06": "关键特征:骷髅头(空心白色)加下方两根交叉骨头。", "GHS07": '关键特征:只有一个感叹号"!"(上方竖条加下方圆点),无任何其他图形,没有人形,没有骷髅。', "GHS08": ( "关键特征:实心黑色人体上半身剪影(有明确的头部+肩膀+躯干轮廓),胸口有白色裂缝/射线向四周放射。" "【重要】树木/植物形状不是GHS08——如果图案看起来像树枝或植物,那是GHS09而不是GHS08,回答NO。" "不是骷髅,不是感叹号,不是树。" ), "GHS09": ( "GHS09是环境危害象形图,图标内有两个有机生物形状:" "左边是一棵枯树(竖直树干+向两侧伸出的树枝,整体呈Y形或T形的树状轮廓)," "右边是一条死鱼(横向椭圆形鱼身轮廓,肚皮朝上翻转)。" "这两个形状与其他GHS图标完全不同——其他图标内没有植物或鱼类形状。" "请对照参考图:如果标签中某个菱形图标内的图案与参考图相似(有树形和/或鱼形),回答YES。" ), } def image_to_base64(pil_image, image_format="JPEG", quality=95): """将PIL Image图像转换为Base64编码""" buffered = BytesIO() pil_image.save(buffered, format=image_format, quality=quality) img_byte_array = buffered.getvalue() encode_image = base64.b64encode(img_byte_array).decode('utf-8') return encode_image def resize_image(image, max_size=512): """缩放图像尺寸,保持 OCR 质量""" width, height = image.size max_dim = max(width, height) if max_dim <= max_size: return image scaling_factor = max_size / max_dim new_width = int(width * scaling_factor) new_height = int(height * scaling_factor) resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS) resized = resized.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3)) enhancer = ImageEnhance.Contrast(resized) resized = enhancer.enhance(1.1) return resized class OcrAgent: def __init__(self): self._url = INFERENCE_URL def _check_single_icon(self, ghs_name: str, ref_b64: str, image_base64: str, max_retries: int = 2) -> bool: """二分类:判断标签图中是否存在指定的 GHS 象形图,返回 True/False。 对容易混淆的图标,同时发送干扰项参考图,让模型对比排除。 """ content = [ {"type": "text", "text": f"以下是标准 {ghs_name} 象形图的参考图(目标图标):"}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ref_b64}"}}, ] # 如果该图标有容易混淆的图标,一并发送供对比 confusable = _GHS_CONFUSABLE.get(ghs_name, []) if confusable: content.append({"type": "text", "text": f"以下是容易与 {ghs_name} 混淆的图标,注意区分:"}) for conf_name in confusable: if conf_name in _GHS_REFERENCE_IMAGES: content.append({"type": "text", "text": f"({conf_name},不是目标图标)"}) content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{_GHS_REFERENCE_IMAGES[conf_name]}"}}) content.append({"type": "text", "text": "以下是需要识别的化学品安全标签图像:"}) content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}) discriminative = _GHS_DISCRIMINATIVE_HINT.get(ghs_name, "") confusable_hint = f"注意不要把它和 {'/'.join(confusable)} 混淆。\n" if confusable else "" if ghs_name == "GHS03": final_question = ( f"请仔细检查最后一张标签图像中的每个菱形图标。\n" f"{discriminative}\n" f"关键问题:标签中是否有图标在火焰图案内部包含了一个明显的空心圆圈/圆环?\n" f"如果有圆圈→YES,如果只有火焰没有圆圈→NO。\n" f"只回答 YES 或 NO,不要输出其他任何内容。" ) elif ghs_name == "GHS09": final_question = ( f"请仔细观察最后一张化学品安全标签图像中的所有菱形图标。\n" f"参考图是GHS09(环境危害)标准图标,内含枯树和死鱼。\n" f"{discriminative}\n" f"注意:印刷版标签中的GHS09图标为黑白色(无红色边框)," f"图案可能很小,但可以看出树形轮廓(树干+分叉枝条)和/或鱼形(椭圆形鱼身)。\n" f"对照参考图,标签中是否有任何一个菱形图标包含了【枯树】或【死鱼】图案?\n" f"只要识别出树形或鱼形,就回答YES;完全看不出来才回答NO。\n" f"只回答 YES 或 NO,不要输出其他任何内容。" ) else: final_question = ( f"请仔细观察上方图片。\n" f"第一张是目标图标 {ghs_name} 的标准参考图。\n" f"{discriminative}\n" f"{confusable_hint}" f"请严格对照上述关键特征,判断:最后一张化学品安全标签图像中是否包含 {ghs_name} 图标?\n" f"必须所有关键特征都匹配才回答YES,有任何一条不符合就回答NO。\n" f"只回答 YES 或 NO,不要输出其他任何内容。" ) content.append({"type": "text", "text": final_question}) last_err = None for attempt in range(max_retries + 1): try: response = requests.post( self._url, headers={ "Authorization": INFERENCE_AUTH_TOKEN, "Content-Type": "application/json" }, json={ "model": INFERENCE_MODEL, "messages": [ {"role": "system", "content": "You are a helpful assistant. Answer only YES or NO."}, {"role": "user", "content": content} ], "max_tokens": 16, "stream": False, "temperature": 0 }, timeout=600 ) response.raise_for_status() resp_json = response.json() answer = resp_json["choices"][0]["message"]["content"].strip().upper() logger.info(f"[icon binary] {ghs_name} -> {answer}") return answer.startswith("YES") except requests.RequestException as e: last_err = e if attempt < max_retries: wait = 2 ** attempt logger.warning(f"[icon binary] {ghs_name} 请求异常: {e},{wait}s 后重试...") time.sleep(wait) logger.error(f"[icon binary] {ghs_name} 重试 {max_retries} 次后仍失败: {last_err}") return False def _confirm_ghs03(self, image_base64: str) -> bool: """GHS03 二次确认:直接询问标签中是否存在内部含空心圆圈的火焰图标。 用于过滤 _check_single_icon 的 false-positive。""" ghs02_b64 = _GHS_REFERENCE_IMAGES.get("GHS02", "") ghs03_b64 = _GHS_REFERENCE_IMAGES.get("GHS03", "") content = [] if ghs02_b64: content.append({"type": "text", "text": "参考图A — GHS02(仅火焰,无圆圈):"}) content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ghs02_b64}"}}) if ghs03_b64: content.append({"type": "text", "text": "参考图B — GHS03(火焰内部有空心圆圈):"}) content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ghs03_b64}"}}) content.append({"type": "text", "text": "待检查的化学品安全标签:"}) content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}) content.append({"type": "text", "text": ( "请仔细观察标签中所有菱形图标内的图案。\n" "问题:标签中是否存在【参考图B】所示的GHS03图标——即火焰图案内部有一个明显的空心圆圈/圆环?\n" "注意:\n" "- 如果只看到【参考图A】类型的纯火焰(无圆圈),回答NO\n" "- 菱形框本身不算圆圈,必须是火焰内部的圆圈\n" "- 如果没有火焰图案,回答NO\n" "只回答 YES 或 NO,不要输出其他内容。" )}) try: response = requests.post( self._url, headers={"Authorization": INFERENCE_AUTH_TOKEN, "Content-Type": "application/json"}, json={ "model": INFERENCE_MODEL, "messages": [ {"role": "system", "content": "You are a helpful assistant. Answer only YES or NO."}, {"role": "user", "content": content} ], "max_tokens": 16, "stream": False, "temperature": 0 }, timeout=600 ) response.raise_for_status() answer = response.json()["choices"][0]["message"]["content"].strip().upper() logger.info(f"[icon GHS03 confirm] -> {answer}") return answer.startswith("YES") except Exception as e: logger.warning(f"[icon GHS03 confirm] 请求失败: {e},保守返回False") return False def extract_icon(self, image_base64: str, max_retries: int = 2): """象形图识别:对 GHS01-GHS09 逐个并行做二分类,返回 (0, JSON字符串)。""" with ThreadPoolExecutor(max_workers=len(_GHS_REFERENCE_IMAGES)) as executor: futures = { executor.submit(self._check_single_icon, name, ref_b64, image_base64, max_retries): name for name, ref_b64 in _GHS_REFERENCE_IMAGES.items() } results = {} for future in as_completed(futures): name = futures[future] results[name] = future.result() matched = [name for name in _GHS_ICON_NAMES if results.get(name)] # GHS02 和 GHS03 互斥:当两者同时出现时保留GHS02、丢弃GHS03 if "GHS02" in matched and "GHS03" in matched: logger.info("[icon] GHS02/GHS03 冲突,保留GHS02,丢弃GHS03") matched.remove("GHS03") # 对 GHS03 做二次确认,减少 false-positive if "GHS03" in matched: confirmed = self._confirm_ghs03(image_base64) if not confirmed: logger.info("[icon] GHS03 二次确认为 NO,移除") matched.remove("GHS03") logger.info(f"[icon] 识别结果: {matched}") return 0, json.dumps({"tag_images": matched}, ensure_ascii=False) def extract_single(self, image_base64: str, prompt: str, index: int, max_retries: int = 2): """单个任务请求,返回 (index, 结果文本)。失败时最多重试 max_retries 次。""" last_err = None for attempt in range(max_retries + 1): try: response = requests.post( self._url, headers={ "Authorization": INFERENCE_AUTH_TOKEN, "Content-Type": "application/json" }, json={ "model": INFERENCE_MODEL, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", "content": [ { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"} }, {"type": "text", "text": prompt} ] } ], "max_tokens": 4096, "stream": False, "temperature": 0 }, timeout=600 ) response.raise_for_status() resp_json = response.json() choice = resp_json["choices"][0] finish_reason = choice.get("finish_reason", "") content = choice["message"]["content"] if finish_reason == "length": # 输出被 token 数截断,内容不完整,重试 raise RuntimeError( f"步骤[{index}] 模型输出被截断(finish_reason=length)," f"第 {attempt + 1} 次尝试" ) logger.info(f"步骤[{index}] finish_reason={finish_reason}") return index, content except RuntimeError as e: last_err = e if attempt < max_retries: wait = 2 ** attempt logger.warning(f"{e},{wait}s 后重试...") time.sleep(wait) except requests.RequestException as e: last_err = e if attempt < max_retries: wait = 2 ** attempt logger.warning(f"步骤[{index}] 请求异常: {e},{wait}s 后重试...") time.sleep(wait) raise RuntimeError(f"步骤[{index}] 重试 {max_retries} 次后仍失败: {last_err}") @staticmethod def _parse_json(text: str, step_name: str) -> dict: """ 解析模型返回的 JSON 文本,自动清洗 ```json``` 标记。 若直接解析失败,尝试用正则从文本中提取第一个 JSON 对象/数组。 解析失败时抛出 RuntimeError。 """ text = text.strip() if text.startswith("```"): lines = text.splitlines() text = "\n".join( line for line in lines if not line.strip().startswith("```") ).strip() try: return json.loads(text) except json.JSONDecodeError: # 模型返回了思考过程 + JSON 混合内容,尝试提取第一个 JSON 块 match = re.search(r'\{[\s\S]*\}', text) if match: try: return json.loads(match.group()) except json.JSONDecodeError: pass raise RuntimeError( f"步骤[{step_name}]模型返回内容无法解析为 JSON\n原始内容: {text[:200]}" ) def agent_ocr(self, image): """qwen_ocr提取化学品安全标签信息""" image = resize_image(image, max_size=_IMAGE_MAX_SIZE) quality = _IMAGE_COMPRESS_QUALITY if _IMAGE_COMPRESS else 95 image_base64 = image_to_base64(image, quality=quality) logger.info(f"图像预处理: max_size={_IMAGE_MAX_SIZE}, compress={_IMAGE_COMPRESS}, quality={quality}") # 为象形图识别单独准备高分辨率裁剪图 # 竖向长图(高/宽 > 1.5)只取上部 30%,横向图取上部 60% w, h = image.size ratio = h / w crop_ratio = 0.30 if ratio > 1.5 else 0.60 icon_crop = image.crop((0, 0, w, int(h * crop_ratio))) # 放大使图标细节清晰,长边不超过 1600px scale = min(1600 / icon_crop.width, 1600 / icon_crop.height, 3) icon_crop = icon_crop.resize( (int(icon_crop.width * scale), int(icon_crop.height * scale)), Image.Resampling.LANCZOS ) enhancer = ImageEnhance.Contrast(icon_crop) icon_crop = enhancer.enhance(1.3) icon_image_base64 = image_to_base64(icon_crop, quality=95) logger.info(f"象形图识别用裁剪图: {icon_crop.size} (crop_ratio={crop_ratio})") start_time = time.perf_counter() prompts = [ PROMPT_EXTRACT_ICON, # 0 PROMPT_EXTRACT_NAME, # 1 PROMPT_EXTRACT_COMPONENTS, # 2 PROMPT_EXTRACT_KEYWORD, # 3 PROMPT_EXTRACT_PREVENTION, # 4 PROMPT_EXTRACT_SUPPLIER # 5 ] # 并行发送 6 个请求,按 index 填回保证顺序 # index=0 的象形图识别使用带参考图的专用方法 results = [None] * len(prompts) with ThreadPoolExecutor(max_workers=len(prompts)) as executor: futures = {} futures[executor.submit(self.extract_icon, icon_image_base64)] = 0 for idx, prompt in enumerate(prompts): if idx == 0: continue # icon 已单独提交 futures[executor.submit(self.extract_single, image_base64, prompt, idx)] = idx for future in as_completed(futures): idx, content = future.result() # 任意一个步骤失败会在此抛出 results[idx] = content end_time = time.perf_counter() logger.info(f"推理时间: {end_time - start_time:.3f} 秒") # 解析各步骤结果(顺序由 index 保证,与串行时完全一致) step_names = ["icon", "name", "components", "keyword", "prevention", "supplier"] icon = self._parse_json(results[0], step_names[0]) name = self._parse_json(results[1], step_names[1]) tag = self._parse_json(results[2], step_names[2]) risk_notice = self._parse_json(results[3], step_names[3]) pre_notice = self._parse_json(results[4], step_names[4]) suppliers = self._parse_json(results[5], step_names[5]) return { "tag": { "name_cn": name["name_cn"], "name_en": name["name_en"], "cf_list": tag["cf_list"] }, "tag_images": icon["tag_images"], "key_word": risk_notice["key_word"], "risk_notice": risk_notice["risk_notice"], "pre_notice": pre_notice["pre_notice"], "supplier": suppliers["supplier"], "acc_tel": suppliers["acc_tel"], } if __name__ == "__main__": image = Image.open("./test1.jpg").convert("RGB") agent = OcrAgent() res = agent.agent_ocr(image) print(res)