| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- import base64
- from io import BytesIO
- from typing import List, Optional
- import requests
- from PIL import Image
- def image_to_base64(pil_image: Image.Image, image_format: str = "PNG") -> str:
- """将 PIL Image 转换为 base64 字符串"""
- buffered = BytesIO()
- pil_image.save(buffered, format=image_format)
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
- class QwenOcrRemote:
- """调用远端 OpenAI 兼容接口的 OCR 客户端
- 接口格式:POST /v1/chat/completions
- 图像通过 image_url (data:image/png;base64,...) 方式传入
- """
- def __init__(
- self,
- api_url: str,
- api_key: str,
- model: str,
- max_tokens: int = 4096,
- temperature: float = 0,
- timeout: int = 60,
- ):
- """
- Args:
- api_url: 接口地址,例如 http://10.69.29.202:31277/inference-api/.../v1/chat/completions
- api_key: Bearer token
- model: 模型名称,例如 Qwen3-VL-30B-A3B-Instruct
- max_tokens: 最大生成 token 数
- temperature: 采样温度
- timeout: 请求超时秒数
- """
- self.api_url = api_url
- self.headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {api_key}",
- }
- self.model = model
- self.max_tokens = max_tokens
- self.temperature = temperature
- self.timeout = timeout
- def _build_payload(self, image_b64: str, prompt: str) -> dict:
- """构建请求体,格式与你贴的 curl 完全一致"""
- return {
- "model": self.model,
- "messages": [
- {"role": "system", "content": "You are a helpful assistant."},
- {
- "role": "user",
- "content": [
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/png;base64,{image_b64}"
- },
- },
- {"type": "text", "text": prompt},
- ],
- },
- ],
- "max_tokens": self.max_tokens,
- "stream": False,
- "temperature": self.temperature,
- }
- def inference(self, image: Image.Image, prompt: str) -> str:
- """单张图像推理
- Args:
- image: PIL Image 对象
- prompt: 提示词文本
- Returns:
- 模型返回的文本字符串
- """
- image_b64 = image_to_base64(image)
- payload = self._build_payload(image_b64, prompt)
- response = requests.post(
- self.api_url,
- headers=self.headers,
- json=payload,
- timeout=self.timeout,
- )
- response.raise_for_status()
- result = response.json()
- return result["choices"][0]["message"]["content"]
- def batch_inference(
- self,
- images: List[Image.Image],
- prompts: List[str],
- ) -> List[str]:
- """批量推理(顺序请求)
- Args:
- images: PIL Image 列表
- prompts: 提示词列表,长度须与 images 一致
- Returns:
- 每张图对应的推理结果列表
- """
- if len(images) != len(prompts):
- raise ValueError(
- f"images 数量({len(images)}) 与 prompts 数量({len(prompts)}) 不一致"
- )
- return [self.inference(img, prompt) for img, prompt in zip(images, prompts)]
|