agent.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from config import MODEL_PATH, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION,PROMPT_EXTRACT_SUPPLIER,PROMPT_EXTRACT_ICON
  2. from model import QwenOcr
  3. from io import BytesIO
  4. import base64
  5. import json
  6. from PIL import Image, ImageFilter, ImageEnhance
  7. import time
  8. from concurrent.futures import ThreadPoolExecutor, as_completed
  9. import requests
  10. def image_to_base64(pil_image, image_format="JPEG"):
  11. """将PIL Image图像转换为Base64编码"""
  12. buffered = BytesIO()
  13. pil_image.save(buffered, format=image_format)
  14. img_byte_array = buffered.getvalue()
  15. encode_image = base64.b64encode(img_byte_array).decode('utf-8')
  16. return encode_image
  17. def resize_image(image, max_size=512):
  18. """缩放图像尺寸,保持 OCR 质量"""
  19. width, height = image.size
  20. max_dim = max(width, height)
  21. # 如果图像不需要缩小,直接返回
  22. if max_dim <= max_size:
  23. return image
  24. scaling_factor = max_size / max_dim
  25. new_width = int(width * scaling_factor)
  26. new_height = int(height * scaling_factor)
  27. # 使用 LANCZOS 高质量缩放
  28. resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
  29. # 应用 UnsharpMask 锐化,补偿缩放损失
  30. resized = resized.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3))
  31. # 轻微增强对比度,提高文字识别率
  32. enhancer = ImageEnhance.Contrast(resized)
  33. resized = enhancer.enhance(1.1)
  34. return resized
  35. class OcrAgent:
  36. def __init__(self):
  37. self._url = "http://127.0.0.1:8000/api/v1/ocr"
  38. def extract_part_info(self, image_base64, prompts):
  39. """根据提示词提取信息"""
  40. response = requests.post(
  41. self._url,
  42. json={
  43. "image": image_base64,
  44. "text": prompts
  45. }
  46. )
  47. result = response.json()
  48. return result
  49. def agent_ocr(self, image):
  50. """qwen_ocr提取化学品安全标签信息"""
  51. image = resize_image(image, max_size=1024)
  52. image_base64 = image_to_base64(image)
  53. start_time = time.perf_counter()
  54. # 定义需要并行执行的任务
  55. prompts = [
  56. PROMPT_EXTRACT_ICON,
  57. PROMPT_EXTRACT_NAME,
  58. PROMPT_EXTRACT_COMPONENTS,
  59. PROMPT_EXTRACT_KEYWORD,
  60. PROMPT_EXTRACT_PREVENTION,
  61. PROMPT_EXTRACT_SUPPLIER
  62. ]
  63. results = self.extract_part_info(image_base64, prompts)
  64. results = results["data"]
  65. # 从结果中提取数据
  66. icon = json.loads(results[0])
  67. name = json.loads(results[1])
  68. tag = json.loads(results[2])
  69. risk_notice = json.loads(results[3])
  70. pre_notice = json.loads(results[4])
  71. suppliers = json.loads(results[5])
  72. end_time = time.perf_counter()
  73. elapsed_time = end_time - start_time
  74. print(f"推理时间: {elapsed_time:.6f} 秒")
  75. result = {
  76. "tag": {
  77. "name_cn": name["name_cn"],
  78. "name_en": name["name_en"],
  79. "cf_list": tag["cf_list"]
  80. },
  81. "tag_images": icon["tag_images"],
  82. "key_word": risk_notice["key_word"],
  83. "risk_notice": risk_notice["risk_notice"],
  84. "pre_notice": pre_notice["pre_notice"],
  85. "supplier": suppliers["supplier"],
  86. "acc_tel": suppliers["acc_tel"],
  87. }
  88. return result
  89. if __name__ == "__main__":
  90. image = Image.open("./test1.jpg").convert("RGB")
  91. agent = OcrAgent()
  92. res = agent.agent_ocr(image)
  93. print(res)