Browse Source

将请求改为串行,并将请求格式改为海关模型符合的格式

Sherlock1011 2 months ago
parent
commit
f88455693a
8 changed files with 321 additions and 106 deletions
  1. 43 24
      agent/agent.py
  2. 1 1
      config/config.py
  3. 3 1
      model/__init__.py
  4. 116 58
      model/model_api.py
  5. 117 0
      model/qwen_ocr_remote.py
  6. 1 1
      model/qwen_ocr_vllm.py
  7. 14 13
      test_api.py
  8. 26 8
      test_api_client.py

+ 43 - 24
agent/agent.py

@@ -6,7 +6,6 @@ import base64
 import json
 from PIL import Image, ImageFilter, ImageEnhance
 import time
-from concurrent.futures import ThreadPoolExecutor, as_completed
 import requests
 
 def image_to_base64(pil_image, image_format="JPEG"):
@@ -46,45 +45,65 @@ class OcrAgent:
     def __init__(self):
         self._url = "http://127.0.0.1:8000/api/v1/ocr"
 
-    def extract_part_info(self, image_base64, prompts):
-        """根据提示词提取信息"""
+    def extract_single(self, image_base64: str, prompt: str, index: int):
+        """单个任务请求,返回 (index, 结果文本)"""
         response = requests.post(
             self._url,
             json={
-                "image": image_base64,
-                "text": prompts
-            }
+                "model": "Qwen3-VL-32B-Instruct",
+                "messages": [
+                    {"role": "system", "content": "You are a helpful assistant."},
+                    {
+                        "role": "user",
+                        "content": [
+                            {
+                                "type": "image_url",
+                                "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
+                            },
+                            {"type": "text", "text": prompt}
+                        ]
+                    }
+                ],
+                "max_tokens": 4096,
+                "stream": False,
+                "temperature": 0
+            },
+            timeout=600
         )
-        result = response.json()
-        return result
+        response.raise_for_status()
+        content = response.json()["choices"][0]["message"]["content"]
+        return index, content
 
     def agent_ocr(self, image):
         """qwen_ocr提取化学品安全标签信息"""
-        image = resize_image(image, max_size=1024)
+        image = resize_image(image, max_size=512)
         image_base64 = image_to_base64(image)
 
         start_time = time.perf_counter()
 
-        # 定义需要并行执行的任务
+        # 定义需要并行执行的任务(顺序固定,用 index 保序)
         prompts = [
-            PROMPT_EXTRACT_ICON,
-            PROMPT_EXTRACT_NAME,
-            PROMPT_EXTRACT_COMPONENTS,
-            PROMPT_EXTRACT_KEYWORD,
-            PROMPT_EXTRACT_PREVENTION,
-            PROMPT_EXTRACT_SUPPLIER
+            PROMPT_EXTRACT_ICON,        # 0
+            PROMPT_EXTRACT_NAME,        # 1
+            PROMPT_EXTRACT_COMPONENTS,  # 2
+            PROMPT_EXTRACT_KEYWORD,     # 3
+            PROMPT_EXTRACT_PREVENTION,  # 4
+            PROMPT_EXTRACT_SUPPLIER     # 5
         ]
 
-        results = self.extract_part_info(image_base64, prompts)
-        results = results["data"]
+        # 串行发送 6 个请求
+        results = []
+        for idx, prompt in enumerate(prompts):
+            _, content = self.extract_single(image_base64, prompt, idx)
+            results.append(content)
 
-        # 从结果中提取数据
-        icon = json.loads(results[0])
-        name = json.loads(results[1])
-        tag = json.loads(results[2])
+        # 从结果中提取数据(顺序已由 index 保证)
+        icon        = json.loads(results[0])
+        name        = json.loads(results[1])
+        tag         = json.loads(results[2])
         risk_notice = json.loads(results[3])
-        pre_notice = json.loads(results[4])
-        suppliers = json.loads(results[5])
+        pre_notice  = json.loads(results[4])
+        suppliers   = json.loads(results[5])
 
         end_time = time.perf_counter()
         elapsed_time = end_time - start_time

+ 1 - 1
config/config.py

@@ -1,7 +1,7 @@
 # OCR配置文件
 
 # 模型路径
