| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- """
- OCR API 客户端测试脚本
- 演示如何调用 OCR API 服务
- """
- import base64
- import json
- import requests
- from pathlib import Path
- from PIL import Image
- import io
- class OCRClient:
- """OCR API 客户端"""
- def __init__(self, base_url: str = "http://localhost:8000"):
- self.base_url = base_url.rstrip("/")
- self.session = requests.Session()
- # 设置超时和重试
- self.timeout = 60
- def image_to_base64(self, image_path: str) -> str:
- """
- 将图像文件转换为 base64 字符串
- Args:
- image_path: 图像文件路径
- Returns:
- base64 编码字符串
- """
- image = Image.open(image_path).convert('RGB')
- buffered = io.BytesIO()
- image.save(buffered, format="JPEG")
- img_bytes = buffered.getvalue()
- return base64.b64encode(img_bytes).decode('utf-8')
- def pil_to_base64(self, pil_image: Image.Image) -> str:
- """
- 将 PIL Image 转换为 base64 字符串
- Args:
- pil_image: PIL Image 对象
- Returns:
- base64 编码字符串
- """
- if pil_image.mode != 'RGB':
- pil_image = pil_image.convert('RGB')
- buffered = io.BytesIO()
- pil_image.save(buffered, format="JPEG")
- img_bytes = buffered.getvalue()
- return base64.b64encode(img_bytes).decode('utf-8')
- def health_check(self) -> dict:
- """
- 健康检查
- Returns:
- 健康状态信息
- """
- url = f"{self.base_url}/health"
- try:
- response = self.session.get(url, timeout=5)
- response.raise_for_status()
- return response.json()
- except Exception as e:
- return {"error": str(e)}
- def ocr_inference(self, image_base64: str, prompt: str) -> dict:
- """
- 执行 OCR 推理
- Args:
- image_base64: base64 编码的图像
- prompt: 提示词
- Returns:
- 推理结果
- """
- url = f"{self.base_url}/api/v1/ocr"
- payload = {
- "image": image_base64,
- "text": prompt
- }
- try:
- response = self.session.post(
- url,
- json=payload,
- timeout=self.timeout
- )
- response.raise_for_status()
- return response.json()
- except requests.exceptions.HTTPError as e:
- return {
- "success": False,
- "error": f"HTTP Error: {e.response.status_code}",
- "detail": e.response.text
- }
- except Exception as e:
- return {
- "success": False,
- "error": str(e)
- }
- def ocr_from_file(self, image_path: str, prompt: str) -> dict:
- """
- 从文件执行 OCR 推理
- Args:
- image_path: 图像文件路径
- prompt: 提示词
- Returns:
- 推理结果
- """
- image_base64 = self.image_to_base64(image_path)
- return self.ocr_inference(image_base64, prompt)
- # ==================== 测试示例 ====================
- def test_basic():
- """基本测试"""
- # 创建客户端
- client = OCRClient("http://localhost:8000")
- # 1. 健康检查
- print("=" * 50)
- print("1. 健康检查")
- print("=" * 50)
- health = client.health_check()
- print(json.dumps(health, indent=2, ensure_ascii=False))
- # 2. OCR 推理测试
- print("\n" + "=" * 50)
- print("2. OCR 推理测试")
- print("=" * 50)
- # 示例提示词
- prompt = """
- 你是一个专业的化学品安全标签说明识别助手。
- 请从图像中提取化学品的中文名称和英文名称(如有)。
- 按照以下JSON格式输出结果:
- {
- "name_cn": "化学品中文名称",
- "name_en": "化学品英文名称"
- }
- 注意:返回结果必须是标准JSON格式,不要包含```json```标记。
- """
- # 替换为实际的图像路径
- image_path = "./test3.jpg"
- if Path(image_path).exists():
- result = client.ocr_from_file(image_path, prompt)
- print(json.dumps(result, indent=2, ensure_ascii=False))
- else:
- print(f"图像文件不存在: {image_path}")
- print("请将测试图像放在当前目录下")
- def test_concurrent():
- """并发测试"""
- import concurrent.futures
- import time
- client = OCRClient("http://localhost:8000")
- prompt = "提取图像中的文字信息"
- # 创建一个测试图像(白色背景)
- test_image = Image.new('RGB', (224, 224), color='white')
- image_base64 = client.pil_to_base64(test_image)
- def send_request(idx):
- """发送单个请求"""
- start = time.time()
- result = client.ocr_inference(image_base64, prompt)
- elapsed = time.time() - start
- return idx, elapsed, result.get("success", False)
- print("\n" + "=" * 50)
- print("3. 并发测试 (10 个并发请求)")
- print("=" * 50)
- # 发送 10 个并发请求
- with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
- futures = [executor.submit(send_request, i) for i in range(10)]
- results = [f.result() for f in concurrent.futures.as_completed(futures)]
- # 统计结果
- success_count = sum(1 for _, _, success in results if success)
- avg_time = sum(elapsed for _, elapsed, _ in results) / len(results)
- print(f"总请求数: {len(results)}")
- print(f"成功请求: {success_count}")
- print(f"失败请求: {len(results) - success_count}")
- print(f"平均响应时间: {avg_time:.2f}秒")
- if __name__ == "__main__":
- # 基本测试
- test_basic()
- # 并发测试
- # test_concurrent()
|