| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404 |
- import base64
- from io import BytesIO
- from PIL import Image
- import json
- import os
- from pathlib import Path
- from typing import List, Dict, Any, Optional
- from qwen_vl_utils import process_vision_info
- from transformers import AutoProcessor
- import time
- import torch
- from config import MODEL_PATH, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER
- # vLLM imports
- from vllm import LLM, SamplingParams
- from vllm.multimodal.utils import fetch_image
- def image_to_base64(pil_image, image_format="JPEG"):
- """将PIL Image图像转换为Base64编码"""
- buffered = BytesIO()
- pil_image.save(buffered, format=image_format)
- img_byte_array = buffered.getvalue()
- encode_image = base64.b64encode(img_byte_array).decode('utf-8')
- return encode_image
- class QwenOcrVLLM:
- """基于vLLM加速框架的Qwen OCR推理类
- vLLM优势:
- 1. PagedAttention技术 - 高效的KV cache管理
- 2. 连续批处理 - 优化GPU利用率
- 3. 快速模型执行 - CUDA/cuDNN kernel优化
- 4. 支持量化 - AWQ, GPTQ等量化格式
- 5. 张量并行 - 支持多GPU推理
- """
- def __init__(
- self,
- icon_dir: str = "./icon",
- tensor_parallel_size: int = 1,
- gpu_memory_utilization: float = 0.9,
- max_model_len: int = 8192,
- dtype: str = "bfloat16",
- trust_remote_code: bool = True,
- ):
- """初始化vLLM模型
- Args:
- icon_dir: icon参考图像目录
- tensor_parallel_size: 张量并行大小(多GPU推理)
- gpu_memory_utilization: GPU显存利用率(0.0-1.0)
- max_model_len: 最大模型序列长度
- dtype: 数据类型("auto", "half", "float16", "bfloat16", "float", "float32")
- trust_remote_code: 是否信任远程代码
- """
- print("=" * 60)
- print("初始化 vLLM 加速推理引擎...")
- print("=" * 60)
- # 初始化vLLM引擎
- self.llm = LLM(
- model=MODEL_PATH,
- tensor_parallel_size=tensor_parallel_size,
- gpu_memory_utilization=gpu_memory_utilization,
- max_model_len=max_model_len,
- dtype=dtype,
- trust_remote_code=trust_remote_code,
- # 视觉模型特定参数
- limit_mm_per_prompt={"image": 10}, # 每个prompt最多支持10张图像
- )
- # 加载processor用于消息模板处理
- self.processor = AutoProcessor.from_pretrained(
- MODEL_PATH,
- trust_remote_code=trust_remote_code
- )
- # 加载icon参考图像
- self.icon_dir = icon_dir
- # self.icon_images = self._load_icon_images()
- # 默认采样参数
- self.default_sampling_params = SamplingParams(
- temperature=0.0, # 使用贪婪解码
- top_p=1.0,
- max_tokens=512, # 最大生成token数
- stop_token_ids=None,
- skip_special_tokens=True,
- )
- print("=" * 60)
- print("vLLM 引擎初始化完成!")
- print(f"- 模型路径: {MODEL_PATH}")
- print(f"- 张量并行: {tensor_parallel_size} GPU(s)")
- print(f"- 显存利用率: {gpu_memory_utilization * 100:.1f}%")
- print(f"- 数据类型: {dtype}")
- print(f"- 最大序列长度: {max_model_len}")
- print("=" * 60)
- # 模型预热
- print("模型预热中...")
- self._warmup()
- print("模型预热完成!")
- print("=" * 60)
- def _load_icon_images(self) -> Dict[str, Image.Image]:
- """加载icon目录下的所有参考图像"""
- icon_images = {}
- icon_path = Path(self.icon_dir)
- if not icon_path.exists():
- print(f"警告: icon目录 {self.icon_dir} 不存在")
- return icon_images
- # 加载所有png图像文件
- for icon_file in icon_path.glob("*.png"):
- icon_name = icon_file.stem # 获取文件名(不含扩展名), 如 GHS01
- try:
- icon_image = Image.open(icon_file).convert("RGB")
- icon_images[icon_name] = icon_image
- print(f"已加载icon参考图像: {icon_name}")
- except Exception as e:
- print(f"加载icon图像 {icon_file} 失败: {e}")
- return icon_images
- def _warmup(self):
- """预热模型以触发编译和优化"""
- dummy_image = Image.new('RGB', (224, 224), color='white')
- prompt = PROMPT_EXTRACT_NAME
- try:
- self.inference(dummy_image, prompt, warmup=True)
- except Exception as e:
- print(f"预热过程中出现警告(可忽略): {e}")
- def _build_messages(self, image: Image.Image, prompt: str) -> List[Dict]:
- """构建消息格式
- Args:
- image: PIL Image对象
- prompt: 提示词文本
- Returns:
- 消息列表
- """
- messages = [
- {
- "role": "user",
- "content": [
- {
- "type": "image",
- "image": image,
- },
- {
- "type": "text",
- "text": prompt
- },
- ],
- }
- ]
- return messages
- def _prepare_inputs(
- self,
- messages: List[Dict]
- ) -> Dict[str, Any]:
- """准备vLLM输入格式
- Args:
- messages: 消息列表
- Returns:
- 包含prompt和multi_modal_data的字典
- """
- # 应用chat模板
- text = self.processor.apply_chat_template(
- messages,
- tokenize=False,
- add_generation_prompt=True
- )
- # 处理视觉信息
- image_inputs, video_inputs = process_vision_info(messages)
- # vLLM 0.6.0+ 新版API格式
- # 直接返回包含文本和多模态数据的字典
- inputs = {
- "prompt": text,
- "multi_modal_data": {
- "image": image_inputs[0] if image_inputs else None
- }
- }
- return inputs
- def inference(
- self,
- image: Image.Image,
- prompt: str,
- warmup: bool = False,
- sampling_params: Optional[SamplingParams] = None
- ) -> List[str]:
- """OCR推理
- Args:
- image: PIL Image对象
- prompt: 提示词
- warmup: 是否为预热模式
- sampling_params: 自定义采样参数
- Returns:
- 生成的文本列表
- """
- # 构建消息
- messages = self._build_messages(image, prompt)
- # 准备输入
- inputs = self._prepare_inputs(messages)
- # 使用默认或自定义采样参数
- params = sampling_params if sampling_params else self.default_sampling_params
- # vLLM 0.6.0+ 新版API:直接传递inputs字典
- outputs = self.llm.generate(
- inputs,
- params
- )
- # 提取生成的文本
- generated_texts = [output.outputs[0].text for output in outputs]
- if not warmup:
- return generated_texts
- return generated_texts
- def batch_inference(
- self,
- images: List[Image.Image],
- prompts: List[str],
- sampling_params: Optional[SamplingParams] = None
- ) -> List[str]:
- """批量OCR推理(vLLM的核心优势)
- Args:
- images: PIL Image对象列表
- prompts: 提示词列表
- sampling_params: 自定义采样参数
- Returns:
- 生成的文本列表
- """
- if len(images) != len(prompts):
- raise ValueError(f"images数量({len(images)})和prompts数量({len(prompts)})不匹配")
- # 准备所有输入
- all_inputs = []
- for image, prompt in zip(images, prompts):
- messages = self._build_messages(image, prompt)
- inputs = self._prepare_inputs(messages)
- all_inputs.append(inputs)
- # 使用默认或自定义采样参数
- params = sampling_params if sampling_params else self.default_sampling_params
- # vLLM 0.6.0+ 新版API:直接传递inputs列表
- outputs = self.llm.generate(
- all_inputs,
- params
- )
- # 提取生成的文本
- generated_texts = [output.outputs[0].text for output in outputs]
- return generated_texts
- def extract_icons(self, image: Image.Image) -> List[str]:
- """识别图像中的象形图标识
- Args:
- image: PIL Image对象 - 待识别的化学品标签图像
- Returns:
- 生成的文本列表
- """
- # 构建包含所有参考图像的messages
- messages = [
- {
- "role": "user",
- "content": []
- }
- ]
- # 添加所有icon参考图像
- content_list = messages[0]["content"]
- # 按GHS编号顺序添加参考图像
- sorted_icons = sorted(self.icon_images.items(), key=lambda x: x[0])
- for icon_name, icon_image in sorted_icons:
- content_list.append({
- "type": "image",
- "image": icon_image,
- })
- content_list.append({
- "type": "text",
- "text": f"参考图像:{icon_name}"
- })
- # 添加待识别的图像
- content_list.append({
- "type": "image",
- "image": image,
- })
- # 添加提示词
- content_list.append({
- "type": "text",
- "text": PROMPT_EXTRACT_ICON
- })
- # 准备输入
- inputs = self._prepare_inputs(messages)
- # vLLM 0.6.0+ 新版API:直接传递inputs字典
- outputs = self.llm.generate(
- inputs,
- self.default_sampling_params
- )
- # 提取生成的文本
- generated_texts = [output.outputs[0].text for output in outputs]
- return generated_texts
- def __del__(self):
- """析构函数,清理资源"""
- try:
- if hasattr(self, 'llm'):
- del self.llm
- print("vLLM引擎已释放")
- except:
- pass
- if __name__ == '__main__':
- # 测试代码
- print("初始化 QwenOcrVLLM...")
- qwen_ocr = QwenOcrVLLM(
- tensor_parallel_size=1, # 单GPU
- gpu_memory_utilization=0.9, # 使用90%显存
- max_model_len=8192, # 最大序列长度
- dtype="bfloat16" # 使用bfloat16精度
- )
- # 测试单张图像推理
- print("\n" + "=" * 60)
- print("测试单张图像推理...")
- print("=" * 60)
- test_image_path = "./test3.jpg"
- if os.path.exists(test_image_path):
- image = Image.open(test_image_path).convert("RGB")
- # 测试提取名称
- start_time = time.time()
- result = qwen_ocr.inference(image, PROMPT_EXTRACT_PREVENTION)
- elapsed = time.time() - start_time
- print(f"\n推理耗时: {elapsed:.3f}秒")
- print(f"提取结果:\n{result[0]}")
- else:
- print(f"测试图像 {test_image_path} 不存在,跳过测试")
- # 测试批量推理
- print("\n" + "=" * 60)
- print("测试批量推理...")
- print("=" * 60)
- if os.path.exists(test_image_path):
- # 创建3张相同的测试图像
- images = [Image.open(test_image_path).convert("RGB") for _ in range(6)]
- prompts = [PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER]
- start_time = time.time()
- results = qwen_ocr.batch_inference(images, prompts)
- elapsed = time.time() - start_time
- print(results[0])
- print(results[1])
- print(results[2])
- print(results[3])
- print(results[4])
- print(results[5])
- print(f"\n批量推理耗时: {elapsed:.3f}秒")
- print(f"平均每张: {elapsed/3:.3f}秒")
- print(f"批量推理加速比: {(elapsed/3):.3f}秒/张 vs 单张推理")
- print("\n" + "=" * 60)
- print("测试完成!")
- print("=" * 60)
|