test_api_client.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. """
  2. OCR API 客户端测试脚本
  3. 演示如何调用 OCR API 服务
  4. """
  5. import base64
  6. import json
  7. import requests
  8. from pathlib import Path
  9. from PIL import Image
  10. import io
  11. class OCRClient:
  12. """OCR API 客户端"""
  13. def __init__(self, base_url: str = "http://localhost:8000"):
  14. self.base_url = base_url.rstrip("/")
  15. self.session = requests.Session()
  16. # 设置超时和重试
  17. self.timeout = 60
  18. def image_to_base64(self, image_path: str) -> str:
  19. """
  20. 将图像文件转换为 base64 字符串
  21. Args:
  22. image_path: 图像文件路径
  23. Returns:
  24. base64 编码字符串
  25. """
  26. image = Image.open(image_path).convert('RGB')
  27. buffered = io.BytesIO()
  28. image.save(buffered, format="JPEG")
  29. img_bytes = buffered.getvalue()
  30. return base64.b64encode(img_bytes).decode('utf-8')
  31. def pil_to_base64(self, pil_image: Image.Image) -> str:
  32. """
  33. 将 PIL Image 转换为 base64 字符串
  34. Args:
  35. pil_image: PIL Image 对象
  36. Returns:
  37. base64 编码字符串
  38. """
  39. if pil_image.mode != 'RGB':
  40. pil_image = pil_image.convert('RGB')
  41. buffered = io.BytesIO()
  42. pil_image.save(buffered, format="JPEG")
  43. img_bytes = buffered.getvalue()
  44. return base64.b64encode(img_bytes).decode('utf-8')
  45. def health_check(self) -> dict:
  46. """
  47. 健康检查
  48. Returns:
  49. 健康状态信息
  50. """
  51. url = f"{self.base_url}/health"
  52. try:
  53. response = self.session.get(url, timeout=5)
  54. response.raise_for_status()
  55. return response.json()
  56. except Exception as e:
  57. return {"error": str(e)}
  58. def ocr_inference(self, image_base64: str, prompt: str, model: str = "Qwen3-VL-32B-Instruct") -> dict:
  59. """
  60. 执行 OCR 推理(OpenAI 兼容格式)
  61. Args:
  62. image_base64: base64 编码的图像(不含 data URI 前缀)
  63. prompt: 提示词
  64. model: 模型名称
  65. Returns:
  66. 推理结果
  67. """
  68. url = f"{self.base_url}/api/v1/ocr"
  69. payload = {
  70. "model": model,
  71. "messages": [
  72. {"role": "system", "content": "You are a helpful assistant."},
  73. {
  74. "role": "user",
  75. "content": [
  76. {
  77. "type": "image_url",
  78. "image_url": {
  79. "url": f"data:image/jpeg;base64,{image_base64}"
  80. }
  81. },
  82. {"type": "text", "text": prompt}
  83. ]
  84. }
  85. ],
  86. "max_tokens": 4096,
  87. "stream": False,
  88. "temperature": 0
  89. }
  90. try:
  91. response = self.session.post(
  92. url,
  93. json=payload,
  94. timeout=self.timeout
  95. )
  96. response.raise_for_status()
  97. return response.json()
  98. except requests.exceptions.HTTPError as e:
  99. return {
  100. "success": False,
  101. "error": f"HTTP Error: {e.response.status_code}",
  102. "detail": e.response.text
  103. }
  104. except Exception as e:
  105. return {
  106. "success": False,
  107. "error": str(e)
  108. }
  109. def ocr_from_file(self, image_path: str, prompt: str) -> dict:
  110. """
  111. 从文件执行 OCR 推理
  112. Args:
  113. image_path: 图像文件路径
  114. prompt: 提示词
  115. Returns:
  116. 推理结果
  117. """
  118. image_base64 = self.image_to_base64(image_path)
  119. return self.ocr_inference(image_base64, prompt)
  120. # ==================== 测试示例 ====================
  121. def test_basic():
  122. """基本测试"""
  123. # 创建客户端
  124. client = OCRClient("http://localhost:8000")
  125. # 1. 健康检查
  126. print("=" * 50)
  127. print("1. 健康检查")
  128. print("=" * 50)
  129. health = client.health_check()
  130. print(json.dumps(health, indent=2, ensure_ascii=False))
  131. # 2. OCR 推理测试
  132. print("\n" + "=" * 50)
  133. print("2. OCR 推理测试")
  134. print("=" * 50)
  135. # 示例提示词
  136. prompt = """
  137. 你是一个专业的化学品安全标签说明识别助手。
  138. 请从图像中提取化学品的中文名称和英文名称(如有)。
  139. 按照以下JSON格式输出结果:
  140. {
  141. "name_cn": "化学品中文名称",
  142. "name_en": "化学品英文名称"
  143. }
  144. 注意:返回结果必须是标准JSON格式,不要包含```json```标记。
  145. """
  146. # 替换为实际的图像路径
  147. image_path = "./test3.jpg"
  148. if Path(image_path).exists():
  149. result = client.ocr_from_file(image_path, prompt)
  150. print(json.dumps(result, indent=2, ensure_ascii=False))
  151. else:
  152. print(f"图像文件不存在: {image_path}")
  153. print("请将测试图像放在当前目录下")
  154. def test_concurrent():
  155. """并发测试"""
  156. import concurrent.futures
  157. import time
  158. client = OCRClient("http://localhost:8000")
  159. prompt = "提取图像中的文字信息"
  160. # 创建一个测试图像(白色背景)
  161. test_image = Image.new('RGB', (224, 224), color='white')
  162. image_base64 = client.pil_to_base64(test_image)
  163. def send_request(idx):
  164. """发送单个请求"""
  165. start = time.time()
  166. result = client.ocr_inference(image_base64, prompt)
  167. elapsed = time.time() - start
  168. return idx, elapsed, bool(result.get("choices"))
  169. print("\n" + "=" * 50)
  170. print("3. 并发测试 (10 个并发请求)")
  171. print("=" * 50)
  172. # 发送 10 个并发请求
  173. with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
  174. futures = [executor.submit(send_request, i) for i in range(10)]
  175. results = [f.result() for f in concurrent.futures.as_completed(futures)]
  176. # 统计结果
  177. success_count = sum(1 for _, _, success in results if success)
  178. avg_time = sum(elapsed for _, elapsed, _ in results) / len(results)
  179. print(f"总请求数: {len(results)}")
  180. print(f"成功请求: {success_count}")
  181. print(f"失败请求: {len(results) - success_count}")
  182. print(f"平均响应时间: {avg_time:.2f}秒")
  183. if __name__ == "__main__":
  184. # 基本测试
  185. # test_basic()
  186. # 并发测试
  187. test_concurrent()