test_api_client.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. """
  2. OCR API 客户端测试脚本
  3. 演示如何调用 OCR API 服务
  4. """
  5. import base64
  6. import json
  7. import requests
  8. from pathlib import Path
  9. from PIL import Image
  10. import io
  11. class OCRClient:
  12. """OCR API 客户端"""
  13. def __init__(self, base_url: str = "http://localhost:8000"):
  14. self.base_url = base_url.rstrip("/")
  15. self.session = requests.Session()
  16. # 设置超时和重试
  17. self.timeout = 60
  18. def image_to_base64(self, image_path: str) -> str:
  19. """
  20. 将图像文件转换为 base64 字符串
  21. Args:
  22. image_path: 图像文件路径
  23. Returns:
  24. base64 编码字符串
  25. """
  26. image = Image.open(image_path).convert('RGB')
  27. buffered = io.BytesIO()
  28. image.save(buffered, format="JPEG")
  29. img_bytes = buffered.getvalue()
  30. return base64.b64encode(img_bytes).decode('utf-8')
  31. def pil_to_base64(self, pil_image: Image.Image) -> str:
  32. """
  33. 将 PIL Image 转换为 base64 字符串
  34. Args:
  35. pil_image: PIL Image 对象
  36. Returns:
  37. base64 编码字符串
  38. """
  39. if pil_image.mode != 'RGB':
  40. pil_image = pil_image.convert('RGB')
  41. buffered = io.BytesIO()
  42. pil_image.save(buffered, format="JPEG")
  43. img_bytes = buffered.getvalue()
  44. return base64.b64encode(img_bytes).decode('utf-8')
  45. def health_check(self) -> dict:
  46. """
  47. 健康检查
  48. Returns:
  49. 健康状态信息
  50. """
  51. url = f"{self.base_url}/health"
  52. try:
  53. response = self.session.get(url, timeout=5)
  54. response.raise_for_status()
  55. return response.json()
  56. except Exception as e:
  57. return {"error": str(e)}
  58. def ocr_inference(self, image_base64: str, prompt: str) -> dict:
  59. """
  60. 执行 OCR 推理
  61. Args:
  62. image_base64: base64 编码的图像
  63. prompt: 提示词
  64. Returns:
  65. 推理结果
  66. """
  67. url = f"{self.base_url}/api/v1/ocr"
  68. payload = {
  69. "image": image_base64,
  70. "text": prompt
  71. }
  72. try:
  73. response = self.session.post(
  74. url,
  75. json=payload,
  76. timeout=self.timeout
  77. )
  78. response.raise_for_status()
  79. return response.json()
  80. except requests.exceptions.HTTPError as e:
  81. return {
  82. "success": False,
  83. "error": f"HTTP Error: {e.response.status_code}",
  84. "detail": e.response.text
  85. }
  86. except Exception as e:
  87. return {
  88. "success": False,
  89. "error": str(e)
  90. }
  91. def ocr_from_file(self, image_path: str, prompt: str) -> dict:
  92. """
  93. 从文件执行 OCR 推理
  94. Args:
  95. image_path: 图像文件路径
  96. prompt: 提示词
  97. Returns:
  98. 推理结果
  99. """
  100. image_base64 = self.image_to_base64(image_path)
  101. return self.ocr_inference(image_base64, prompt)
  102. # ==================== 测试示例 ====================
  103. def test_basic():
  104. """基本测试"""
  105. # 创建客户端
  106. client = OCRClient("http://localhost:8000")
  107. # 1. 健康检查
  108. print("=" * 50)
  109. print("1. 健康检查")
  110. print("=" * 50)
  111. health = client.health_check()
  112. print(json.dumps(health, indent=2, ensure_ascii=False))
  113. # 2. OCR 推理测试
  114. print("\n" + "=" * 50)
  115. print("2. OCR 推理测试")
  116. print("=" * 50)
  117. # 示例提示词
  118. prompt = """
  119. 你是一个专业的化学品安全标签说明识别助手。
  120. 请从图像中提取化学品的中文名称和英文名称(如有)。
  121. 按照以下JSON格式输出结果:
  122. {
  123. "name_cn": "化学品中文名称",
  124. "name_en": "化学品英文名称"
  125. }
  126. 注意:返回结果必须是标准JSON格式,不要包含```json```标记。
  127. """
  128. # 替换为实际的图像路径
  129. image_path = "./test3.jpg"
  130. if Path(image_path).exists():
  131. result = client.ocr_from_file(image_path, prompt)
  132. print(json.dumps(result, indent=2, ensure_ascii=False))
  133. else:
  134. print(f"图像文件不存在: {image_path}")
  135. print("请将测试图像放在当前目录下")
  136. def test_concurrent():
  137. """并发测试"""
  138. import concurrent.futures
  139. import time
  140. client = OCRClient("http://localhost:8000")
  141. prompt = "提取图像中的文字信息"
  142. # 创建一个测试图像(白色背景)
  143. test_image = Image.new('RGB', (224, 224), color='white')
  144. image_base64 = client.pil_to_base64(test_image)
  145. def send_request(idx):
  146. """发送单个请求"""
  147. start = time.time()
  148. result = client.ocr_inference(image_base64, prompt)
  149. elapsed = time.time() - start
  150. return idx, elapsed, result.get("success", False)
  151. print("\n" + "=" * 50)
  152. print("3. 并发测试 (10 个并发请求)")
  153. print("=" * 50)
  154. # 发送 10 个并发请求
  155. with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
  156. futures = [executor.submit(send_request, i) for i in range(10)]
  157. results = [f.result() for f in concurrent.futures.as_completed(futures)]
  158. # 统计结果
  159. success_count = sum(1 for _, _, success in results if success)
  160. avg_time = sum(elapsed for _, elapsed, _ in results) / len(results)
  161. print(f"总请求数: {len(results)}")
  162. print(f"成功请求: {success_count}")
  163. print(f"失败请求: {len(results) - success_count}")
  164. print(f"平均响应时间: {avg_time:.2f}秒")
  165. if __name__ == "__main__":
  166. # 基本测试
  167. test_basic()
  168. # 并发测试
  169. # test_concurrent()