| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- import base64
- from io import BytesIO
- from PIL import Image
- import json
- import os
- from pathlib import Path
- from qwen_vl_utils import process_vision_info
- from transformers import Qwen3VLForConditionalGeneration, AutoTokenizer, AutoProcessor
- import time
- import torch
- from config import MODEL_PATH, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_ICON
- def image_to_base64(pil_image, image_format="JPEG"):
- """将PIL Image图像转换为Base64编码"""
- buffered = BytesIO()
- pil_image.save(buffered, format=image_format)
- img_byte_array = buffered.getvalue()
- encode_image = base64.b64encode(img_byte_array).decode('utf-8')
- return encode_image
- class QwenOcr:
- def __init__(self, icon_dir="./icon"):
- self.model = Qwen3VLForConditionalGeneration.from_pretrained(
- MODEL_PATH,
- # torch_dtype="auto",
- torch_dtype=torch.bfloat16,
- attn_implementation="flash_attention_2",
- device_map="auto"
- )
- # 优化1: 设置为评估模式并禁用梯度计算
- self.model.eval()
- torch.set_grad_enabled(False)
- self.processor = AutoProcessor.from_pretrained(MODEL_PATH)
- # 加载icon参考图像
- self.icon_dir = icon_dir
- # self.icon_images = self._load_icon_images()
- # 优化4: 模型预热 - 运行一次推理以触发编译
- print("模型预热中...")
- self._warmup()
- print("模型预热完成")
- def _load_icon_images(self):
- """加载icon目录下的所有参考图像"""
- icon_images = {}
- icon_path = Path(self.icon_dir)
- if not icon_path.exists():
- print(f"警告: icon目录 {self.icon_dir} 不存在")
- return icon_images
- # 加载所有png图像文件
- for icon_file in icon_path.glob("*.png"):
- icon_name = icon_file.stem # 获取文件名(不含扩展名), 如 GHS01
- try:
- icon_image = Image.open(icon_file).convert("RGB")
- icon_images[icon_name] = icon_image
- print(f"已加载icon参考图像: {icon_name}")
- except Exception as e:
- print(f"加载icon图像 {icon_file} 失败: {e}")
- return icon_images
- def _warmup(self):
- """预热模型以触发编译和优化"""
- dummy_image = Image.new('RGB', (224, 224), color='white')
- prompt = PROMPT_EXTRACT_NAME
- try:
- self.inference(dummy_image, prompt, warmup=True)
- except Exception as e:
- print(f"预热过程中出现警告(可忽略): {e}")
-
- def inference(self, image, prompt, warmup=False):
- """ocr推理
- Args:
- image: PIL Image对象
- warmup: 是否为预热模式(预热时不打印详细信息)
- """
- messages = [
- {
- "role": "user",
- "content": [
- {
- "type": "image",
- "image": image, # 直接传递PIL图像对象
- },
- {"type": "text", "text": prompt},
- ],
- }
- ]
- text = self.processor.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
- image_inputs, video_inputs = process_vision_info(messages)
- inputs = self.processor(
- text=[text],
- images=image_inputs,
- videos=video_inputs,
- padding=True,
- return_tensors="pt",
- )
- inputs = inputs.to("npu")
- # 优化1: 添加KV Cache和生成参数优化
- generated_ids = self.model.generate(
- **inputs,
- max_new_tokens=512, # 根据实际需求减少生成长度
- use_cache=True, # 启用KV cache加速
- do_sample=False, # 使用贪婪解码,更快且稳定
- num_beams=1, # 不使用束搜索,进一步加速
- )
- generated_ids_trimmed = [
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
- ]
- output_text = self.processor.batch_decode(
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
- )
- return output_text
- def extract_icons(self, image):
- """识别图像中的象形图标识
- Args:
- image: PIL Image对象 - 待识别的化学品标签图像
- Returns:
- dict: 包含识别结果的字典,格式为 {"tag_images": ["GHS06", "GHS08", ""]}
- """
- # 构建包含所有参考图像的messages
- messages = [
- {
- "role": "user",
- "content": []
- }
- ]
- # 添加所有icon参考图像
- content_list = messages[0]["content"]
- # 按GHS编号顺序添加参考图像
- sorted_icons = sorted(self.icon_images.items(), key=lambda x: x[0])
- for icon_name, icon_image in sorted_icons:
- content_list.append({
- "type": "image",
- "image": icon_image,
- })
- content_list.append({
- "type": "text",
- "text": f"参考图像:{icon_name}"
- })
- # 添加待识别的图像
- content_list.append({
- "type": "image",
- "image": image,
- })
- # 添加提示词
- content_list.append({
- "type": "text",
- "text": PROMPT_EXTRACT_ICON
- })
- # 处理消息并进行推理
- text = self.processor.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
- image_inputs, video_inputs = process_vision_info(messages)
- inputs = self.processor(
- text=[text],
- images=image_inputs,
- videos=video_inputs,
- padding=True,
- return_tensors="pt",
- )
- inputs = inputs.to("npu")
- # 生成结果
- generated_ids = self.model.generate(
- **inputs,
- max_new_tokens=512,
- use_cache=True,
- do_sample=False,
- num_beams=1,
- )
- generated_ids_trimmed = [
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
- ]
- output_text = self.processor.batch_decode(
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
- )
- return output_text
-
- if __name__ == '__main__':
- qwen_ocr = QwenOcr()
- image = Image.open("./test3.jpg").convert("RGB")
- clear()
-
|