-MODEL_PATH = "/root/autodl-tmp/llm/Qwen3-VL-8B-Instruct"
+MODEL_PATH = "/root/autodl-tmp/llm/Qwen3-VL-32B-Instruct"
 
 
 # ========== OCR提示词 - 分步骤提取 ==========

+ 3 - 1
model/__init__.py

@@ -1,7 +1,9 @@
 from .qwen_ocr import QwenOcr
 from .qwen_ocr_vllm import QwenOcrVLLM
+from .qwen_ocr_remote import QwenOcrRemote
 
 __all__ = [
     "QwenOcr",
-    "QwenOcrVLLM"
+    "QwenOcrVLLM",
+    "QwenOcrRemote",
 ]

+ 116 - 58
model/model_api.py

@@ -33,38 +33,84 @@ logger = logging.getLogger(__name__)
 
 
 # ==================== 请求/响应模型 ====================
-class OCRRequest(BaseModel):
-    """OCR 推理请求模型"""
-    image: str = Field(..., description="Base64 编码的图像字符串")
-    text: list = Field(..., description="OCR 提示词文本列表")
+class ImageURL(BaseModel):
+    url: str = Field(..., description="data:image/png;base64,... 格式的图像 URL")
 
-    @validator('image')
-    def validate_image(cls, v):
-        """验证 base64 图像格式"""
-        if not v:
-            raise ValueError("图像不能为空")
-        try:
-            # 尝试解码验证格式
-            base64.b64decode(v)
-        except Exception:
-            raise ValueError("无效的 base64 图像格式")
-        return v
 
-    @validator('text')
-    def validate_text(cls, v):
-        """验证提示词文本"""
+class ContentItem(BaseModel):
+    type: str = Field(..., description="内容类型: image_url 或 text")
+    image_url: Optional[ImageURL] = None
+    text: Optional[str] = None
+
+
+class Message(BaseModel):
+    role: str = Field(..., description="角色: system 或 user")
+    content: Any = Field(..., description="消息内容")
+
+
+class OCRRequest(BaseModel):
+    """OCR 推理请求模型(OpenAI 兼容格式)"""
+    model: Optional[str] = Field(None, description="模型名称")
+    messages: list = Field(..., description="消息列表")
+    max_tokens: Optional[int] = Field(4096, description="最大生成 token 数")
+    stream: Optional[bool] = Field(False, description="是否流式输出")
+    temperature: Optional[float] = Field(0, description="采样温度")
+
+    @validator('messages')
+    def validate_messages(cls, v):
         if not v:
-            raise ValueError("提示词不能为空")
+            raise ValueError("messages 不能为空")
         return v
 
+    def get_image_base64(self) -> str:
+        """从 messages 中提取 base64 图像(去掉 data:image/xxx;base64, 前缀)"""
+        for msg in self.messages:
+            if msg.get('role') != 'user':
+                continue
+            content = msg.get('content', [])
+            if not isinstance(content, list):
+                continue
+            for item in content:
+                if item.get('type') == 'image_url':
+                    url = item.get('image_url', {}).get('url', '')
+                    # 去掉 "data:image/png;base64," 前缀
+                    if ';base64,' in url:
+                        return url.split(';base64,', 1)[1]
+                    return url
+        raise ValueError("messages 中未找到 image_url")
+
+    def get_prompt(self) -> str:
+        """从 messages 中提取用户文本提示词"""
+        for msg in self.messages:
+            if msg.get('role') != 'user':
+                continue
+            content = msg.get('content', [])
+            if not isinstance(content, list):
+                continue
+            for item in content:
+                if item.get('type') == 'text':
+                    return item.get('text', '')
+        raise ValueError("messages 中未找到 text")
+
+
+class ChoiceMessage(BaseModel):
+    role: str = "assistant"
+    content: Optional[str] = None
+
+
+class Choice(BaseModel):
+    index: int = 0
+    message: ChoiceMessage
+    finish_reason: str = "stop"
+
 
 class OCRResponse(BaseModel):
