agent.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from config import MODEL_PATH, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION,PROMPT_EXTRACT_SUPPLIER,PROMPT_EXTRACT_ICON
  2. from model import QwenOcr
  3. from io import BytesIO
  4. import base64
  5. import json
  6. from PIL import Image, ImageFilter, ImageEnhance
  7. import time
  8. from concurrent.futures import ThreadPoolExecutor, as_completed
  9. import requests
  10. def image_to_base64(pil_image, image_format="JPEG"):
  11. """将PIL Image图像转换为Base64编码"""
  12. buffered = BytesIO()
  13. pil_image.save(buffered, format=image_format)
  14. img_byte_array = buffered.getvalue()
  15. encode_image = base64.b64encode(img_byte_array).decode('utf-8')
  16. return encode_image
  17. def resize_image(image, max_size=512):
  18. """缩放图像尺寸,保持 OCR 质量"""
  19. width, height = image.size
  20. max_dim = max(width, height)
  21. # 如果图像不需要缩小,直接返回
  22. if max_dim <= max_size:
  23. return image
  24. scaling_factor = max_size / max_dim
  25. new_width = int(width * scaling_factor)
  26. new_height = int(height * scaling_factor)
  27. # 使用 LANCZOS 高质量缩放
  28. resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
  29. # 应用 UnsharpMask 锐化,补偿缩放损失
  30. resized = resized.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3))
  31. # 轻微增强对比度,提高文字识别率
  32. enhancer = ImageEnhance.Contrast(resized)
  33. resized = enhancer.enhance(1.1)
  34. return resized
  35. class OcrAgent:
  36. def __init__(self):
  37. self._url = "http://127.0.0.1:8000/api/v1/ocr"
  38. def extract_part_info(self, image_base64, prompt):
  39. """根据提示词提取信息"""
  40. response = requests.post(
  41. self._url,
  42. json={
  43. "image": image_base64,
  44. "text": prompt
  45. }
  46. )
  47. result = response.json()
  48. return json.loads(result['data'][0])
  49. def agent_ocr(self, image):
  50. """qwen_ocr提取化学品安全标签信息"""
  51. image = resize_image(image, max_size=1024)
  52. image_base64 = image_to_base64(image)
  53. start_time = time.perf_counter()
  54. # 定义需要并行执行的任务
  55. tasks = {
  56. 'icon': PROMPT_EXTRACT_ICON,
  57. 'name': PROMPT_EXTRACT_NAME,
  58. 'tag': PROMPT_EXTRACT_COMPONENTS,
  59. 'risk_notice': PROMPT_EXTRACT_KEYWORD,
  60. 'pre_notice': PROMPT_EXTRACT_PREVENTION,
  61. 'suppliers': PROMPT_EXTRACT_SUPPLIER
  62. }
  63. # 使用线程池并行执行所有提取任务
  64. results = {}
  65. with ThreadPoolExecutor(max_workers=6) as executor:
  66. # 提交所有任务
  67. future_to_task = {
  68. executor.submit(self.extract_part_info, image_base64, prompt): task_name
  69. for task_name, prompt in tasks.items()
  70. }
  71. # 收集结果
  72. for future in as_completed(future_to_task):
  73. task_name = future_to_task[future]
  74. try:
  75. results[task_name] = future.result()
  76. except Exception as e:
  77. print(f"任务 {task_name} 执行失败: {e}")
  78. results[task_name] = {}
  79. # 从结果中提取数据
  80. icon = results.get('icon', {})
  81. name = results.get('name', {})
  82. tag = results.get('tag', {})
  83. risk_notice = results.get('risk_notice', {})
  84. pre_notice = results.get('pre_notice', {})
  85. suppliers = results.get('suppliers', {})
  86. end_time = time.perf_counter()
  87. elapsed_time = end_time - start_time
  88. print(f"推理时间: {elapsed_time:.6f} 秒")
  89. result = {
  90. "tag": {
  91. "name_cn": name["name_cn"],
  92. "name_en": name["name_en"],
  93. "cf_list": tag["cf_list"]
  94. },
  95. "tag_images": icon["tag_images"],
  96. "key_word": risk_notice["key_word"],
  97. "risk_notice": risk_notice["risk_notice"],
  98. "pre_notice": pre_notice["pre_notice"],
  99. "supplier": suppliers["supplier"],
  100. "acc_tel": suppliers["acc_tel"],
  101. }
  102. return result
  103. if __name__ == "__main__":
  104. image = Image.open("./test1.jpg").convert("RGB")
  105. agent = OcrAgent()
  106. agent.agent_ocr(image)