import base64 from io import BytesIO from PIL import Image import json import os from pathlib import Path from qwen_vl_utils import process_vision_info from transformers import Qwen3VLForConditionalGeneration, AutoTokenizer, AutoProcessor import time import torch from config import MODEL_PATH, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_ICON def image_to_base64(pil_image, image_format="JPEG"): """将PIL Image图像转换为Base64编码""" buffered = BytesIO() pil_image.save(buffered, format=image_format) img_byte_array = buffered.getvalue() encode_image = base64.b64encode(img_byte_array).decode('utf-8') return encode_image class QwenOcr: def __init__(self, icon_dir="./icon"): self.model = Qwen3VLForConditionalGeneration.from_pretrained( MODEL_PATH, # torch_dtype="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto" ) # 优化1: 设置为评估模式并禁用梯度计算 self.model.eval() torch.set_grad_enabled(False) self.processor = AutoProcessor.from_pretrained(MODEL_PATH) # 加载icon参考图像 self.icon_dir = icon_dir # self.icon_images = self._load_icon_images() # 优化4: 模型预热 - 运行一次推理以触发编译 print("模型预热中...") self._warmup() print("模型预热完成") def _load_icon_images(self): """加载icon目录下的所有参考图像""" icon_images = {} icon_path = Path(self.icon_dir) if not icon_path.exists(): print(f"警告: icon目录 {self.icon_dir} 不存在") return icon_images # 加载所有png图像文件 for icon_file in icon_path.glob("*.png"): icon_name = icon_file.stem # 获取文件名(不含扩展名), 如 GHS01 try: icon_image = Image.open(icon_file).convert("RGB") icon_images[icon_name] = icon_image print(f"已加载icon参考图像: {icon_name}") except Exception as e: print(f"加载icon图像 {icon_file} 失败: {e}") return icon_images def _warmup(self): """预热模型以触发编译和优化""" dummy_image = Image.new('RGB', (224, 224), color='white') prompt = PROMPT_EXTRACT_NAME try: self.inference(dummy_image, prompt, warmup=True) except Exception as e: print(f"预热过程中出现警告(可忽略): {e}") def inference(self, image, prompt, warmup=False): """ocr推理 Args: image: PIL Image对象 warmup: 是否为预热模式(预热时不打印详细信息) """ messages = [ { "role": "user", "content": [ { "type": "image", "image": image, # 直接传递PIL图像对象 }, {"type": "text", "text": prompt}, ], } ] text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to("npu") # 优化1: 添加KV Cache和生成参数优化 generated_ids = self.model.generate( **inputs, max_new_tokens=512, # 根据实际需求减少生成长度 use_cache=True, # 启用KV cache加速 do_sample=False, # 使用贪婪解码,更快且稳定 num_beams=1, # 不使用束搜索,进一步加速 ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text def extract_icons(self, image): """识别图像中的象形图标识 Args: image: PIL Image对象 - 待识别的化学品标签图像 Returns: dict: 包含识别结果的字典,格式为 {"tag_images": ["GHS06", "GHS08", ""]} """ # 构建包含所有参考图像的messages messages = [ { "role": "user", "content": [] } ] # 添加所有icon参考图像 content_list = messages[0]["content"] # 按GHS编号顺序添加参考图像 sorted_icons = sorted(self.icon_images.items(), key=lambda x: x[0]) for icon_name, icon_image in sorted_icons: content_list.append({ "type": "image", "image": icon_image, }) content_list.append({ "type": "text", "text": f"参考图像:{icon_name}" }) # 添加待识别的图像 content_list.append({ "type": "image", "image": image, }) # 添加提示词 content_list.append({ "type": "text", "text": PROMPT_EXTRACT_ICON }) # 处理消息并进行推理 text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to("npu") # 生成结果 generated_ids = self.model.generate( **inputs, max_new_tokens=512, use_cache=True, do_sample=False, num_beams=1, ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text if __name__ == '__main__': qwen_ocr = QwenOcr() image = Image.open("./test3.jpg").convert("RGB") clear()