-    """OCR 推理响应模型"""
-    success: bool = Field(..., description="请求是否成功")
-    data: Optional[Any] = Field(None, description="推理结果数据")
-    message: str = Field(..., description="响应消息")
+    """OCR 推理响应模型(OpenAI 兼容格式)"""
+    id: str = Field(..., description="请求ID")
+    object: str = Field("chat.completion", description="对象类型")
+    model: Optional[str] = Field(None, description="模型名称")
+    choices: list = Field(..., description="推理结果列表")
     timestamp: str = Field(..., description="响应时间戳")
-    request_id: Optional[str] = Field(None, description="请求ID(用于追踪)")
 
 
 class HealthResponse(BaseModel):
@@ -166,14 +212,14 @@ class ModelManager:
             logger.error(f"Base64 转换失败: {e}")
             raise ValueError(f"图像解码失败: {str(e)}")
 
-    async def inference(self, image_base64: str, prompts: str) -> list:
+    async def inference(self, image_base64: str, prompt: str) -> str:
         """
-        执行 OCR 推理(带并发控制)
+        执行 OCR 推理(带并发控制)
         Args:
-            image_base64: base64 编码的图像
-            prompts: 提示词
+            image_base64: base64 编码的图像(不含 data URI 前缀)
+            prompt: 提示词
         Returns:
-            推理结果
+            推理结果字符串
         """
         if not self.is_loaded or self.model is None:
             raise RuntimeError("模型未加载")
@@ -187,16 +233,16 @@ class ModelManager:
                 # 转换图像
                 pil_image = self.base64_to_pil(image_base64)
 
-                # 在线程池中执行推理,避免阻塞
+                # 在线程池中执行推理避免阻塞
                 loop = asyncio.get_event_loop()
-                result = await loop.run_in_executor(
+                results = await loop.run_in_executor(
                     None,
                     self.model.batch_inference,
-                    [pil_image] * len(prompts),
-                    prompts
+                    [pil_image],
+                    [prompt]
                 )
 
-                return result
+                return results[0]
             finally:
                 async with self._request_lock:
                     self.current_requests -= 1
@@ -270,42 +316,55 @@ async def health_check():
 @app.post("/api/v1/ocr", response_model=OCRResponse)
 async def ocr_inference(request: OCRRequest):
     """
-    OCR 推理端点
-
-    Args:
-        request: OCRRequest 对象,包含 image(base64) 和 text(提示词)
-
-    Returns:
-        OCRResponse: 推理结果
+    OCR 推理端点(OpenAI 兼容格式)
+
+    请求体与 /v1/chat/completions 格式一致:
+    {
+      "model": "...",
+      "messages": [
+        {"role": "system", "content": "..."},
+        {"role": "user", "content": [
+          {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
+          {"type": "text", "text": "问题"}
+        ]}
+      ],
+      "max_tokens": 4096,
+      "stream": false,
+      "temperature": 0
+    }
     """
-    request_id = f"req_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
+    request_id = f"chatcmpl-{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
     logger.info(f"[{request_id}] 收到 OCR 请求")
 
     try:
-        # 获取模型管理器
         manager = await ModelManager.get_instance()
 
         if not manager.is_loaded:
             raise HTTPException(
                 status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
-                detail="模型未加载,服务暂不可用"
+                detail="模型未加载服务暂不可用"
             )
 
-        # 执行推理
+        # 从 messages 中提取图像和提示词
+        image_base64 = request.get_image_base64()
+        prompt = request.get_prompt()
+
         logger.info(f"[{request_id}] 开始推理...")
-        result = await manager.inference(request.image, request.text)
+        content = await manager.inference(image_base64, prompt)
         logger.info(f"[{request_id}] 推理完成")
 
         return OCRResponse(
-            success=True,
-            data=result,
-            message="推理成功",
-            timestamp=datetime.now().isoformat(),
-            request_id=request_id
+            id=request_id,
+            model=request.model,
+            choices=[{
+                "index": 0,
+                "message": {"role": "assistant", "content": content},
+                "finish_reason": "stop"
+            }],
+            timestamp=datetime.now().isoformat()
         )
 
     except ValueError as e:
-        # 参数验证错误
         logger.warning(f"[{request_id}] 参数验证失败: {e}")
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
@@ -313,7 +372,6 @@ async def ocr_inference(request: OCRRequest):
         )
 
     except RuntimeError as e:
