|
@@ -0,0 +1,404 @@
|
|
|
|
|
+import base64
|
|
|
|
|
+from io import BytesIO
|
|
|
|
|
+from PIL import Image
|
|
|
|
|
+import json
|
|
|
|
|
+import os
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import List, Dict, Any, Optional
|
|
|
|
|
+
|
|
|
|
|
+from qwen_vl_utils import process_vision_info
|
|
|
|
|
+from transformers import AutoProcessor
|
|
|
|
|
+import time
|
|
|
|
|
+import torch
|
|
|
|
|
+
|
|
|
|
|
+from config import MODEL_PATH, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER
|
|
|
|
|
+
|
|
|
|
|
+# vLLM imports
|
|
|
|
|
+from vllm import LLM, SamplingParams
|
|
|
|
|
+from vllm.multimodal.utils import fetch_image
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def image_to_base64(pil_image, image_format="JPEG"):
|
|
|
|
|
+ """将PIL Image图像转换为Base64编码"""
|
|
|
|
|
+ buffered = BytesIO()
|
|
|
|
|
+ pil_image.save(buffered, format=image_format)
|
|
|
|
|
+ img_byte_array = buffered.getvalue()
|
|
|
|
|
+ encode_image = base64.b64encode(img_byte_array).decode('utf-8')
|
|
|
|
|
+ return encode_image
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class QwenOcrVLLM:
|
|
|
|
|
+ """基于vLLM加速框架的Qwen OCR推理类
|
|
|
|
|
+
|
|
|
|
|
+ vLLM优势:
|
|
|
|
|
+ 1. PagedAttention技术 - 高效的KV cache管理
|
|
|
|
|
+ 2. 连续批处理 - 优化GPU利用率
|
|
|
|
|
+ 3. 快速模型执行 - CUDA/cuDNN kernel优化
|
|
|
|
|
+ 4. 支持量化 - AWQ, GPTQ等量化格式
|
|
|
|
|
+ 5. 张量并行 - 支持多GPU推理
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ icon_dir: str = "./icon",
|
|
|
|
|
+ tensor_parallel_size: int = 1,
|
|
|
|
|
+ gpu_memory_utilization: float = 0.9,
|
|
|
|
|
+ max_model_len: int = 8192,
|
|
|
|
|
+ dtype: str = "bfloat16",
|
|
|
|
|
+ trust_remote_code: bool = True,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """初始化vLLM模型
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ icon_dir: icon参考图像目录
|
|
|
|
|
+ tensor_parallel_size: 张量并行大小(多GPU推理)
|
|
|
|
|
+ gpu_memory_utilization: GPU显存利用率(0.0-1.0)
|
|
|
|
|
+ max_model_len: 最大模型序列长度
|
|
|
|
|
+ dtype: 数据类型("auto", "half", "float16", "bfloat16", "float", "float32")
|
|
|
|
|
+ trust_remote_code: 是否信任远程代码
|
|
|
|
|
+ """
|
|
|
|
|
+ print("=" * 60)
|
|
|
|
|
+ print("初始化 vLLM 加速推理引擎...")
|
|
|
|
|
+ print("=" * 60)
|
|
|
|
|
+
|
|
|
|
|
+ # 初始化vLLM引擎
|
|
|
|
|
+ self.llm = LLM(
|
|
|
|
|
+ model=MODEL_PATH,
|
|
|
|
|
+ tensor_parallel_size=tensor_parallel_size,
|
|
|
|
|
+ gpu_memory_utilization=gpu_memory_utilization,
|
|
|
|
|
+ max_model_len=max_model_len,
|
|
|
|
|
+ dtype=dtype,
|
|
|
|
|
+ trust_remote_code=trust_remote_code,
|
|
|
|
|
+ # 视觉模型特定参数
|
|
|
|
|
+ limit_mm_per_prompt={"image": 10}, # 每个prompt最多支持10张图像
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 加载processor用于消息模板处理
|
|
|
|
|
+ self.processor = AutoProcessor.from_pretrained(
|
|
|
|
|
+ MODEL_PATH,
|
|
|
|
|
+ trust_remote_code=trust_remote_code
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 加载icon参考图像
|
|
|
|
|
+ self.icon_dir = icon_dir
|
|
|
|
|
+ # self.icon_images = self._load_icon_images()
|
|
|
|
|
+
|
|
|
|
|
+ # 默认采样参数
|
|
|
|
|
+ self.default_sampling_params = SamplingParams(
|
|
|
|
|
+ temperature=0.0, # 使用贪婪解码
|
|
|
|
|
+ top_p=1.0,
|
|
|
|
|
+ max_tokens=512, # 最大生成token数
|
|
|
|
|
+ stop_token_ids=None,
|
|
|
|
|
+ skip_special_tokens=True,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ print("=" * 60)
|
|
|
|
|
+ print("vLLM 引擎初始化完成!")
|
|
|
|
|
+ print(f"- 模型路径: {MODEL_PATH}")
|
|
|
|
|
+ print(f"- 张量并行: {tensor_parallel_size} GPU(s)")
|
|
|
|
|
+ print(f"- 显存利用率: {gpu_memory_utilization * 100:.1f}%")
|
|
|
|
|
+ print(f"- 数据类型: {dtype}")
|
|
|
|
|
+ print(f"- 最大序列长度: {max_model_len}")
|
|
|
|
|
+ print("=" * 60)
|
|
|
|
|
+
|
|
|
|
|
+ # 模型预热
|
|
|
|
|
+ print("模型预热中...")
|
|
|
|
|
+ self._warmup()
|
|
|
|
|
+ print("模型预热完成!")
|
|
|
|
|
+ print("=" * 60)
|
|
|
|
|
+
|
|
|
|
|
+ def _load_icon_images(self) -> Dict[str, Image.Image]:
|
|
|
|
|
+ """加载icon目录下的所有参考图像"""
|
|
|
|
|
+ icon_images = {}
|
|
|
|
|
+ icon_path = Path(self.icon_dir)
|
|
|
|
|
+
|
|
|
|
|
+ if not icon_path.exists():
|
|
|
|
|
+ print(f"警告: icon目录 {self.icon_dir} 不存在")
|
|
|
|
|
+ return icon_images
|
|
|
|
|
+
|
|
|
|
|
+ # 加载所有png图像文件
|
|
|
|
|
+ for icon_file in icon_path.glob("*.png"):
|
|
|
|
|
+ icon_name = icon_file.stem # 获取文件名(不含扩展名), 如 GHS01
|
|
|
|
|
+ try:
|
|
|
|
|
+ icon_image = Image.open(icon_file).convert("RGB")
|
|
|
|
|
+ icon_images[icon_name] = icon_image
|
|
|
|
|
+ print(f"已加载icon参考图像: {icon_name}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"加载icon图像 {icon_file} 失败: {e}")
|
|
|
|
|
+
|
|
|
|
|
+ return icon_images
|
|
|
|
|
+
|
|
|
|
|
+ def _warmup(self):
|
|
|
|
|
+ """预热模型以触发编译和优化"""
|
|
|
|
|
+ dummy_image = Image.new('RGB', (224, 224), color='white')
|
|
|
|
|
+ prompt = PROMPT_EXTRACT_NAME
|
|
|
|
|
+ try:
|
|
|
|
|
+ self.inference(dummy_image, prompt, warmup=True)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"预热过程中出现警告(可忽略): {e}")
|
|
|
|
|
+
|
|
|
|
|
+ def _build_messages(self, image: Image.Image, prompt: str) -> List[Dict]:
|
|
|
|
|
+ """构建消息格式
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: PIL Image对象
|
|
|
|
|
+ prompt: 提示词文本
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 消息列表
|
|
|
|
|
+ """
|
|
|
|
|
+ messages = [
|
|
|
|
|
+ {
|
|
|
|
|
+ "role": "user",
|
|
|
|
|
+ "content": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "type": "image",
|
|
|
|
|
+ "image": image,
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "type": "text",
|
|
|
|
|
+ "text": prompt
|
|
|
|
|
+ },
|
|
|
|
|
+ ],
|
|
|
|
|
+ }
|
|
|
|
|
+ ]
|
|
|
|
|
+ return messages
|
|
|
|
|
+
|
|
|
|
|
+ def _prepare_inputs(
|
|
|
|
|
+ self,
|
|
|
|
|
+ messages: List[Dict]
|
|
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
|
|
+ """准备vLLM输入格式
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ messages: 消息列表
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 包含prompt和multi_modal_data的字典
|
|
|
|
|
+ """
|
|
|
|
|
+ # 应用chat模板
|
|
|
|
|
+ text = self.processor.apply_chat_template(
|
|
|
|
|
+ messages,
|
|
|
|
|
+ tokenize=False,
|
|
|
|
|
+ add_generation_prompt=True
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 处理视觉信息
|
|
|
|
|
+ image_inputs, video_inputs = process_vision_info(messages)
|
|
|
|
|
+
|
|
|
|
|
+ # vLLM 0.6.0+ 新版API格式
|
|
|
|
|
+ # 直接返回包含文本和多模态数据的字典
|
|
|
|
|
+ inputs = {
|
|
|
|
|
+ "prompt": text,
|
|
|
|
|
+ "multi_modal_data": {
|
|
|
|
|
+ "image": image_inputs[0] if image_inputs else None
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return inputs
|
|
|
|
|
+
|
|
|
|
|
+ def inference(
|
|
|
|
|
+ self,
|
|
|
|
|
+ image: Image.Image,
|
|
|
|
|
+ prompt: str,
|
|
|
|
|
+ warmup: bool = False,
|
|
|
|
|
+ sampling_params: Optional[SamplingParams] = None
|
|
|
|
|
+ ) -> List[str]:
|
|
|
|
|
+ """OCR推理
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: PIL Image对象
|
|
|
|
|
+ prompt: 提示词
|
|
|
|
|
+ warmup: 是否为预热模式
|
|
|
|
|
+ sampling_params: 自定义采样参数
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 生成的文本列表
|
|
|
|
|
+ """
|
|
|
|
|
+ # 构建消息
|
|
|
|
|
+ messages = self._build_messages(image, prompt)
|
|
|
|
|
+
|
|
|
|
|
+ # 准备输入
|
|
|
|
|
+ inputs = self._prepare_inputs(messages)
|
|
|
|
|
+
|
|
|
|
|
+ # 使用默认或自定义采样参数
|
|
|
|
|
+ params = sampling_params if sampling_params else self.default_sampling_params
|
|
|
|
|
+
|
|
|
|
|
+ # vLLM 0.6.0+ 新版API:直接传递inputs字典
|
|
|
|
|
+ outputs = self.llm.generate(
|
|
|
|
|
+ inputs,
|
|
|
|
|
+ params
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 提取生成的文本
|
|
|
|
|
+ generated_texts = [output.outputs[0].text for output in outputs]
|
|
|
|
|
+
|
|
|
|
|
+ if not warmup:
|
|
|
|
|
+ return generated_texts
|
|
|
|
|
+
|
|
|
|
|
+ return generated_texts
|
|
|
|
|
+
|
|
|
|
|
+ def batch_inference(
|
|
|
|
|
+ self,
|
|
|
|
|
+ images: List[Image.Image],
|
|
|
|
|
+ prompts: List[str],
|
|
|
|
|
+ sampling_params: Optional[SamplingParams] = None
|
|
|
|
|
+ ) -> List[str]:
|
|
|
|
|
+ """批量OCR推理(vLLM的核心优势)
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ images: PIL Image对象列表
|
|
|
|
|
+ prompts: 提示词列表
|
|
|
|
|
+ sampling_params: 自定义采样参数
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 生成的文本列表
|
|
|
|
|
+ """
|
|
|
|
|
+ if len(images) != len(prompts):
|
|
|
|
|
+ raise ValueError(f"images数量({len(images)})和prompts数量({len(prompts)})不匹配")
|
|
|
|
|
+
|
|
|
|
|
+ # 准备所有输入
|
|
|
|
|
+ all_inputs = []
|
|
|
|
|
+ for image, prompt in zip(images, prompts):
|
|
|
|
|
+ messages = self._build_messages(image, prompt)
|
|
|
|
|
+ inputs = self._prepare_inputs(messages)
|
|
|
|
|
+ all_inputs.append(inputs)
|
|
|
|
|
+
|
|
|
|
|
+ # 使用默认或自定义采样参数
|
|
|
|
|
+ params = sampling_params if sampling_params else self.default_sampling_params
|
|
|
|
|
+
|
|
|
|
|
+ # vLLM 0.6.0+ 新版API:直接传递inputs列表
|
|
|
|
|
+ outputs = self.llm.generate(
|
|
|
|
|
+ all_inputs,
|
|
|
|
|
+ params
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 提取生成的文本
|
|
|
|
|
+ generated_texts = [output.outputs[0].text for output in outputs]
|
|
|
|
|
+
|
|
|
|
|
+ return generated_texts
|
|
|
|
|
+
|
|
|
|
|
+ def extract_icons(self, image: Image.Image) -> List[str]:
|
|
|
|
|
+ """识别图像中的象形图标识
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: PIL Image对象 - 待识别的化学品标签图像
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 生成的文本列表
|
|
|
|
|
+ """
|
|
|
|
|
+ # 构建包含所有参考图像的messages
|
|
|
|
|
+ messages = [
|
|
|
|
|
+ {
|
|
|
|
|
+ "role": "user",
|
|
|
|
|
+ "content": []
|
|
|
|
|
+ }
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ # 添加所有icon参考图像
|
|
|
|
|
+ content_list = messages[0]["content"]
|
|
|
|
|
+
|
|
|
|
|
+ # 按GHS编号顺序添加参考图像
|
|
|
|
|
+ sorted_icons = sorted(self.icon_images.items(), key=lambda x: x[0])
|
|
|
|
|
+ for icon_name, icon_image in sorted_icons:
|
|
|
|
|
+ content_list.append({
|
|
|
|
|
+ "type": "image",
|
|
|
|
|
+ "image": icon_image,
|
|
|
|
|
+ })
|
|
|
|
|
+ content_list.append({
|
|
|
|
|
+ "type": "text",
|
|
|
|
|
+ "text": f"参考图像:{icon_name}"
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # 添加待识别的图像
|
|
|
|
|
+ content_list.append({
|
|
|
|
|
+ "type": "image",
|
|
|
|
|
+ "image": image,
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # 添加提示词
|
|
|
|
|
+ content_list.append({
|
|
|
|
|
+ "type": "text",
|
|
|
|
|
+ "text": PROMPT_EXTRACT_ICON
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # 准备输入
|
|
|
|
|
+ inputs = self._prepare_inputs(messages)
|
|
|
|
|
+
|
|
|
|
|
+ # vLLM 0.6.0+ 新版API:直接传递inputs字典
|
|
|
|
|
+ outputs = self.llm.generate(
|
|
|
|
|
+ inputs,
|
|
|
|
|
+ self.default_sampling_params
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 提取生成的文本
|
|
|
|
|
+ generated_texts = [output.outputs[0].text for output in outputs]
|
|
|
|
|
+
|
|
|
|
|
+ return generated_texts
|
|
|
|
|
+
|
|
|
|
|
+ def __del__(self):
|
|
|
|
|
+ """析构函数,清理资源"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ if hasattr(self, 'llm'):
|
|
|
|
|
+ del self.llm
|
|
|
|
|
+ print("vLLM引擎已释放")
|
|
|
|
|
+ except:
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
|
+ # 测试代码
|
|
|
|
|
+ print("初始化 QwenOcrVLLM...")
|
|
|
|
|
+ qwen_ocr = QwenOcrVLLM(
|
|
|
|
|
+ tensor_parallel_size=1, # 单GPU
|
|
|
|
|
+ gpu_memory_utilization=0.9, # 使用90%显存
|
|
|
|
|
+ max_model_len=8192, # 最大序列长度
|
|
|
|
|
+ dtype="bfloat16" # 使用bfloat16精度
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 测试单张图像推理
|
|
|
|
|
+ print("\n" + "=" * 60)
|
|
|
|
|
+ print("测试单张图像推理...")
|
|
|
|
|
+ print("=" * 60)
|
|
|
|
|
+
|
|
|
|
|
+ test_image_path = "./test3.jpg"
|
|
|
|
|
+ if os.path.exists(test_image_path):
|
|
|
|
|
+ image = Image.open(test_image_path).convert("RGB")
|
|
|
|
|
+
|
|
|
|
|
+ # 测试提取名称
|
|
|
|
|
+ start_time = time.time()
|
|
|
|
|
+ result = qwen_ocr.inference(image, PROMPT_EXTRACT_PREVENTION)
|
|
|
|
|
+ elapsed = time.time() - start_time
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n推理耗时: {elapsed:.3f}秒")
|
|
|
|
|
+ print(f"提取结果:\n{result[0]}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(f"测试图像 {test_image_path} 不存在,跳过测试")
|
|
|
|
|
+
|
|
|
|
|
+ # 测试批量推理
|
|
|
|
|
+ print("\n" + "=" * 60)
|
|
|
|
|
+ print("测试批量推理...")
|
|
|
|
|
+ print("=" * 60)
|
|
|
|
|
+
|
|
|
|
|
+ if os.path.exists(test_image_path):
|
|
|
|
|
+ # 创建3张相同的测试图像
|
|
|
|
|
+ images = [Image.open(test_image_path).convert("RGB") for _ in range(6)]
|
|
|
|
|
+ prompts = [PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER]
|
|
|
|
|
+
|
|
|
|
|
+ start_time = time.time()
|
|
|
|
|
+ results = qwen_ocr.batch_inference(images, prompts)
|
|
|
|
|
+ elapsed = time.time() - start_time
|
|
|
|
|
+ print(results[0])
|
|
|
|
|
+ print(results[1])
|
|
|
|
|
+ print(results[2])
|
|
|
|
|
+ print(results[3])
|
|
|
|
|
+ print(results[4])
|
|
|
|
|
+ print(results[5])
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n批量推理耗时: {elapsed:.3f}秒")
|
|
|
|
|
+ print(f"平均每张: {elapsed/3:.3f}秒")
|
|
|
|
|
+ print(f"批量推理加速比: {(elapsed/3):.3f}秒/张 vs 单张推理")
|
|
|
|
|
+
|
|
|
|
|
+ print("\n" + "=" * 60)
|
|
|
|
|
+ print("测试完成!")
|
|
|
|
|
+ print("=" * 60)
|