qwen_ocr.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import base64
  2. from io import BytesIO
  3. from PIL import Image
  4. import json
  5. import os
  6. from pathlib import Path
  7. from qwen_vl_utils import process_vision_info
  8. from transformers import Qwen3VLForConditionalGeneration, AutoTokenizer, AutoProcessor
  9. import time
  10. import torch
  11. from config import MODEL_PATH, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_ICON
  12. def image_to_base64(pil_image, image_format="JPEG"):
  13. """将PIL Image图像转换为Base64编码"""
  14. buffered = BytesIO()
  15. pil_image.save(buffered, format=image_format)
  16. img_byte_array = buffered.getvalue()
  17. encode_image = base64.b64encode(img_byte_array).decode('utf-8')
  18. return encode_image
  19. class QwenOcr:
  20. def __init__(self, icon_dir="./icon"):
  21. self.model = Qwen3VLForConditionalGeneration.from_pretrained(
  22. MODEL_PATH,
  23. # torch_dtype="auto",
  24. torch_dtype=torch.bfloat16,
  25. attn_implementation="flash_attention_2",
  26. device_map="auto"
  27. )
  28. # 优化1: 设置为评估模式并禁用梯度计算
  29. self.model.eval()
  30. torch.set_grad_enabled(False)
  31. self.processor = AutoProcessor.from_pretrained(MODEL_PATH)
  32. # 加载icon参考图像
  33. self.icon_dir = icon_dir
  34. # self.icon_images = self._load_icon_images()
  35. # 优化4: 模型预热 - 运行一次推理以触发编译
  36. print("模型预热中...")
  37. self._warmup()
  38. print("模型预热完成")
  39. def _load_icon_images(self):
  40. """加载icon目录下的所有参考图像"""
  41. icon_images = {}
  42. icon_path = Path(self.icon_dir)
  43. if not icon_path.exists():
  44. print(f"警告: icon目录 {self.icon_dir} 不存在")
  45. return icon_images
  46. # 加载所有png图像文件
  47. for icon_file in icon_path.glob("*.png"):
  48. icon_name = icon_file.stem # 获取文件名(不含扩展名), 如 GHS01
  49. try:
  50. icon_image = Image.open(icon_file).convert("RGB")
  51. icon_images[icon_name] = icon_image
  52. print(f"已加载icon参考图像: {icon_name}")
  53. except Exception as e:
  54. print(f"加载icon图像 {icon_file} 失败: {e}")
  55. return icon_images
  56. def _warmup(self):
  57. """预热模型以触发编译和优化"""
  58. dummy_image = Image.new('RGB', (224, 224), color='white')
  59. prompt = PROMPT_EXTRACT_NAME
  60. try:
  61. self.inference(dummy_image, prompt, warmup=True)
  62. except Exception as e:
  63. print(f"预热过程中出现警告(可忽略): {e}")
  64. def inference(self, image, prompt, warmup=False):
  65. """ocr推理
  66. Args:
  67. image: PIL Image对象
  68. warmup: 是否为预热模式(预热时不打印详细信息)
  69. """
  70. messages = [
  71. {
  72. "role": "user",
  73. "content": [
  74. {
  75. "type": "image",
  76. "image": image, # 直接传递PIL图像对象
  77. },
  78. {"type": "text", "text": prompt},
  79. ],
  80. }
  81. ]
  82. text = self.processor.apply_chat_template(
  83. messages, tokenize=False, add_generation_prompt=True
  84. )
  85. image_inputs, video_inputs = process_vision_info(messages)
  86. inputs = self.processor(
  87. text=[text],
  88. images=image_inputs,
  89. videos=video_inputs,
  90. padding=True,
  91. return_tensors="pt",
  92. )
  93. inputs = inputs.to("npu")
  94. # 优化1: 添加KV Cache和生成参数优化
  95. generated_ids = self.model.generate(
  96. **inputs,
  97. max_new_tokens=512, # 根据实际需求减少生成长度
  98. use_cache=True, # 启用KV cache加速
  99. do_sample=False, # 使用贪婪解码,更快且稳定
  100. num_beams=1, # 不使用束搜索,进一步加速
  101. )
  102. generated_ids_trimmed = [
  103. out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
  104. ]
  105. output_text = self.processor.batch_decode(
  106. generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
  107. )
  108. return output_text
  109. def extract_icons(self, image):
  110. """识别图像中的象形图标识
  111. Args:
  112. image: PIL Image对象 - 待识别的化学品标签图像
  113. Returns:
  114. dict: 包含识别结果的字典,格式为 {"tag_images": ["GHS06", "GHS08", ""]}
  115. """
  116. # 构建包含所有参考图像的messages
  117. messages = [
  118. {
  119. "role": "user",
  120. "content": []
  121. }
  122. ]
  123. # 添加所有icon参考图像
  124. content_list = messages[0]["content"]
  125. # 按GHS编号顺序添加参考图像
  126. sorted_icons = sorted(self.icon_images.items(), key=lambda x: x[0])
  127. for icon_name, icon_image in sorted_icons:
  128. content_list.append({
  129. "type": "image",
  130. "image": icon_image,
  131. })
  132. content_list.append({
  133. "type": "text",
  134. "text": f"参考图像:{icon_name}"
  135. })
  136. # 添加待识别的图像
  137. content_list.append({
  138. "type": "image",
  139. "image": image,
  140. })
  141. # 添加提示词
  142. content_list.append({
  143. "type": "text",
  144. "text": PROMPT_EXTRACT_ICON
  145. })
  146. # 处理消息并进行推理
  147. text = self.processor.apply_chat_template(
  148. messages, tokenize=False, add_generation_prompt=True
  149. )
  150. image_inputs, video_inputs = process_vision_info(messages)
  151. inputs = self.processor(
  152. text=[text],
  153. images=image_inputs,
  154. videos=video_inputs,
  155. padding=True,
  156. return_tensors="pt",
  157. )
  158. inputs = inputs.to("npu")
  159. # 生成结果
  160. generated_ids = self.model.generate(
  161. **inputs,
  162. max_new_tokens=512,
  163. use_cache=True,
  164. do_sample=False,
  165. num_beams=1,
  166. )
  167. generated_ids_trimmed = [
  168. out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
  169. ]
  170. output_text = self.processor.batch_decode(
  171. generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
  172. )
  173. return output_text
  174. if __name__ == '__main__':
  175. qwen_ocr = QwenOcr()
  176. image = Image.open("./test3.jpg").convert("RGB")
  177. clear()