agent.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. from config import MODEL_PATH, INFERENCE_URL, INFERENCE_AUTH_TOKEN, INFERENCE_MODEL, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION,PROMPT_EXTRACT_SUPPLIER,PROMPT_EXTRACT_ICON
  2. from io import BytesIO
  3. from concurrent.futures import ThreadPoolExecutor, as_completed
  4. import base64
  5. import json
  6. import logging
  7. import os
  8. from PIL import Image, ImageFilter, ImageEnhance
  9. import time
  10. import re
  11. import requests
  12. logger = logging.getLogger(__name__)
  13. # 从环境变量读取图像预处理配置(由 start.py 启动时注入)
  14. _IMAGE_MAX_SIZE = int(os.environ.get("IMAGE_MAX_SIZE", 512))
  15. _IMAGE_COMPRESS = os.environ.get("IMAGE_COMPRESS", "false").lower() == "true"
  16. _IMAGE_COMPRESS_QUALITY = int(os.environ.get("IMAGE_COMPRESS_QUALITY", 70))
  17. # GHS 参考图目录(与 agent.py 同级的上级目录下的 ghs_icons/)
  18. _GHS_ICONS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ghs_icons")
  19. _GHS_ICON_NAMES = [f"GHS{i:02d}" for i in range(1, 10)]
  20. def _load_ghs_reference_images() -> dict:
  21. """加载 GHS01-GHS09 参考图,返回 {名称: base64} 字典。缺失的图跳过。"""
  22. refs = {}
  23. for name in _GHS_ICON_NAMES:
  24. path = os.path.join(_GHS_ICONS_DIR, f"{name}.png")
  25. if os.path.exists(path):
  26. with open(path, "rb") as f:
  27. refs[name] = base64.b64encode(f.read()).decode("utf-8")
  28. else:
  29. logger.warning(f"GHS 参考图缺失: {path}")
  30. return refs
  31. # 启动时加载一次,避免每次请求重复读文件
  32. _GHS_REFERENCE_IMAGES: dict = _load_ghs_reference_images()
  33. # 容易混淆的图标对,判断目标图标时同时发送干扰项供模型对比排除
  34. # GHS09 外形独特(枯树+死鱼),不发送混淆图,避免反向干扰
  35. _GHS_CONFUSABLE = {
  36. "GHS02": ["GHS03"],
  37. "GHS03": ["GHS02"],
  38. "GHS07": ["GHS08"],
  39. "GHS08": ["GHS07", "GHS06", "GHS09"],
  40. "GHS09": ["GHS08"],
  41. }
  42. # 每个图标的判定关键特征,嵌入提示词中强化区分
  43. _GHS_DISCRIMINATIVE_HINT = {
  44. "GHS01": "关键特征:圆形炸弹向四周炸裂,有大量射线和碎片飞散。",
  45. "GHS02": "关键特征:只有火焰图案,火焰底部有横线底座。火焰中间没有任何圆圈/圆环,这是与GHS03的根本区别。",
  46. "GHS03": (
  47. "【GHS03唯一判定标准】:图标内部必须同时存在两个元素:(1)火焰 AND (2)火焰包围的空心大圆圈/圆环(圆圈内部镂空呈白色)。"
  48. "缺少圆圈就是GHS02,不是GHS03。"
  49. "请先回答:图标中有没有空心圆圈?如果没有圆圈,直接回答NO。"
  50. ),
  51. "GHS04": (
  52. "关键特征:一个横置的粗短圆柱形气瓶/钢瓶,右侧伸出一根细长的阀门管道。"
  53. "整体形状类似一个横向的短粗矩形/圆柱体,右侧有细管伸出。"
  54. "在小尺寸印刷版本中,整个图案看起来像一个横向的短粗横条(比感叹号的竖线更粗更短,且是水平的)。"
  55. "图标内没有火焰、没有感叹号竖线(感叹号是竖向的,气瓶是横向的)、没有人形、没有树。"
  56. "即使只看到一个横向的粗短形状,也可以回答YES。"
  57. ),
  58. "GHS05": (
  59. "关键特征:图标内有腐蚀性液体滴落的场景——"
  60. "上方有试管/容器,液体向下滴落,腐蚀下方的物体(金属板或手掌)。"
  61. "整体看起来像'液体从上往下滴,下方被腐蚀出缺口'的形状。"
  62. "印刷版可能很小,但可以看出上方有细管状物体、下方有不规则缺口形状。"
  63. "即使细节模糊,只要能看出'滴落腐蚀'的大致形状,就回答YES。"
  64. ),
  65. "GHS06": "关键特征:骷髅头(空心白色)加下方两根交叉骨头。",
  66. "GHS07": '关键特征:只有一个感叹号"!"(上方竖条加下方圆点),无任何其他图形,没有人形,没有骷髅。',
  67. "GHS08": (
  68. "关键特征:实心黑色人体上半身剪影(有明确的头部+肩膀+躯干轮廓),胸口有白色裂缝/射线向四周放射。"
  69. "【重要】树木/植物形状不是GHS08——如果图案看起来像树枝或植物,那是GHS09而不是GHS08,回答NO。"
  70. "不是骷髅,不是感叹号,不是树。"
  71. ),
  72. "GHS09": (
  73. "GHS09是环境危害象形图,图标内有两个有机生物形状:"
  74. "左边是一棵枯树(竖直树干+向两侧伸出的树枝,整体呈Y形或T形的树状轮廓),"
  75. "右边是一条死鱼(横向椭圆形鱼身轮廓,肚皮朝上翻转)。"
  76. "这两个形状与其他GHS图标完全不同——其他图标内没有植物或鱼类形状。"
  77. "请对照参考图:如果标签中某个菱形图标内的图案与参考图相似(有树形和/或鱼形),回答YES。"
  78. ),
  79. }
  80. def image_to_base64(pil_image, image_format="JPEG", quality=95):
  81. """将PIL Image图像转换为Base64编码"""
  82. buffered = BytesIO()
  83. pil_image.save(buffered, format=image_format, quality=quality)
  84. img_byte_array = buffered.getvalue()
  85. encode_image = base64.b64encode(img_byte_array).decode('utf-8')
  86. return encode_image
  87. def resize_image(image, max_size=512):
  88. """缩放图像尺寸,保持 OCR 质量"""
  89. width, height = image.size
  90. max_dim = max(width, height)
  91. if max_dim <= max_size:
  92. return image
  93. scaling_factor = max_size / max_dim
  94. new_width = int(width * scaling_factor)
  95. new_height = int(height * scaling_factor)
  96. resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
  97. resized = resized.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3))
  98. enhancer = ImageEnhance.Contrast(resized)
  99. resized = enhancer.enhance(1.1)
  100. return resized
  101. class OcrAgent:
  102. def __init__(self):
  103. self._url = INFERENCE_URL
  104. def _check_single_icon(self, ghs_name: str, ref_b64: str, image_base64: str, max_retries: int = 2) -> bool:
  105. """二分类:判断标签图中是否存在指定的 GHS 象形图,返回 True/False。
  106. 对容易混淆的图标,同时发送干扰项参考图,让模型对比排除。
  107. """
  108. content = [
  109. {"type": "text", "text": f"以下是标准 {ghs_name} 象形图的参考图(目标图标):"},
  110. {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ref_b64}"}},
  111. ]
  112. # 如果该图标有容易混淆的图标,一并发送供对比
  113. confusable = _GHS_CONFUSABLE.get(ghs_name, [])
  114. if confusable:
  115. content.append({"type": "text", "text": f"以下是容易与 {ghs_name} 混淆的图标,注意区分:"})
  116. for conf_name in confusable:
  117. if conf_name in _GHS_REFERENCE_IMAGES:
  118. content.append({"type": "text", "text": f"({conf_name},不是目标图标)"})
  119. content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{_GHS_REFERENCE_IMAGES[conf_name]}"}})
  120. content.append({"type": "text", "text": "以下是需要识别的化学品安全标签图像:"})
  121. content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}})
  122. discriminative = _GHS_DISCRIMINATIVE_HINT.get(ghs_name, "")
  123. confusable_hint = f"注意不要把它和 {'/'.join(confusable)} 混淆。\n" if confusable else ""
  124. if ghs_name == "GHS03":
  125. final_question = (
  126. f"请仔细检查最后一张标签图像中的每个菱形图标。\n"
  127. f"{discriminative}\n"
  128. f"关键问题:标签中是否有图标在火焰图案内部包含了一个明显的空心圆圈/圆环?\n"
  129. f"如果有圆圈→YES,如果只有火焰没有圆圈→NO。\n"
  130. f"只回答 YES 或 NO,不要输出其他任何内容。"
  131. )
  132. elif ghs_name == "GHS09":
  133. final_question = (
  134. f"请仔细观察最后一张化学品安全标签图像中的所有菱形图标。\n"
  135. f"参考图是GHS09(环境危害)标准图标,内含枯树和死鱼。\n"
  136. f"{discriminative}\n"
  137. f"注意:印刷版标签中的GHS09图标为黑白色(无红色边框),"
  138. f"图案可能很小,但可以看出树形轮廓(树干+分叉枝条)和/或鱼形(椭圆形鱼身)。\n"
  139. f"对照参考图,标签中是否有任何一个菱形图标包含了【枯树】或【死鱼】图案?\n"
  140. f"只要识别出树形或鱼形,就回答YES;完全看不出来才回答NO。\n"
  141. f"只回答 YES 或 NO,不要输出其他任何内容。"
  142. )
  143. else:
  144. final_question = (
  145. f"请仔细观察上方图片。\n"
  146. f"第一张是目标图标 {ghs_name} 的标准参考图。\n"
  147. f"{discriminative}\n"
  148. f"{confusable_hint}"
  149. f"请严格对照上述关键特征,判断:最后一张化学品安全标签图像中是否包含 {ghs_name} 图标?\n"
  150. f"必须所有关键特征都匹配才回答YES,有任何一条不符合就回答NO。\n"
  151. f"只回答 YES 或 NO,不要输出其他任何内容。"
  152. )
  153. content.append({"type": "text", "text": final_question})
  154. last_err = None
  155. for attempt in range(max_retries + 1):
  156. try:
  157. response = requests.post(
  158. self._url,
  159. headers={
  160. "Authorization": INFERENCE_AUTH_TOKEN,
  161. "Content-Type": "application/json"
  162. },
  163. json={
  164. "model": INFERENCE_MODEL,
  165. "messages": [
  166. {"role": "system", "content": "You are a helpful assistant. Answer only YES or NO."},
  167. {"role": "user", "content": content}
  168. ],
  169. "max_tokens": 16,
  170. "stream": False,
  171. "temperature": 0
  172. },
  173. timeout=600
  174. )
  175. response.raise_for_status()
  176. resp_json = response.json()
  177. answer = resp_json["choices"][0]["message"]["content"].strip().upper()
  178. logger.info(f"[icon binary] {ghs_name} -> {answer}")
  179. return answer.startswith("YES")
  180. except requests.RequestException as e:
  181. last_err = e
  182. if attempt < max_retries:
  183. wait = 2 ** attempt
  184. logger.warning(f"[icon binary] {ghs_name} 请求异常: {e},{wait}s 后重试...")
  185. time.sleep(wait)
  186. logger.error(f"[icon binary] {ghs_name} 重试 {max_retries} 次后仍失败: {last_err}")
  187. return False
  188. def _confirm_ghs03(self, image_base64: str) -> bool:
  189. """GHS03 二次确认:直接询问标签中是否存在内部含空心圆圈的火焰图标。
  190. 用于过滤 _check_single_icon 的 false-positive。"""
  191. ghs02_b64 = _GHS_REFERENCE_IMAGES.get("GHS02", "")
  192. ghs03_b64 = _GHS_REFERENCE_IMAGES.get("GHS03", "")
  193. content = []
  194. if ghs02_b64:
  195. content.append({"type": "text", "text": "参考图A — GHS02(仅火焰,无圆圈):"})
  196. content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ghs02_b64}"}})
  197. if ghs03_b64:
  198. content.append({"type": "text", "text": "参考图B — GHS03(火焰内部有空心圆圈):"})
  199. content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{ghs03_b64}"}})
  200. content.append({"type": "text", "text": "待检查的化学品安全标签:"})
  201. content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}})
  202. content.append({"type": "text", "text": (
  203. "请仔细观察标签中所有菱形图标内的图案。\n"
  204. "问题:标签中是否存在【参考图B】所示的GHS03图标——即火焰图案内部有一个明显的空心圆圈/圆环?\n"
  205. "注意:\n"
  206. "- 如果只看到【参考图A】类型的纯火焰(无圆圈),回答NO\n"
  207. "- 菱形框本身不算圆圈,必须是火焰内部的圆圈\n"
  208. "- 如果没有火焰图案,回答NO\n"
  209. "只回答 YES 或 NO,不要输出其他内容。"
  210. )})
  211. try:
  212. response = requests.post(
  213. self._url,
  214. headers={"Authorization": INFERENCE_AUTH_TOKEN, "Content-Type": "application/json"},
  215. json={
  216. "model": INFERENCE_MODEL,
  217. "messages": [
  218. {"role": "system", "content": "You are a helpful assistant. Answer only YES or NO."},
  219. {"role": "user", "content": content}
  220. ],
  221. "max_tokens": 16,
  222. "stream": False,
  223. "temperature": 0
  224. },
  225. timeout=600
  226. )
  227. response.raise_for_status()
  228. answer = response.json()["choices"][0]["message"]["content"].strip().upper()
  229. logger.info(f"[icon GHS03 confirm] -> {answer}")
  230. return answer.startswith("YES")
  231. except Exception as e:
  232. logger.warning(f"[icon GHS03 confirm] 请求失败: {e},保守返回False")
  233. return False
  234. def extract_icon(self, image_base64: str, max_retries: int = 2):
  235. """象形图识别:对 GHS01-GHS09 逐个并行做二分类,返回 (0, JSON字符串)。"""
  236. with ThreadPoolExecutor(max_workers=len(_GHS_REFERENCE_IMAGES)) as executor:
  237. futures = {
  238. executor.submit(self._check_single_icon, name, ref_b64, image_base64, max_retries): name
  239. for name, ref_b64 in _GHS_REFERENCE_IMAGES.items()
  240. }
  241. results = {}
  242. for future in as_completed(futures):
  243. name = futures[future]
  244. results[name] = future.result()
  245. matched = [name for name in _GHS_ICON_NAMES if results.get(name)]
  246. # GHS02 和 GHS03 互斥:当两者同时出现时保留GHS02、丢弃GHS03
  247. if "GHS02" in matched and "GHS03" in matched:
  248. logger.info("[icon] GHS02/GHS03 冲突,保留GHS02,丢弃GHS03")
  249. matched.remove("GHS03")
  250. # 对 GHS03 做二次确认,减少 false-positive
  251. if "GHS03" in matched:
  252. confirmed = self._confirm_ghs03(image_base64)
  253. if not confirmed:
  254. logger.info("[icon] GHS03 二次确认为 NO,移除")
  255. matched.remove("GHS03")
  256. logger.info(f"[icon] 识别结果: {matched}")
  257. return 0, json.dumps({"tag_images": matched}, ensure_ascii=False)
  258. def extract_single(self, image_base64: str, prompt: str, index: int, max_retries: int = 2):
  259. """单个任务请求,返回 (index, 结果文本)。失败时最多重试 max_retries 次。"""
  260. last_err = None
  261. for attempt in range(max_retries + 1):
  262. try:
  263. response = requests.post(
  264. self._url,
  265. headers={
  266. "Authorization": INFERENCE_AUTH_TOKEN,
  267. "Content-Type": "application/json"
  268. },
  269. json={
  270. "model": INFERENCE_MODEL,
  271. "messages": [
  272. {"role": "system", "content": "You are a helpful assistant."},
  273. {
  274. "role": "user",
  275. "content": [
  276. {
  277. "type": "image_url",
  278. "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
  279. },
  280. {"type": "text", "text": prompt}
  281. ]
  282. }
  283. ],
  284. "max_tokens": 4096,
  285. "stream": False,
  286. "temperature": 0
  287. },
  288. timeout=600
  289. )
  290. response.raise_for_status()
  291. resp_json = response.json()
  292. choice = resp_json["choices"][0]
  293. finish_reason = choice.get("finish_reason", "")
  294. content = choice["message"]["content"]
  295. if finish_reason == "length":
  296. # 输出被 token 数截断,内容不完整,重试
  297. raise RuntimeError(
  298. f"步骤[{index}] 模型输出被截断(finish_reason=length),"
  299. f"第 {attempt + 1} 次尝试"
  300. )
  301. logger.info(f"步骤[{index}] finish_reason={finish_reason}")
  302. return index, content
  303. except RuntimeError as e:
  304. last_err = e
  305. if attempt < max_retries:
  306. wait = 2 ** attempt
  307. logger.warning(f"{e},{wait}s 后重试...")
  308. time.sleep(wait)
  309. except requests.RequestException as e:
  310. last_err = e
  311. if attempt < max_retries:
  312. wait = 2 ** attempt
  313. logger.warning(f"步骤[{index}] 请求异常: {e},{wait}s 后重试...")
  314. time.sleep(wait)
  315. raise RuntimeError(f"步骤[{index}] 重试 {max_retries} 次后仍失败: {last_err}")
  316. @staticmethod
  317. def _parse_json(text: str, step_name: str) -> dict:
  318. """
  319. 解析模型返回的 JSON 文本,自动清洗 ```json``` 标记。
  320. 若直接解析失败,尝试用正则从文本中提取第一个 JSON 对象/数组。
  321. 解析失败时抛出 RuntimeError。
  322. """
  323. text = text.strip()
  324. if text.startswith("```"):
  325. lines = text.splitlines()
  326. text = "\n".join(
  327. line for line in lines
  328. if not line.strip().startswith("```")
  329. ).strip()
  330. try:
  331. return json.loads(text)
  332. except json.JSONDecodeError:
  333. # 模型返回了思考过程 + JSON 混合内容,尝试提取第一个 JSON 块
  334. match = re.search(r'\{[\s\S]*\}', text)
  335. if match:
  336. try:
  337. return json.loads(match.group())
  338. except json.JSONDecodeError:
  339. pass
  340. raise RuntimeError(
  341. f"步骤[{step_name}]模型返回内容无法解析为 JSON\n原始内容: {text[:200]}"
  342. )
  343. def agent_ocr(self, image):
  344. """qwen_ocr提取化学品安全标签信息"""
  345. image = resize_image(image, max_size=_IMAGE_MAX_SIZE)
  346. quality = _IMAGE_COMPRESS_QUALITY if _IMAGE_COMPRESS else 95
  347. image_base64 = image_to_base64(image, quality=quality)
  348. logger.info(f"图像预处理: max_size={_IMAGE_MAX_SIZE}, compress={_IMAGE_COMPRESS}, quality={quality}")
  349. # 为象形图识别单独准备高分辨率裁剪图
  350. # 竖向长图(高/宽 > 1.5)只取上部 30%,横向图取上部 60%
  351. w, h = image.size
  352. ratio = h / w
  353. crop_ratio = 0.30 if ratio > 1.5 else 0.60
  354. icon_crop = image.crop((0, 0, w, int(h * crop_ratio)))
  355. # 放大使图标细节清晰,长边不超过 1600px
  356. scale = min(1600 / icon_crop.width, 1600 / icon_crop.height, 3)
  357. icon_crop = icon_crop.resize(
  358. (int(icon_crop.width * scale), int(icon_crop.height * scale)),
  359. Image.Resampling.LANCZOS
  360. )
  361. enhancer = ImageEnhance.Contrast(icon_crop)
  362. icon_crop = enhancer.enhance(1.3)
  363. icon_image_base64 = image_to_base64(icon_crop, quality=95)
  364. logger.info(f"象形图识别用裁剪图: {icon_crop.size} (crop_ratio={crop_ratio})")
  365. start_time = time.perf_counter()
  366. prompts = [
  367. PROMPT_EXTRACT_ICON, # 0
  368. PROMPT_EXTRACT_NAME, # 1
  369. PROMPT_EXTRACT_COMPONENTS, # 2
  370. PROMPT_EXTRACT_KEYWORD, # 3
  371. PROMPT_EXTRACT_PREVENTION, # 4
  372. PROMPT_EXTRACT_SUPPLIER # 5
  373. ]
  374. # 并行发送 6 个请求,按 index 填回保证顺序
  375. # index=0 的象形图识别使用带参考图的专用方法
  376. results = [None] * len(prompts)
  377. with ThreadPoolExecutor(max_workers=len(prompts)) as executor:
  378. futures = {}
  379. futures[executor.submit(self.extract_icon, icon_image_base64)] = 0
  380. for idx, prompt in enumerate(prompts):
  381. if idx == 0:
  382. continue # icon 已单独提交
  383. futures[executor.submit(self.extract_single, image_base64, prompt, idx)] = idx
  384. for future in as_completed(futures):
  385. idx, content = future.result() # 任意一个步骤失败会在此抛出
  386. results[idx] = content
  387. end_time = time.perf_counter()
  388. logger.info(f"推理时间: {end_time - start_time:.3f} 秒")
  389. # 解析各步骤结果(顺序由 index 保证,与串行时完全一致)
  390. step_names = ["icon", "name", "components", "keyword", "prevention", "supplier"]
  391. icon = self._parse_json(results[0], step_names[0])
  392. name = self._parse_json(results[1], step_names[1])
  393. tag = self._parse_json(results[2], step_names[2])
  394. risk_notice = self._parse_json(results[3], step_names[3])
  395. pre_notice = self._parse_json(results[4], step_names[4])
  396. suppliers = self._parse_json(results[5], step_names[5])
  397. return {
  398. "tag": {
  399. "name_cn": name["name_cn"],
  400. "name_en": name["name_en"],
  401. "cf_list": tag["cf_list"]
  402. },
  403. "tag_images": icon["tag_images"],
  404. "key_word": risk_notice["key_word"],
  405. "risk_notice": risk_notice["risk_notice"],
  406. "pre_notice": pre_notice["pre_notice"],
  407. "supplier": suppliers["supplier"],
  408. "acc_tel": suppliers["acc_tel"],
  409. }
  410. if __name__ == "__main__":
  411. image = Image.open("./test1.jpg").convert("RGB")
  412. agent = OcrAgent()
  413. res = agent.agent_ocr(image)
  414. print(res)