ソースを参照

增加vllm加速推理框架

Sherlock1011 2 ヶ月 前
コミット
34f934c4c4
8 ファイル変更449 行追加58 行削除
  1. 22 36
      agent/agent.py
  2. 5 4
      api/run_api.py
  3. 1 1
      config/config.py
  4. 3 1
      model/__init__.py
  5. 10 10
      model/model_api.py
  6. 404 0
      model/qwen_ocr_vllm.py
  7. 0 3
      requirements.txt
  8. 4 3
      test_api.py

+ 22 - 36
agent/agent.py

@@ -46,17 +46,17 @@ class OcrAgent:
     def __init__(self):
         self._url = "http://127.0.0.1:8000/api/v1/ocr"
 
-    def extract_part_info(self, image_base64, prompt):
+    def extract_part_info(self, image_base64, prompts):
         """根据提示词提取信息"""
         response = requests.post(
             self._url,
             json={
                 "image": image_base64,
-                "text": prompt
+                "text": prompts
             }
         )
         result = response.json()
-        return json.loads(result['data'][0])
+        return result
 
     def agent_ocr(self, image):
         """qwen_ocr提取化学品安全标签信息"""
@@ -66,40 +66,25 @@ class OcrAgent:
         start_time = time.perf_counter()
 
         # 定义需要并行执行的任务
-        tasks = {
-            'icon': PROMPT_EXTRACT_ICON,
-            'name': PROMPT_EXTRACT_NAME,
-            'tag': PROMPT_EXTRACT_COMPONENTS,
-            'risk_notice': PROMPT_EXTRACT_KEYWORD,
-            'pre_notice': PROMPT_EXTRACT_PREVENTION,
-            'suppliers': PROMPT_EXTRACT_SUPPLIER
-        }
-
-        # 使用线程池并行执行所有提取任务
-        results = {}
-        with ThreadPoolExecutor(max_workers=6) as executor:
-            # 提交所有任务
-            future_to_task = {
-                executor.submit(self.extract_part_info, image_base64, prompt): task_name
-                for task_name, prompt in tasks.items()
-            }
-
-            # 收集结果
-            for future in as_completed(future_to_task):
-                task_name = future_to_task[future]
-                try:
-                    results[task_name] = future.result()
-                except Exception as e:
-                    print(f"任务 {task_name} 执行失败: {e}")
-                    results[task_name] = {}
+        prompts = [
+            PROMPT_EXTRACT_ICON,
+            PROMPT_EXTRACT_NAME,
+            PROMPT_EXTRACT_COMPONENTS,
+            PROMPT_EXTRACT_KEYWORD,
+            PROMPT_EXTRACT_PREVENTION,
+            PROMPT_EXTRACT_SUPPLIER
+        ]
+
+        results = self.extract_part_info(image_base64, prompts)
+        results = results["data"]
 
         # 从结果中提取数据
-        icon = results.get('icon', {})
-        name = results.get('name', {})
-        tag = results.get('tag', {})
-        risk_notice = results.get('risk_notice', {})
-        pre_notice = results.get('pre_notice', {})
-        suppliers = results.get('suppliers', {})
+        icon = json.loads(results[0])
+        name = json.loads(results[1])
+        tag = json.loads(results[2])
+        risk_notice = json.loads(results[3])
+        pre_notice = json.loads(results[4])
+        suppliers = json.loads(results[5])
 
         end_time = time.perf_counter()
         elapsed_time = end_time - start_time
@@ -125,4 +110,5 @@ class OcrAgent:
 if __name__ == "__main__":
     image = Image.open("./test1.jpg").convert("RGB")
     agent = OcrAgent()
-    agent.agent_ocr(image)
+    res = agent.agent_ocr(image)
+    print(res)

+ 5 - 4
api/run_api.py

@@ -13,7 +13,7 @@ from datetime import datetime
 
 from fastapi import FastAPI, HTTPException, status
 from fastapi.responses import JSONResponse
-from pydantic import BaseModel, Field, validator
+from pydantic import BaseModel, Field, field_validator
 from PIL import Image
 import uvicorn
 
@@ -37,7 +37,8 @@ class AgentOCRRequest(BaseModel):
     """Agent OCR 请求模型"""
     image: str = Field(..., description="Base64 编码的图像字符串")
 
-    @validator('image')
+    @field_validator('image')
+    @classmethod
     def validate_image(cls, v):
         """验证 base64 图像格式"""
         if not v:
