agent.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from config import MODEL_PATH, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION,PROMPT_EXTRACT_SUPPLIER,PROMPT_EXTRACT_ICON
  2. from model import QwenOcr
  3. from io import BytesIO
  4. import base64
  5. import json
  6. from PIL import Image, ImageFilter, ImageEnhance
  7. import time
  8. import requests
  9. def image_to_base64(pil_image, image_format="JPEG"):
  10. """将PIL Image图像转换为Base64编码"""
  11. buffered = BytesIO()
  12. pil_image.save(buffered, format=image_format)
  13. img_byte_array = buffered.getvalue()
  14. encode_image = base64.b64encode(img_byte_array).decode('utf-8')
  15. return encode_image
  16. def resize_image(image, max_size=512):
  17. """缩放图像尺寸,保持 OCR 质量"""
  18. width, height = image.size
  19. max_dim = max(width, height)
  20. # 如果图像不需要缩小,直接返回
  21. if max_dim <= max_size:
  22. return image
  23. scaling_factor = max_size / max_dim
  24. new_width = int(width * scaling_factor)
  25. new_height = int(height * scaling_factor)
  26. # 使用 LANCZOS 高质量缩放
  27. resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
  28. # 应用 UnsharpMask 锐化,补偿缩放损失
  29. resized = resized.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3))
  30. # 轻微增强对比度,提高文字识别率
  31. enhancer = ImageEnhance.Contrast(resized)
  32. resized = enhancer.enhance(1.1)
  33. return resized
  34. class OcrAgent:
  35. def __init__(self):
  36. self._url = "http://127.0.0.1:8000/api/v1/ocr"
  37. def extract_single(self, image_base64: str, prompt: str, index: int):
  38. """单个任务请求,返回 (index, 结果文本)"""
  39. response = requests.post(
  40. self._url,
  41. json={
  42. "model": "Qwen3-VL-32B-Instruct",
  43. "messages": [
  44. {"role": "system", "content": "You are a helpful assistant."},
  45. {
  46. "role": "user",
  47. "content": [
  48. {
  49. "type": "image_url",
  50. "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
  51. },
  52. {"type": "text", "text": prompt}
  53. ]
  54. }
  55. ],
  56. "max_tokens": 4096,
  57. "stream": False,
  58. "temperature": 0
  59. },
  60. timeout=600
  61. )
  62. response.raise_for_status()
  63. content = response.json()["choices"][0]["message"]["content"]
  64. return index, content
  65. def agent_ocr(self, image):
  66. """qwen_ocr提取化学品安全标签信息"""
  67. image = resize_image(image, max_size=512)
  68. image_base64 = image_to_base64(image)
  69. start_time = time.perf_counter()
  70. # 定义需要并行执行的任务(顺序固定,用 index 保序)
  71. prompts = [
  72. PROMPT_EXTRACT_ICON, # 0
  73. PROMPT_EXTRACT_NAME, # 1
  74. PROMPT_EXTRACT_COMPONENTS, # 2
  75. PROMPT_EXTRACT_KEYWORD, # 3
  76. PROMPT_EXTRACT_PREVENTION, # 4
  77. PROMPT_EXTRACT_SUPPLIER # 5
  78. ]
  79. # 串行发送 6 个请求
  80. results = []
  81. for idx, prompt in enumerate(prompts):
  82. _, content = self.extract_single(image_base64, prompt, idx)
  83. results.append(content)
  84. # 从结果中提取数据(顺序已由 index 保证)
  85. icon = json.loads(results[0])
  86. name = json.loads(results[1])
  87. tag = json.loads(results[2])
  88. risk_notice = json.loads(results[3])
  89. pre_notice = json.loads(results[4])
  90. suppliers = json.loads(results[5])
  91. end_time = time.perf_counter()
  92. elapsed_time = end_time - start_time
  93. print(f"推理时间: {elapsed_time:.6f} 秒")
  94. result = {
  95. "tag": {
  96. "name_cn": name["name_cn"],
  97. "name_en": name["name_en"],
  98. "cf_list": tag["cf_list"]
  99. },
  100. "tag_images": icon["tag_images"],
  101. "key_word": risk_notice["key_word"],
  102. "risk_notice": risk_notice["risk_notice"],
  103. "pre_notice": pre_notice["pre_notice"],
  104. "supplier": suppliers["supplier"],
  105. "acc_tel": suppliers["acc_tel"],
  106. }
  107. return result
  108. if __name__ == "__main__":
  109. image = Image.open("./test1.jpg").convert("RGB")
  110. agent = OcrAgent()
  111. res = agent.agent_ocr(image)
  112. print(res)