agent.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from config import MODEL_PATH, INFERENCE_URL, INFERENCE_AUTH_TOKEN, INFERENCE_MODEL, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION,PROMPT_EXTRACT_SUPPLIER,PROMPT_EXTRACT_ICON
  2. from io import BytesIO
  3. import base64
  4. import json
  5. from PIL import Image, ImageFilter, ImageEnhance
  6. import time
  7. import requests
  8. def image_to_base64(pil_image, image_format="JPEG"):
  9. """将PIL Image图像转换为Base64编码"""
  10. buffered = BytesIO()
  11. pil_image.save(buffered, format=image_format)
  12. img_byte_array = buffered.getvalue()
  13. encode_image = base64.b64encode(img_byte_array).decode('utf-8')
  14. return encode_image
  15. def resize_image(image, max_size=512):
  16. """缩放图像尺寸,保持 OCR 质量"""
  17. width, height = image.size
  18. max_dim = max(width, height)
  19. # 如果图像不需要缩小,直接返回
  20. if max_dim <= max_size:
  21. return image
  22. scaling_factor = max_size / max_dim
  23. new_width = int(width * scaling_factor)
  24. new_height = int(height * scaling_factor)
  25. # 使用 LANCZOS 高质量缩放
  26. resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
  27. # 应用 UnsharpMask 锐化,补偿缩放损失
  28. resized = resized.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3))
  29. # 轻微增强对比度,提高文字识别率
  30. enhancer = ImageEnhance.Contrast(resized)
  31. resized = enhancer.enhance(1.1)
  32. return resized
  33. class OcrAgent:
  34. def __init__(self):
  35. self._url = INFERENCE_URL
  36. def extract_single(self, image_base64: str, prompt: str, index: int):
  37. """单个任务请求,返回 (index, 结果文本)"""
  38. response = requests.post(
  39. self._url,
  40. headers={
  41. "Authorization": INFERENCE_AUTH_TOKEN,
  42. "Content-Type": "application/json"
  43. },
  44. json={
  45. "model": INFERENCE_MODEL,
  46. "messages": [
  47. {"role": "system", "content": "You are a helpful assistant."},
  48. {
  49. "role": "user",
  50. "content": [
  51. {
  52. "type": "image_url",
  53. "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
  54. },
  55. {"type": "text", "text": prompt}
  56. ]
  57. }
  58. ],
  59. "max_tokens": 4096,
  60. "stream": False,
  61. "temperature": 0
  62. },
  63. timeout=600
  64. )
  65. response.raise_for_status()
  66. content = response.json()["choices"][0]["message"]["content"]
  67. return index, content
  68. @staticmethod
  69. def _parse_json(text: str, step_name: str) -> dict:
  70. """
  71. 解析模型返回的 JSON 文本,自动清洗 ```json``` 标记。
  72. 解析失败时抛出 RuntimeError(不会被 ValueError 捕获误报为"参数验证失败")。
  73. """
  74. # 去除首尾空白
  75. text = text.strip()
  76. # 兼容模型偶尔返回 ```json ... ``` 包裹的情况
  77. if text.startswith("```"):
  78. lines = text.splitlines()
  79. # 去掉首行的 ```json 或 ``` 和末行的 ```
  80. text = "\n".join(
  81. line for line in lines
  82. if not line.strip().startswith("```")
  83. ).strip()
  84. try:
  85. return json.loads(text)
  86. except json.JSONDecodeError as e:
  87. raise RuntimeError(
  88. f"步骤[{step_name}]模型返回内容无法解析为 JSON: {e}\n原始内容: {text[:200]}"
  89. )
  90. def agent_ocr(self, image):
  91. """qwen_ocr提取化学品安全标签信息"""
  92. image = resize_image(image, max_size=512)
  93. image_base64 = image_to_base64(image)
  94. start_time = time.perf_counter()
  95. # 定义需要并行执行的任务(顺序固定,用 index 保序)
  96. prompts = [
  97. PROMPT_EXTRACT_ICON, # 0
  98. PROMPT_EXTRACT_NAME, # 1
  99. PROMPT_EXTRACT_COMPONENTS, # 2
  100. PROMPT_EXTRACT_KEYWORD, # 3
  101. PROMPT_EXTRACT_PREVENTION, # 4
  102. PROMPT_EXTRACT_SUPPLIER # 5
  103. ]
  104. # 串行发送 6 个请求
  105. results = []
  106. for idx, prompt in enumerate(prompts):
  107. _, content = self.extract_single(image_base64, prompt, idx)
  108. results.append(content)
  109. # 从结果中提取数据(顺序已由 index 保证)
  110. step_names = ["icon", "name", "components", "keyword", "prevention", "supplier"]
  111. icon = self._parse_json(results[0], step_names[0])
  112. name = self._parse_json(results[1], step_names[1])
  113. tag = self._parse_json(results[2], step_names[2])
  114. risk_notice = self._parse_json(results[3], step_names[3])
  115. pre_notice = self._parse_json(results[4], step_names[4])
  116. suppliers = self._parse_json(results[5], step_names[5])
  117. end_time = time.perf_counter()
  118. elapsed_time = end_time - start_time
  119. print(f"推理时间: {elapsed_time:.6f} 秒")
  120. result = {
  121. "tag": {
  122. "name_cn": name["name_cn"],
  123. "name_en": name["name_en"],
  124. "cf_list": tag["cf_list"]
  125. },
  126. "tag_images": icon["tag_images"],
  127. "key_word": risk_notice["key_word"],
  128. "risk_notice": risk_notice["risk_notice"],
  129. "pre_notice": pre_notice["pre_notice"],
  130. "supplier": suppliers["supplier"],
  131. "acc_tel": suppliers["acc_tel"],
  132. }
  133. return result
  134. if __name__ == "__main__":
  135. image = Image.open("./test1.jpg").convert("RGB")
  136. agent = OcrAgent()
  137. res = agent.agent_ocr(image)
  138. print(res)