-        # 模型运行时错误
         logger.error(f"[{request_id}] 运行时错误: {e}")
         raise HTTPException(
             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -321,7 +379,6 @@ async def ocr_inference(request: OCRRequest):
         )
 
     except Exception as e:
-        # 未知错误
         logger.error(f"[{request_id}] 未知错误: {e}", exc_info=True)
         raise HTTPException(
             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -336,10 +393,11 @@ async def global_exception_handler(request, exc):
     return JSONResponse(
         status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
         content={
-            "success": False,
-            "data": None,
-            "message": f"服务器错误: {str(exc)}",
-            "timestamp": datetime.now().isoformat()
+            "id": None,
+            "object": "chat.completion",
+            "choices": [],
+            "timestamp": datetime.now().isoformat(),
+            "error": {"message": str(exc)}
         }
     )
 

+ 117 - 0
model/qwen_ocr_remote.py

@@ -0,0 +1,117 @@
+import base64
+from io import BytesIO
+from typing import List, Optional
+
+import requests
+from PIL import Image
+
+
+def image_to_base64(pil_image: Image.Image, image_format: str = "PNG") -> str:
+    """将 PIL Image 转换为 base64 字符串"""
+    buffered = BytesIO()
+    pil_image.save(buffered, format=image_format)
+    return base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+
+class QwenOcrRemote:
+    """调用远端 OpenAI 兼容接口的 OCR 客户端
+
+    接口格式:POST /v1/chat/completions
+    图像通过 image_url (data:image/png;base64,...) 方式传入
+    """
+
+    def __init__(
+        self,
+        api_url: str,
+        api_key: str,
+        model: str,
+        max_tokens: int = 4096,
+        temperature: float = 0,
+        timeout: int = 60,
+    ):
+        """
+        Args:
+            api_url:     接口地址,例如 http://10.69.29.202:31277/inference-api/.../v1/chat/completions
+            api_key:     Bearer token
+            model:       模型名称,例如 Qwen3-VL-30B-A3B-Instruct
+            max_tokens:  最大生成 token 数
+            temperature: 采样温度
+            timeout:     请求超时秒数
+        """
+        self.api_url = api_url
+        self.headers = {
+            "Content-Type": "application/json",
+            "Authorization": f"Bearer {api_key}",
+        }
+        self.model = model
+        self.max_tokens = max_tokens
+        self.temperature = temperature
+        self.timeout = timeout
+
+    def _build_payload(self, image_b64: str, prompt: str) -> dict:
+        """构建请求体,格式与你贴的 curl 完全一致"""
+        return {
+            "model": self.model,
+            "messages": [
+                {"role": "system", "content": "You are a helpful assistant."},
+                {
+                    "role": "user",
+                    "content": [
+                        {
+                            "type": "image_url",
+                            "image_url": {
+                                "url": f"data:image/png;base64,{image_b64}"
+                            },
+                        },
+                        {"type": "text", "text": prompt},
+                    ],
+                },
+            ],
+            "max_tokens": self.max_tokens,
+            "stream": False,
+            "temperature": self.temperature,
+        }
+
+    def inference(self, image: Image.Image, prompt: str) -> str:
+        """单张图像推理
+
+        Args:
+            image:  PIL Image 对象
+            prompt: 提示词文本
+
+        Returns:
+            模型返回的文本字符串
+        """
+        image_b64 = image_to_base64(image)
+        payload = self._build_payload(image_b64, prompt)
+
+        response = requests.post(
+            self.api_url,
+            headers=self.headers,
+            json=payload,
+            timeout=self.timeout,
+        )
+        response.raise_for_status()
+
+        result = response.json()
+        return result["choices"][0]["message"]["content"]
+
+    def batch_inference(
+        self,
+        images: List[Image.Image],
+        prompts: List[str],
+    ) -> List[str]:
+        """批量推理(顺序请求)
+
+        Args:
+            images:  PIL Image 列表
+            prompts: 提示词列表,长度须与 images 一致
+
+        Returns:
+            每张图对应的推理结果列表
+        """
+        if len(images) != len(prompts):
+            raise ValueError(
+                f"images 数量({len(images)}) 与 prompts 数量({len(prompts)}) 不一致"
+            )
+        return [self.inference(img, prompt) for img, prompt in zip(images, prompts)]

+ 1 - 1
model/qwen_ocr_vllm.py

@@ -41,7 +41,7 @@ class QwenOcrVLLM:
     def __init__(
         self,
         icon_dir: str = "./icon",
-        tensor_parallel_size: int = 1,
+        tensor_parallel_size: int = 2,
         gpu_memory_utilization: float = 0.9,
         max_model_len: int = 8192,
         dtype: str = "bfloat16",

+ 14 - 13
test_api.py

@@ -2,8 +2,9 @@ import requests
 from io import BytesIO
 import base64
 import json
+import time
 from PIL import Image, ImageFilter, ImageEnhance
-from config import PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER
+
 
 def image_to_base64(pil_image, image_format="JPEG"):
     """将PIL Image图像转换为Base64编码"""
@@ -13,12 +14,12 @@ def image_to_base64(pil_image, image_format="JPEG"):
     encode_image = base64.b64encode(img_byte_array).decode('utf-8')
     return encode_image
 
+
 def resize_image(image, max_size=512):
     """缩放图像尺寸,保持 OCR 质量"""
     width, height = image.size
     max_dim = max(width, height)
 
-    # 如果图像不需要缩小,直接返回
     if max_dim <= max_size:
         return image
 
@@ -34,25 +35,25 @@ def resize_image(image, max_size=512):
 
     return resized
 
+
+# ==================== 请求 agent_ocr 服务 ====================
 image = Image.open('./test1.jpg')
 image = resize_image(image)
 image_base64 = image_to_base64(image)
 
-
-# response = requests.post(
-#     "http://127.0.0.1:8000/api/v1/ocr",
-#     json={
-#         "image": image_base64,
-#         "text": [PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER]
-#     }
-# )
+print("发送请求到 agent_ocr 服务...")
+start = time.time()
 
 response = requests.post(
-    "https://u475436-9425-5ad0e9a4.gda1.seetacloud.com:6443/api/v1/agent_ocr",
+    "http://127.0.0.1:6006/api/v1/agent_ocr",
     json={
         "image": image_base64,
-    }
+    },
+    timeout=700
 )
 
+elapsed = time.time() - start
+print(f"耗时: {elapsed:.2f}s  状态码: {response.status_code}")
+
 result = response.json()
-print(result['data'])
+print(json.dumps(result, indent=2, ensure_ascii=False))

+ 26 - 8
test_api_client.py

@@ -62,19 +62,37 @@ class OCRClient:
         except Exception as e:
             return {"error": str(e)}
 
-    def ocr_inference(self, image_base64: str, prompt: str) -> dict:
+    def ocr_inference(self, image_base64: str, prompt: str, model: str = "Qwen3-VL-32B-Instruct") -> dict:
         """
-        执行 OCR 推理
+        执行 OCR 推理(OpenAI 兼容格式)
         Args:
-            image_base64: base64 编码的图像
+            image_base64: base64 编码的图像(不含 data URI 前缀)
             prompt: 提示词
+            model: 模型名称
         Returns:
             推理结果
         """
         url = f"{self.base_url}/api/v1/ocr"
         payload = {
-            "image": image_base64,
-            "text": prompt
+            "model": model,
+            "messages": [
+                {"role": "system", "content": "You are a helpful assistant."},
+                {
+                    "role": "user",
+                    "content": [
+                        {
+                            "type": "image_url",
+                            "image_url": {
+                                "url": f"data:image/jpeg;base64,{image_base64}"
+                            }
+                        },
+                        {"type": "text", "text": prompt}
+                    ]
+                }
+            ],
+            "max_tokens": 4096,
+            "stream": False,
+            "temperature": 0
         }
 
         try:
@@ -171,7 +189,7 @@ def test_concurrent():
         start = time.time()
         result = client.ocr_inference(image_base64, prompt)
         elapsed = time.time() - start
-        return idx, elapsed, result.get("success", False)
+        return idx, elapsed, bool(result.get("choices"))
 
     print("\n" + "=" * 50)
     print("3. 并发测试 (10 个并发请求)")
@@ -194,7 +212,7 @@ def test_concurrent():
 
 if __name__ == "__main__":
     # 基本测试
-    test_basic()
+    # test_basic()
 
     # 并发测试
-    # test_concurrent()
+    test_concurrent()