@@ -346,9 +347,9 @@ async def global_exception_handler(request, exc):
 def main():
     """启动服务"""
     uvicorn.run(
-        "run_api:app",
+        "api.run_api:app",
         host="0.0.0.0",
-        port=7080,  # 使用 8001 端口,避免与 model_api 的 8000 端口冲突
+        port=6006,  # 使用 8001 端口,避免与 model_api 的 8000 端口冲突
         workers=1,  # 由于 Agent 占用资源,使用单 worker
         log_level="info",
         access_log=True,

+ 1 - 1
config/config.py

@@ -1,7 +1,7 @@
 # OCR配置文件
 
 # 模型路径
-MODEL_PATH = "/root/llm/Qwen3-VL-8B-Instruct"
+MODEL_PATH = "/root/autodl-tmp/llm/Qwen3-VL-8B-Instruct"
 
 
 # ========== OCR提示词 - 分步骤提取 ==========

+ 3 - 1
model/__init__.py

@@ -1,5 +1,7 @@
 from .qwen_ocr import QwenOcr
+from .qwen_ocr_vllm import QwenOcrVLLM
 
 __all__ = [
-    "QwenOcr"
+    "QwenOcr",
+    "QwenOcrVLLM"
 ]

+ 10 - 10
model/model_api.py

@@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, validator
 from PIL import Image
 import uvicorn
 
-from model.qwen_ocr import QwenOcr
+from model import QwenOcr, QwenOcrVLLM
 
 
 # ==================== 日志配置 ====================
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
 class OCRRequest(BaseModel):
     """OCR 推理请求模型"""
     image: str = Field(..., description="Base64 编码的图像字符串")
-    text: str = Field(..., description="OCR 提示词文本")
+    text: list = Field(..., description="OCR 提示词文本列表")
 
     @validator('image')
     def validate_image(cls, v):
@@ -53,9 +53,9 @@ class OCRRequest(BaseModel):
     @validator('text')
     def validate_text(cls, v):
         """验证提示词文本"""
-        if not v or not v.strip():
+        if not v:
             raise ValueError("提示词不能为空")
-        return v.strip()
+        return v
 
 
 class OCRResponse(BaseModel):
@@ -114,7 +114,7 @@ class ModelManager:
             logger.info("开始加载 QwenOcr 模型...")
             # 在线程池中加载模型,避免阻塞事件循环
             loop = asyncio.get_event_loop()
-            self.model = await loop.run_in_executor(None, QwenOcr)
+            self.model = await loop.run_in_executor(None, QwenOcrVLLM)
 
             # 初始化并发控制
             self.max_concurrent_requests = max_concurrent
@@ -166,12 +166,12 @@ class ModelManager:
             logger.error(f"Base64 转换失败: {e}")
             raise ValueError(f"图像解码失败: {str(e)}")
 
-    async def inference(self, image_base64: str, prompt: str) -> list:
+    async def inference(self, image_base64: str, prompts: str) -> list:
         """
         执行 OCR 推理(带并发控制)
         Args:
             image_base64: base64 编码的图像
-            prompt: 提示词
+            prompts: 提示词
         Returns:
             推理结果
         """
@@ -191,9 +191,9 @@ class ModelManager:
                 loop = asyncio.get_event_loop()
                 result = await loop.run_in_executor(
                     None,
-                    self.model.inference,
-                    pil_image,
-                    prompt
+                    self.model.batch_inference,
+                    [pil_image] * len(prompts),
+                    prompts
                 )
 
                 return result

+ 404 - 0
model/qwen_ocr_vllm.py

@@ -0,0 +1,404 @@
+import base64
+from io import BytesIO
+from PIL import Image
+import json
+import os
+from pathlib import Path
+from typing import List, Dict, Any, Optional
+
+from qwen_vl_utils import process_vision_info
+from transformers import AutoProcessor
+import time
+import torch
+
+from config import MODEL_PATH, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER
+
+# vLLM imports
+from vllm import LLM, SamplingParams
+from vllm.multimodal.utils import fetch_image
+
+
+def image_to_base64(pil_image, image_format="JPEG"):
+    """将PIL Image图像转换为Base64编码"""
+    buffered = BytesIO()
+    pil_image.save(buffered, format=image_format)
+    img_byte_array = buffered.getvalue()
+    encode_image = base64.b64encode(img_byte_array).decode('utf-8')
+    return encode_image
+
+
+class QwenOcrVLLM:
+    """基于vLLM加速框架的Qwen OCR推理类
+
+    vLLM优势:
+    1. PagedAttention技术 - 高效的KV cache管理
+    2. 连续批处理 - 优化GPU利用率
+    3. 快速模型执行 - CUDA/cuDNN kernel优化
+    4. 支持量化 - AWQ, GPTQ等量化格式
+    5. 张量并行 - 支持多GPU推理
+    """
+
+    def __init__(
+        self,
+        icon_dir: str = "./icon",
+        tensor_parallel_size: int = 1,
+        gpu_memory_utilization: float = 0.9,
+        max_model_len: int = 8192,
+        dtype: str = "bfloat16",
+        trust_remote_code: bool = True,
+    ):
+        """初始化vLLM模型
+
+        Args:
+            icon_dir: icon参考图像目录
+            tensor_parallel_size: 张量并行大小(多GPU推理)
+            gpu_memory_utilization: GPU显存利用率(0.0-1.0)
+            max_model_len: 最大模型序列长度
+            dtype: 数据类型("auto", "half", "float16", "bfloat16", "float", "float32")
+            trust_remote_code: 是否信任远程代码
+        """
+        print("=" * 60)
+        print("初始化 vLLM 加速推理引擎...")
+        print("=" * 60)
+
+        # 初始化vLLM引擎
+        self.llm = LLM(
+            model=MODEL_PATH,
+            tensor_parallel_size=tensor_parallel_size,
+            gpu_memory_utilization=gpu_memory_utilization,
+            max_model_len=max_model_len,
+            dtype=dtype,
+            trust_remote_code=trust_remote_code,
+            # 视觉模型特定参数
+            limit_mm_per_prompt={"image": 10},  # 每个prompt最多支持10张图像
+        )
+
+        # 加载processor用于消息模板处理
+        self.processor = AutoProcessor.from_pretrained(
+            MODEL_PATH,
+            trust_remote_code=trust_remote_code
+        )
+
+        # 加载icon参考图像
+        self.icon_dir = icon_dir
+        # self.icon_images = self._load_icon_images()
+
+        # 默认采样参数
+        self.default_sampling_params = SamplingParams(
+            temperature=0.0,      # 使用贪婪解码
+            top_p=1.0,
+            max_tokens=512,       # 最大生成token数
+            stop_token_ids=None,
+            skip_special_tokens=True,
+        )
+
+        print("=" * 60)
+        print("vLLM 引擎初始化完成!")
+        print(f"- 模型路径: {MODEL_PATH}")
+        print(f"- 张量并行: {tensor_parallel_size} GPU(s)")
+        print(f"- 显存利用率: {gpu_memory_utilization * 100:.1f}%")
+        print(f"- 数据类型: {dtype}")
+        print(f"- 最大序列长度: {max_model_len}")
+        print("=" * 60)
+
+        # 模型预热
+        print("模型预热中...")
+        self._warmup()
+        print("模型预热完成!")
+        print("=" * 60)
+
+    def _load_icon_images(self) -> Dict[str, Image.Image]:
+        """加载icon目录下的所有参考图像"""
+        icon_images = {}
+        icon_path = Path(self.icon_dir)
+
+        if not icon_path.exists():
+            print(f"警告: icon目录 {self.icon_dir} 不存在")
+            return icon_images
+
+        # 加载所有png图像文件
+        for icon_file in icon_path.glob("*.png"):
+            icon_name = icon_file.stem  # 获取文件名(不含扩展名), 如 GHS01
+            try:
+                icon_image = Image.open(icon_file).convert("RGB")
+                icon_images[icon_name] = icon_image
+                print(f"已加载icon参考图像: {icon_name}")
+            except Exception as e:
+                print(f"加载icon图像 {icon_file} 失败: {e}")
+
+        return icon_images
+
+    def _warmup(self):
+        """预热模型以触发编译和优化"""
+        dummy_image = Image.new('RGB', (224, 224), color='white')
+        prompt = PROMPT_EXTRACT_NAME
+        try:
+            self.inference(dummy_image, prompt, warmup=True)
+        except Exception as e:
+            print(f"预热过程中出现警告(可忽略): {e}")
+
+    def _build_messages(self, image: Image.Image, prompt: str) -> List[Dict]:
+        """构建消息格式
+
+        Args:
+            image: PIL Image对象
+            prompt: 提示词文本
+
+        Returns:
+            消息列表
+        """
+        messages = [
+            {
+                "role": "user",
+                "content": [
+                    {
+                        "type": "image",
+                        "image": image,
+                    },
+                    {
+                        "type": "text",
+                        "text": prompt
+                    },
+                ],
+            }
+        ]
+        return messages
+
+    def _prepare_inputs(
+        self,
+        messages: List[Dict]
+    ) -> Dict[str, Any]:
+        """准备vLLM输入格式
+
+        Args:
+            messages: 消息列表
+
+        Returns:
+            包含prompt和multi_modal_data的字典
+        """
+        # 应用chat模板
+        text = self.processor.apply_chat_template(
+            messages,
+            tokenize=False,
+            add_generation_prompt=True
+        )
+
+        # 处理视觉信息
+        image_inputs, video_inputs = process_vision_info(messages)
+
+        # vLLM 0.6.0+ 新版API格式
+        # 直接返回包含文本和多模态数据的字典
+        inputs = {
+            "prompt": text,
+            "multi_modal_data": {
+                "image": image_inputs[0] if image_inputs else None
+            }
+        }
+
+        return inputs
+
+    def inference(
+        self,
+        image: Image.Image,
+        prompt: str,
+        warmup: bool = False,
+        sampling_params: Optional[SamplingParams] = None
+    ) -> List[str]:
+        """OCR推理
+
+        Args:
+            image: PIL Image对象
+            prompt: 提示词
+            warmup: 是否为预热模式
+            sampling_params: 自定义采样参数
+
+        Returns:
+            生成的文本列表
+        """
+        # 构建消息
+        messages = self._build_messages(image, prompt)
+
+        # 准备输入
+        inputs = self._prepare_inputs(messages)
+
+        # 使用默认或自定义采样参数
+        params = sampling_params if sampling_params else self.default_sampling_params
+
+        # vLLM 0.6.0+ 新版API:直接传递inputs字典
+        outputs = self.llm.generate(
+            inputs,
+            params
+        )
+
+        # 提取生成的文本
+        generated_texts = [output.outputs[0].text for output in outputs]
+
+        if not warmup:
+            return generated_texts
+
+        return generated_texts
+
+    def batch_inference(
+        self,
+        images: List[Image.Image],
+        prompts: List[str],
+        sampling_params: Optional[SamplingParams] = None
+    ) -> List[str]:
+        """批量OCR推理(vLLM的核心优势)
+
+        Args:
+            images: PIL Image对象列表
+            prompts: 提示词列表
+            sampling_params: 自定义采样参数
+
+        Returns:
+            生成的文本列表
+        """
+        if len(images) != len(prompts):
+            raise ValueError(f"images数量({len(images)})和prompts数量({len(prompts)})不匹配")
+
+        # 准备所有输入
+        all_inputs = []
+        for image, prompt in zip(images, prompts):
+            messages = self._build_messages(image, prompt)
+            inputs = self._prepare_inputs(messages)
+            all_inputs.append(inputs)
+
+        # 使用默认或自定义采样参数
+        params = sampling_params if sampling_params else self.default_sampling_params
+
+        # vLLM 0.6.0+ 新版API:直接传递inputs列表
+        outputs = self.llm.generate(
+            all_inputs,
+            params
+        )
+
+        # 提取生成的文本
+        generated_texts = [output.outputs[0].text for output in outputs]
+
+        return generated_texts
+
+    def extract_icons(self, image: Image.Image) -> List[str]:
+        """识别图像中的象形图标识
+
+        Args:
+            image: PIL Image对象 - 待识别的化学品标签图像
+
+        Returns:
+            生成的文本列表
+        """
+        # 构建包含所有参考图像的messages
+        messages = [
+            {
+                "role": "user",
+                "content": []
+            }
+        ]
+
+        # 添加所有icon参考图像
+        content_list = messages[0]["content"]
+
+        # 按GHS编号顺序添加参考图像
+        sorted_icons = sorted(self.icon_images.items(), key=lambda x: x[0])
+        for icon_name, icon_image in sorted_icons:
+            content_list.append({
+                "type": "image",
+                "image": icon_image,
+            })
+            content_list.append({
+                "type": "text",
+                "text": f"参考图像:{icon_name}"
+            })
+
+        # 添加待识别的图像
+        content_list.append({
+            "type": "image",
+            "image": image,
+        })
+
+        # 添加提示词
+        content_list.append({
+            "type": "text",
+            "text": PROMPT_EXTRACT_ICON
+        })
+
+        # 准备输入
+        inputs = self._prepare_inputs(messages)
+
+        # vLLM 0.6.0+ 新版API:直接传递inputs字典
+        outputs = self.llm.generate(
+            inputs,
+            self.default_sampling_params
+        )
+
+        # 提取生成的文本
+        generated_texts = [output.outputs[0].text for output in outputs]
+
+        return generated_texts
+
+    def __del__(self):
+        """析构函数,清理资源"""
+        try:
+            if hasattr(self, 'llm'):
+                del self.llm
+                print("vLLM引擎已释放")
+        except:
+            pass
+
+
+if __name__ == '__main__':
+    # 测试代码
+    print("初始化 QwenOcrVLLM...")
+    qwen_ocr = QwenOcrVLLM(
+        tensor_parallel_size=1,      # 单GPU
+        gpu_memory_utilization=0.9,  # 使用90%显存
+        max_model_len=8192,          # 最大序列长度
+        dtype="bfloat16"             # 使用bfloat16精度
+    )
+
+    # 测试单张图像推理
+    print("\n" + "=" * 60)
+    print("测试单张图像推理...")
+    print("=" * 60)
+
+    test_image_path = "./test3.jpg"
+    if os.path.exists(test_image_path):
+        image = Image.open(test_image_path).convert("RGB")
+
+        # 测试提取名称
+        start_time = time.time()
+        result = qwen_ocr.inference(image, PROMPT_EXTRACT_PREVENTION)
+        elapsed = time.time() - start_time
+
+        print(f"\n推理耗时: {elapsed:.3f}秒")
+        print(f"提取结果:\n{result[0]}")
+    else:
+        print(f"测试图像 {test_image_path} 不存在,跳过测试")
+
+    # 测试批量推理
+    print("\n" + "=" * 60)
+    print("测试批量推理...")
+    print("=" * 60)
+
+    if os.path.exists(test_image_path):
+        # 创建3张相同的测试图像
+        images = [Image.open(test_image_path).convert("RGB") for _ in range(6)]
+        prompts = [PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER]
+
+        start_time = time.time()
+        results = qwen_ocr.batch_inference(images, prompts)
+        elapsed = time.time() - start_time
+        print(results[0])
+        print(results[1])
+        print(results[2])
+        print(results[3])
+        print(results[4])
+        print(results[5])
+
+        print(f"\n批量推理耗时: {elapsed:.3f}秒")
+        print(f"平均每张: {elapsed/3:.3f}秒")
+        print(f"批量推理加速比: {(elapsed/3):.3f}秒/张 vs 单张推理")
+
+    print("\n" + "=" * 60)
+    print("测试完成!")
+    print("=" * 60)

+ 0 - 3
requirements.txt

@@ -176,9 +176,6 @@ terminado==0.18.1
 tinycss2==1.4.0
 tokenizers==0.22.1
 tomli==2.2.1
-torch==2.8.0
-torch_npu==2.8.0
-torchvision==0.23.0
 tornado==6.4.2
 tqdm @ file:///croot/tqdm_1724853943256/work
 traitlets==5.14.3

+ 4 - 3
test_api.py

@@ -3,6 +3,7 @@ from io import BytesIO
 import base64
 import json
 from PIL import Image, ImageFilter, ImageEnhance
+from config import PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER
 
 def image_to_base64(pil_image, image_format="JPEG"):
     """将PIL Image图像转换为Base64编码"""
@@ -42,16 +43,16 @@ image_base64 = image_to_base64(image)
 #     "http://127.0.0.1:8000/api/v1/ocr",
 #     json={
 #         "image": image_base64,
-#         "text": PROMPT_EXTRACT_NAME
+#         "text": [PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER]
 #     }
 # )
 
 response = requests.post(
-    "https://749757254390085-http-7080.edge-proxy.gpugeek.com:8443/api/v1/agent_ocr",
+    "https://u475436-9425-5ad0e9a4.gda1.seetacloud.com:6443/api/v1/agent_ocr",
     json={
         "image": image_base64,
     }
 )
 
 result = response.json()
-print(result)
+print(result['data'])