qwen_ocr_remote.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import base64
  2. from io import BytesIO
  3. from typing import List, Optional
  4. import requests
  5. from PIL import Image
  6. def image_to_base64(pil_image: Image.Image, image_format: str = "PNG") -> str:
  7. """将 PIL Image 转换为 base64 字符串"""
  8. buffered = BytesIO()
  9. pil_image.save(buffered, format=image_format)
  10. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  11. class QwenOcrRemote:
  12. """调用远端 OpenAI 兼容接口的 OCR 客户端
  13. 接口格式:POST /v1/chat/completions
  14. 图像通过 image_url (data:image/png;base64,...) 方式传入
  15. """
  16. def __init__(
  17. self,
  18. api_url: str,
  19. api_key: str,
  20. model: str,
  21. max_tokens: int = 4096,
  22. temperature: float = 0,
  23. timeout: int = 60,
  24. ):
  25. """
  26. Args:
  27. api_url: 接口地址,例如 http://10.69.29.202:31277/inference-api/.../v1/chat/completions
  28. api_key: Bearer token
  29. model: 模型名称,例如 Qwen3-VL-30B-A3B-Instruct
  30. max_tokens: 最大生成 token 数
  31. temperature: 采样温度
  32. timeout: 请求超时秒数
  33. """
  34. self.api_url = api_url
  35. self.headers = {
  36. "Content-Type": "application/json",
  37. "Authorization": f"Bearer {api_key}",
  38. }
  39. self.model = model
  40. self.max_tokens = max_tokens
  41. self.temperature = temperature
  42. self.timeout = timeout
  43. def _build_payload(self, image_b64: str, prompt: str) -> dict:
  44. """构建请求体,格式与你贴的 curl 完全一致"""
  45. return {
  46. "model": self.model,
  47. "messages": [
  48. {"role": "system", "content": "You are a helpful assistant."},
  49. {
  50. "role": "user",
  51. "content": [
  52. {
  53. "type": "image_url",
  54. "image_url": {
  55. "url": f"data:image/png;base64,{image_b64}"
  56. },
  57. },
  58. {"type": "text", "text": prompt},
  59. ],
  60. },
  61. ],
  62. "max_tokens": self.max_tokens,
  63. "stream": False,
  64. "temperature": self.temperature,
  65. }
  66. def inference(self, image: Image.Image, prompt: str) -> str:
  67. """单张图像推理
  68. Args:
  69. image: PIL Image 对象
  70. prompt: 提示词文本
  71. Returns:
  72. 模型返回的文本字符串
  73. """
  74. image_b64 = image_to_base64(image)
  75. payload = self._build_payload(image_b64, prompt)
  76. response = requests.post(
  77. self.api_url,
  78. headers=self.headers,
  79. json=payload,
  80. timeout=self.timeout,
  81. )
  82. response.raise_for_status()
  83. result = response.json()
  84. return result["choices"][0]["message"]["content"]
  85. def batch_inference(
  86. self,
  87. images: List[Image.Image],
  88. prompts: List[str],
  89. ) -> List[str]:
  90. """批量推理(顺序请求)
  91. Args:
  92. images: PIL Image 列表
  93. prompts: 提示词列表,长度须与 images 一致
  94. Returns:
  95. 每张图对应的推理结果列表
  96. """
  97. if len(images) != len(prompts):
  98. raise ValueError(
  99. f"images 数量({len(images)}) 与 prompts 数量({len(prompts)}) 不一致"
  100. )
  101. return [self.inference(img, prompt) for img, prompt in zip(images, prompts)]