""" OCR API 客户端测试脚本 演示如何调用 OCR API 服务 """ import base64 import json import requests from pathlib import Path from PIL import Image import io class OCRClient: """OCR API 客户端""" def __init__(self, base_url: str = "http://localhost:8000"): self.base_url = base_url.rstrip("/") self.session = requests.Session() # 设置超时和重试 self.timeout = 60 def image_to_base64(self, image_path: str) -> str: """ 将图像文件转换为 base64 字符串 Args: image_path: 图像文件路径 Returns: base64 编码字符串 """ image = Image.open(image_path).convert('RGB') buffered = io.BytesIO() image.save(buffered, format="JPEG") img_bytes = buffered.getvalue() return base64.b64encode(img_bytes).decode('utf-8') def pil_to_base64(self, pil_image: Image.Image) -> str: """ 将 PIL Image 转换为 base64 字符串 Args: pil_image: PIL Image 对象 Returns: base64 编码字符串 """ if pil_image.mode != 'RGB': pil_image = pil_image.convert('RGB') buffered = io.BytesIO() pil_image.save(buffered, format="JPEG") img_bytes = buffered.getvalue() return base64.b64encode(img_bytes).decode('utf-8') def health_check(self) -> dict: """ 健康检查 Returns: 健康状态信息 """ url = f"{self.base_url}/health" try: response = self.session.get(url, timeout=5) response.raise_for_status() return response.json() except Exception as e: return {"error": str(e)} def ocr_inference(self, image_base64: str, prompt: str) -> dict: """ 执行 OCR 推理 Args: image_base64: base64 编码的图像 prompt: 提示词 Returns: 推理结果 """ url = f"{self.base_url}/api/v1/ocr" payload = { "image": image_base64, "text": prompt } try: response = self.session.post( url, json=payload, timeout=self.timeout ) response.raise_for_status() return response.json() except requests.exceptions.HTTPError as e: return { "success": False, "error": f"HTTP Error: {e.response.status_code}", "detail": e.response.text } except Exception as e: return { "success": False, "error": str(e) } def ocr_from_file(self, image_path: str, prompt: str) -> dict: """ 从文件执行 OCR 推理 Args: image_path: 图像文件路径 prompt: 提示词 Returns: 推理结果 """ image_base64 = self.image_to_base64(image_path) return self.ocr_inference(image_base64, prompt) # ==================== 测试示例 ==================== def test_basic(): """基本测试""" # 创建客户端 client = OCRClient("http://localhost:8000") # 1. 健康检查 print("=" * 50) print("1. 健康检查") print("=" * 50) health = client.health_check() print(json.dumps(health, indent=2, ensure_ascii=False)) # 2. OCR 推理测试 print("\n" + "=" * 50) print("2. OCR 推理测试") print("=" * 50) # 示例提示词 prompt = """ 你是一个专业的化学品安全标签说明识别助手。 请从图像中提取化学品的中文名称和英文名称(如有)。 按照以下JSON格式输出结果: { "name_cn": "化学品中文名称", "name_en": "化学品英文名称" } 注意:返回结果必须是标准JSON格式,不要包含```json```标记。 """ # 替换为实际的图像路径 image_path = "./test3.jpg" if Path(image_path).exists(): result = client.ocr_from_file(image_path, prompt) print(json.dumps(result, indent=2, ensure_ascii=False)) else: print(f"图像文件不存在: {image_path}") print("请将测试图像放在当前目录下") def test_concurrent(): """并发测试""" import concurrent.futures import time client = OCRClient("http://localhost:8000") prompt = "提取图像中的文字信息" # 创建一个测试图像(白色背景) test_image = Image.new('RGB', (224, 224), color='white') image_base64 = client.pil_to_base64(test_image) def send_request(idx): """发送单个请求""" start = time.time() result = client.ocr_inference(image_base64, prompt) elapsed = time.time() - start return idx, elapsed, result.get("success", False) print("\n" + "=" * 50) print("3. 并发测试 (10 个并发请求)") print("=" * 50) # 发送 10 个并发请求 with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [executor.submit(send_request, i) for i in range(10)] results = [f.result() for f in concurrent.futures.as_completed(futures)] # 统计结果 success_count = sum(1 for _, _, success in results if success) avg_time = sum(elapsed for _, elapsed, _ in results) / len(results) print(f"总请求数: {len(results)}") print(f"成功请求: {success_count}") print(f"失败请求: {len(results) - success_count}") print(f"平均响应时间: {avg_time:.2f}秒") if __name__ == "__main__": # 基本测试 test_basic() # 并发测试 # test_concurrent()