| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462 |
- from config import MODEL_PATH, INFERENCE_URL, INFERENCE_AUTH_TOKEN, INFERENCE_MODEL, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION,PROMPT_EXTRACT_SUPPLIER,PROMPT_EXTRACT_ICON
- from io import BytesIO
- from concurrent.futures import ThreadPoolExecutor, as_completed
- import base64
- import json
- import logging
- import os
- from PIL import Image, ImageFilter, ImageEnhance
- import time
- import re
- import requests
- logger = logging.getLogger(__name__)
- # 从环境变量读取图像预处理配置(由 start.py 启动时注入)
- _IMAGE_MAX_SIZE = int(os.environ.get("IMAGE_MAX_SIZE", 512))
- _IMAGE_COMPRESS = os.environ.get("IMAGE_COMPRESS", "false").lower() == "true"
- _IMAGE_COMPRESS_QUALITY = int(os.environ.get("IMAGE_COMPRESS_QUALITY", 70))
- # GHS 参考图目录(与 agent.py 同级的上级目录下的 ghs_icons/)
- _GHS_ICONS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ghs_icons")
- _GHS_ICON_NAMES = [f"GHS{i:02d}" for i in range(1, 10)]
- def _load_ghs_reference_images() -> dict:
- """加载 GHS01-GHS09 参考图,返回 {名称: base64} 字典。缺失的图跳过。"""
- refs = {}
- for name in _GHS_ICON_NAMES:
- path = os.path.join(_GHS_ICONS_DIR, f"{name}.png")
- if os.path.exists(path):
- with open(path, "rb") as f:
- refs[name] = base64.b64encode(f.read()).decode("utf-8")
- else:
- logger.warning(f"GHS 参考图缺失: {path}")
- return refs
- # 启动时加载一次,避免每次请求重复读文件
- _GHS_REFERENCE_IMAGES: dict = _load_ghs_reference_images()
- # 容易混淆的图标对,判断目标图标时同时发送干扰项供模型对比排除
- # GHS09 外形独特(枯树+死鱼),不发送混淆图,避免反向干扰
- _GHS_CONFUSABLE = {
- "GHS02": ["GHS03"],
- "GHS03": ["GHS02"],
- "GHS07": ["GHS08"],
- "GHS08": ["GHS07", "GHS06", "GHS09"],
- "GHS09": ["GHS08"],
- }
- # 每个图标的判定关键特征,嵌入提示词中强化区分
- _GHS_DISCRIMINATIVE_HINT = {
- "GHS01": "关键特征:圆形炸弹向四周炸裂,有大量射线和碎片飞散。",
- "GHS02": "关键特征:只有火焰图案,火焰底部有横线底座。火焰中间没有任何圆圈/圆环,这是与GHS03的根本区别。",
- "GHS03": (
- "【GHS03唯一判定标准】:图标内部必须同时存在两个元素:(1)火焰 AND (2)火焰包围的空心大圆圈/圆环(圆圈内部镂空呈白色)。"
- "缺少圆圈就是GHS02,不是GHS03。"
- "请先回答:图标中有没有空心圆圈?如果没有圆圈,直接回答NO。"
- ),
- "GHS04": (
- "关键特征:一个横置的粗短圆柱形气瓶/钢瓶,右侧伸出一根细长的阀门管道。"
- "整体形状类似一个横向的短粗矩形/圆柱体,右侧有细管伸出。"
- "在小尺寸印刷版本中,整个图案看起来像一个横向的短粗横条(比感叹号的竖线更粗更短,且是水平的)。"
- "图标内没有火焰、没有感叹号竖线(感叹号是竖向的,气瓶是横向的)、没有人形、没有树。"
- "即使只看到一个横向的粗短形状,也可以回答YES。"
- ),
- "GHS05": (
- "关键特征:图标内有腐蚀性液体滴落的场景——"
- "上方有试管/容器,液体向下滴落,腐蚀下方的物体(金属板或手掌)。"
- "整体看起来像'液体从上往下滴,下方被腐蚀出缺口'的形状。"
- "印刷版可能很小,但可以看出上方有细管状物体、下方有不规则缺口形状。"
- "即使细节模糊,只要能看出'滴落腐蚀'的大致形状,就回答YES。"
- ),
- "GHS06": "关键特征:骷髅头(空心白色)加下方两根交叉骨头。",
- "GHS07": '关键特征:只有一个感叹号"!"(上方竖条加下方圆点),无任何其他图形,没有人形,没有骷髅。',
- "GHS08": (
- "关键特征:实心黑色人体上半身剪影(有明确的头部+肩膀+躯干轮廓),胸口有白色裂缝/射线向四周放射。"
- "【重要】树木/植物形状不是GHS08——如果图案看起来像树枝或植物,那是GHS09而不是GHS08,回答NO。"
- "不是骷髅,不是感叹号,不是树。"
- ),
- "GHS09": (
- "GHS09是环境危害象形图,图标内有两个有机生物形状:"
- "左边是一棵枯树(竖直树干+向两侧伸出的树枝,整体呈Y形或T形的树状轮廓),"
- "右边是一条死鱼(横向椭圆形鱼身轮廓,肚皮朝上翻转)。"
- "这两个形状与其他GHS图标完全不同——其他图标内没有植物或鱼类形状。"
- "请对照参考图:如果标签中某个菱形图标内的图案与参考图相似(有树形和/或鱼形),回答YES。"
- ),
- }
- def image_to_base64(pil_image, image_format="JPEG", quality=95):
- """将PIL Image图像转换为Base64编码"""
- buffered = BytesIO()
- pil_image.save(buffered, format=image_format, quality=quality)
- img_byte_array = buffered.getvalue()
- encode_image = base64.b64encode(img_byte_array).decode('utf-8')
- return encode_image
- def resize_image(image, max_size=512):
- """缩放图像尺寸,保持 OCR 质量"""
- width, height = image.size
- max_dim = max(width, height)
- if max_dim <= max_size:
- return image
- scaling_factor = max_size / max_dim
- new_width = int(width * scaling_factor)
- new_height = int(height * scaling_factor)
- resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
- resized = resized.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3))
- enhancer = ImageEnhance.Contrast(resized)
- resized = enhancer.enhance(1.1)
- return resized
- class OcrAgent:
- def __init__(self):
- self._url = INFERENCE_URL
- def _check_single_icon(self, ghs_name: str, ref_b64: str, image_base64: str, max_retries: int = 2) -> bool:
- """二分类:判断标签图中是否存在指定的 GHS 象形图,返回 True/False。
- 对容易混淆的图标,同时发送干扰项参考图,让模型对比排除。
- """
- content = [
- {"type": "text", "text": f"以下是标准 {ghs_name} 象形图的参考图(目标图标):"},
- {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ref_b64}"}},
- ]
- # 如果该图标有容易混淆的图标,一并发送供对比
- confusable = _GHS_CONFUSABLE.get(ghs_name, [])
- if confusable:
- content.append({"type": "text", "text": f"以下是容易与 {ghs_name} 混淆的图标,注意区分:"})
- for conf_name in confusable:
- if conf_name in _GHS_REFERENCE_IMAGES:
- content.append({"type": "text", "text": f"({conf_name},不是目标图标)"})
- content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{_GHS_REFERENCE_IMAGES[conf_name]}"}})
- content.append({"type": "text", "text": "以下是需要识别的化学品安全标签图像:"})
- content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}})
- discriminative = _GHS_DISCRIMINATIVE_HINT.get(ghs_name, "")
- confusable_hint = f"注意不要把它和 {'/'.join(confusable)} 混淆。\n" if confusable else ""
- if ghs_name == "GHS03":
- final_question = (
- f"请仔细检查最后一张标签图像中的每个菱形图标。\n"
- f"{discriminative}\n"
- f"关键问题:标签中是否有图标在火焰图案内部包含了一个明显的空心圆圈/圆环?\n"
- f"如果有圆圈→YES,如果只有火焰没有圆圈→NO。\n"
- f"只回答 YES 或 NO,不要输出其他任何内容。"
- )
- elif ghs_name == "GHS09":
- final_question = (
- f"请仔细观察最后一张化学品安全标签图像中的所有菱形图标。\n"
- f"参考图是GHS09(环境危害)标准图标,内含枯树和死鱼。\n"
- f"{discriminative}\n"
- f"注意:印刷版标签中的GHS09图标为黑白色(无红色边框),"
- f"图案可能很小,但可以看出树形轮廓(树干+分叉枝条)和/或鱼形(椭圆形鱼身)。\n"
- f"对照参考图,标签中是否有任何一个菱形图标包含了【枯树】或【死鱼】图案?\n"
- f"只要识别出树形或鱼形,就回答YES;完全看不出来才回答NO。\n"
- f"只回答 YES 或 NO,不要输出其他任何内容。"
- )
- else:
- final_question = (
- f"请仔细观察上方图片。\n"
- f"第一张是目标图标 {ghs_name} 的标准参考图。\n"
- f"{discriminative}\n"
- f"{confusable_hint}"
- f"请严格对照上述关键特征,判断:最后一张化学品安全标签图像中是否包含 {ghs_name} 图标?\n"
- f"必须所有关键特征都匹配才回答YES,有任何一条不符合就回答NO。\n"
- f"只回答 YES 或 NO,不要输出其他任何内容。"
- )
- content.append({"type": "text", "text": final_question})
- last_err = None
- for attempt in range(max_retries + 1):
- try:
- response = requests.post(
- self._url,
- headers={
- "Authorization": INFERENCE_AUTH_TOKEN,
- "Content-Type": "application/json"
- },
- json={
- "model": INFERENCE_MODEL,
- "messages": [
- {"role": "system", "content": "You are a helpful assistant. Answer only YES or NO."},
- {"role": "user", "content": content}
- ],
- "max_tokens": 16,
- "stream": False,
- "temperature": 0
- },
- timeout=600
- )
- response.raise_for_status()
- resp_json = response.json()
- answer = resp_json["choices"][0]["message"]["content"].strip().upper()
- logger.info(f"[icon binary] {ghs_name} -> {answer}")
- return answer.startswith("YES")
- except requests.RequestException as e:
- last_err = e
- if attempt < max_retries:
- wait = 2 ** attempt
- logger.warning(f"[icon binary] {ghs_name} 请求异常: {e},{wait}s 后重试...")
- time.sleep(wait)
- logger.error(f"[icon binary] {ghs_name} 重试 {max_retries} 次后仍失败: {last_err}")
- return False
- def _confirm_ghs03(self, image_base64: str) -> bool:
- """GHS03 二次确认:直接询问标签中是否存在内部含空心圆圈的火焰图标。
- 用于过滤 _check_single_icon 的 false-positive。"""
- ghs02_b64 = _GHS_REFERENCE_IMAGES.get("GHS02", "")
- ghs03_b64 = _GHS_REFERENCE_IMAGES.get("GHS03", "")
- content = []
- if ghs02_b64:
- content.append({"type": "text", "text": "参考图A — GHS02(仅火焰,无圆圈):"})
- content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ghs02_b64}"}})
- if ghs03_b64:
- content.append({"type": "text", "text": "参考图B — GHS03(火焰内部有空心圆圈):"})
- content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ghs03_b64}"}})
- content.append({"type": "text", "text": "待检查的化学品安全标签:"})
- content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}})
- content.append({"type": "text", "text": (
- "请仔细观察标签中所有菱形图标内的图案。\n"
- "问题:标签中是否存在【参考图B】所示的GHS03图标——即火焰图案内部有一个明显的空心圆圈/圆环?\n"
- "注意:\n"
- "- 如果只看到【参考图A】类型的纯火焰(无圆圈),回答NO\n"
- "- 菱形框本身不算圆圈,必须是火焰内部的圆圈\n"
- "- 如果没有火焰图案,回答NO\n"
- "只回答 YES 或 NO,不要输出其他内容。"
- )})
- try:
- response = requests.post(
- self._url,
- headers={"Authorization": INFERENCE_AUTH_TOKEN, "Content-Type": "application/json"},
- json={
- "model": INFERENCE_MODEL,
- "messages": [
- {"role": "system", "content": "You are a helpful assistant. Answer only YES or NO."},
- {"role": "user", "content": content}
- ],
- "max_tokens": 16,
- "stream": False,
- "temperature": 0
- },
- timeout=600
- )
- response.raise_for_status()
- answer = response.json()["choices"][0]["message"]["content"].strip().upper()
- logger.info(f"[icon GHS03 confirm] -> {answer}")
- return answer.startswith("YES")
- except Exception as e:
- logger.warning(f"[icon GHS03 confirm] 请求失败: {e},保守返回False")
- return False
- def extract_icon(self, image_base64: str, max_retries: int = 2):
- """象形图识别:对 GHS01-GHS09 逐个并行做二分类,返回 (0, JSON字符串)。"""
- with ThreadPoolExecutor(max_workers=len(_GHS_REFERENCE_IMAGES)) as executor:
- futures = {
- executor.submit(self._check_single_icon, name, ref_b64, image_base64, max_retries): name
- for name, ref_b64 in _GHS_REFERENCE_IMAGES.items()
- }
- results = {}
- for future in as_completed(futures):
- name = futures[future]
- results[name] = future.result()
- matched = [name for name in _GHS_ICON_NAMES if results.get(name)]
- # GHS02 和 GHS03 互斥:当两者同时出现时保留GHS02、丢弃GHS03
- if "GHS02" in matched and "GHS03" in matched:
- logger.info("[icon] GHS02/GHS03 冲突,保留GHS02,丢弃GHS03")
- matched.remove("GHS03")
- # 对 GHS03 做二次确认,减少 false-positive
- if "GHS03" in matched:
- confirmed = self._confirm_ghs03(image_base64)
- if not confirmed:
- logger.info("[icon] GHS03 二次确认为 NO,移除")
- matched.remove("GHS03")
- logger.info(f"[icon] 识别结果: {matched}")
- return 0, json.dumps({"tag_images": matched}, ensure_ascii=False)
- def extract_single(self, image_base64: str, prompt: str, index: int, max_retries: int = 2):
- """单个任务请求,返回 (index, 结果文本)。失败时最多重试 max_retries 次。"""
- last_err = None
- for attempt in range(max_retries + 1):
- try:
- response = requests.post(
- self._url,
- headers={
- "Authorization": INFERENCE_AUTH_TOKEN,
- "Content-Type": "application/json"
- },
- json={
- "model": INFERENCE_MODEL,
- "messages": [
- {"role": "system", "content": "You are a helpful assistant."},
- {
- "role": "user",
- "content": [
- {
- "type": "image_url",
- "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
- },
- {"type": "text", "text": prompt}
- ]
- }
- ],
- "max_tokens": 4096,
- "stream": False,
- "temperature": 0
- },
- timeout=600
- )
- response.raise_for_status()
- resp_json = response.json()
- choice = resp_json["choices"][0]
- finish_reason = choice.get("finish_reason", "")
- content = choice["message"]["content"]
- if finish_reason == "length":
- # 输出被 token 数截断,内容不完整,重试
- raise RuntimeError(
- f"步骤[{index}] 模型输出被截断(finish_reason=length),"
- f"第 {attempt + 1} 次尝试"
- )
- logger.info(f"步骤[{index}] finish_reason={finish_reason}")
- return index, content
- except RuntimeError as e:
- last_err = e
- if attempt < max_retries:
- wait = 2 ** attempt
- logger.warning(f"{e},{wait}s 后重试...")
- time.sleep(wait)
- except requests.RequestException as e:
- last_err = e
- if attempt < max_retries:
- wait = 2 ** attempt
- logger.warning(f"步骤[{index}] 请求异常: {e},{wait}s 后重试...")
- time.sleep(wait)
- raise RuntimeError(f"步骤[{index}] 重试 {max_retries} 次后仍失败: {last_err}")
- @staticmethod
- def _parse_json(text: str, step_name: str) -> dict:
- """
- 解析模型返回的 JSON 文本,自动清洗 ```json``` 标记。
- 若直接解析失败,尝试用正则从文本中提取第一个 JSON 对象/数组。
- 解析失败时抛出 RuntimeError。
- """
- text = text.strip()
- if text.startswith("```"):
- lines = text.splitlines()
- text = "\n".join(
- line for line in lines
- if not line.strip().startswith("```")
- ).strip()
- try:
- return json.loads(text)
- except json.JSONDecodeError:
- # 模型返回了思考过程 + JSON 混合内容,尝试提取第一个 JSON 块
- match = re.search(r'\{[\s\S]*\}', text)
- if match:
- try:
- return json.loads(match.group())
- except json.JSONDecodeError:
- pass
- raise RuntimeError(
- f"步骤[{step_name}]模型返回内容无法解析为 JSON\n原始内容: {text[:200]}"
- )
- def agent_ocr(self, image):
- """qwen_ocr提取化学品安全标签信息"""
- image = resize_image(image, max_size=_IMAGE_MAX_SIZE)
- quality = _IMAGE_COMPRESS_QUALITY if _IMAGE_COMPRESS else 95
- image_base64 = image_to_base64(image, quality=quality)
- logger.info(f"图像预处理: max_size={_IMAGE_MAX_SIZE}, compress={_IMAGE_COMPRESS}, quality={quality}")
- # 为象形图识别单独准备高分辨率裁剪图
- # 竖向长图(高/宽 > 1.5)只取上部 30%,横向图取上部 60%
- w, h = image.size
- ratio = h / w
- crop_ratio = 0.30 if ratio > 1.5 else 0.60
- icon_crop = image.crop((0, 0, w, int(h * crop_ratio)))
- # 放大使图标细节清晰,长边不超过 1600px
- scale = min(1600 / icon_crop.width, 1600 / icon_crop.height, 3)
- icon_crop = icon_crop.resize(
- (int(icon_crop.width * scale), int(icon_crop.height * scale)),
- Image.Resampling.LANCZOS
- )
- enhancer = ImageEnhance.Contrast(icon_crop)
- icon_crop = enhancer.enhance(1.3)
- icon_image_base64 = image_to_base64(icon_crop, quality=95)
- logger.info(f"象形图识别用裁剪图: {icon_crop.size} (crop_ratio={crop_ratio})")
- start_time = time.perf_counter()
- prompts = [
- PROMPT_EXTRACT_ICON, # 0
- PROMPT_EXTRACT_NAME, # 1
- PROMPT_EXTRACT_COMPONENTS, # 2
- PROMPT_EXTRACT_KEYWORD, # 3
- PROMPT_EXTRACT_PREVENTION, # 4
- PROMPT_EXTRACT_SUPPLIER # 5
- ]
- # 并行发送 6 个请求,按 index 填回保证顺序
- # index=0 的象形图识别使用带参考图的专用方法
- results = [None] * len(prompts)
- with ThreadPoolExecutor(max_workers=len(prompts)) as executor:
- futures = {}
- futures[executor.submit(self.extract_icon, icon_image_base64)] = 0
- for idx, prompt in enumerate(prompts):
- if idx == 0:
- continue # icon 已单独提交
- futures[executor.submit(self.extract_single, image_base64, prompt, idx)] = idx
- for future in as_completed(futures):
- idx, content = future.result() # 任意一个步骤失败会在此抛出
- results[idx] = content
- end_time = time.perf_counter()
- logger.info(f"推理时间: {end_time - start_time:.3f} 秒")
- # 解析各步骤结果(顺序由 index 保证,与串行时完全一致)
- step_names = ["icon", "name", "components", "keyword", "prevention", "supplier"]
- icon = self._parse_json(results[0], step_names[0])
- name = self._parse_json(results[1], step_names[1])
- tag = self._parse_json(results[2], step_names[2])
- risk_notice = self._parse_json(results[3], step_names[3])
- pre_notice = self._parse_json(results[4], step_names[4])
- suppliers = self._parse_json(results[5], step_names[5])
- return {
- "tag": {
- "name_cn": name["name_cn"],
- "name_en": name["name_en"],
- "cf_list": tag["cf_list"]
- },
- "tag_images": icon["tag_images"],
- "key_word": risk_notice["key_word"],
- "risk_notice": risk_notice["risk_notice"],
- "pre_notice": pre_notice["pre_notice"],
- "supplier": suppliers["supplier"],
- "acc_tel": suppliers["acc_tel"],
- }
- if __name__ == "__main__":
- image = Image.open("./test1.jpg").convert("RGB")
- agent = OcrAgent()
- res = agent.agent_ocr(image)
- print(res)
|