|
|
@@ -0,0 +1,362 @@
|
|
|
+"""
|
|
|
+企业级 OCR API 服务
|
|
|
+提供基于 FastAPI 的高并发 OCR 推理服务
|
|
|
+"""
|
|
|
+import asyncio
|
|
|
+import base64
|
|
|
+import io
|
|
|
+import logging
|
|
|
+import sys
|
|
|
+from contextlib import asynccontextmanager
|
|
|
+from typing import Optional, Dict, Any
|
|
|
+from datetime import datetime
|
|
|
+
|
|
|
+from fastapi import FastAPI, HTTPException, status
|
|
|
+from fastapi.responses import JSONResponse
|
|
|
+from pydantic import BaseModel, Field, validator
|
|
|
+from PIL import Image
|
|
|
+import uvicorn
|
|
|
+
|
|
|
+from model.qwen_ocr import QwenOcr
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 日志配置 ====================
|
|
|
+logging.basicConfig(
|
|
|
+ level=logging.INFO,
|
|
|
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
|
+ handlers=[
|
|
|
+ logging.StreamHandler(sys.stdout),
|
|
|
+ logging.FileHandler('ocr_api.log', encoding='utf-8')
|
|
|
+ ]
|
|
|
+)
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 请求/响应模型 ====================
|
|
|
+class OCRRequest(BaseModel):
|
|
|
+ """OCR 推理请求模型"""
|
|
|
+ image: str = Field(..., description="Base64 编码的图像字符串")
|
|
|
+ text: str = Field(..., description="OCR 提示词文本")
|
|
|
+
|
|
|
+ @validator('image')
|
|
|
+ def validate_image(cls, v):
|
|
|
+ """验证 base64 图像格式"""
|
|
|
+ if not v:
|
|
|
+ raise ValueError("图像不能为空")
|
|
|
+ try:
|
|
|
+ # 尝试解码验证格式
|
|
|
+ base64.b64decode(v)
|
|
|
+ except Exception:
|
|
|
+ raise ValueError("无效的 base64 图像格式")
|
|
|
+ return v
|
|
|
+
|
|
|
+ @validator('text')
|
|
|
+ def validate_text(cls, v):
|
|
|
+ """验证提示词文本"""
|
|
|
+ if not v or not v.strip():
|
|
|
+ raise ValueError("提示词不能为空")
|
|
|
+ return v.strip()
|
|
|
+
|
|
|
+
|
|
|
+class OCRResponse(BaseModel):
|
|
|
+ """OCR 推理响应模型"""
|
|
|
+ success: bool = Field(..., description="请求是否成功")
|
|
|
+ data: Optional[Any] = Field(None, description="推理结果数据")
|
|
|
+ message: str = Field(..., description="响应消息")
|
|
|
+ timestamp: str = Field(..., description="响应时间戳")
|
|
|
+ request_id: Optional[str] = Field(None, description="请求ID(用于追踪)")
|
|
|
+
|
|
|
+
|
|
|
+class HealthResponse(BaseModel):
|
|
|
+ """健康检查响应模型"""
|
|
|
+ status: str
|
|
|
+ model_loaded: bool
|
|
|
+ timestamp: str
|
|
|
+ concurrent_requests: int
|
|
|
+ max_concurrent: int
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 模型管理器(单例模式) ====================
|
|
|
+class ModelManager:
|
|
|
+ """模型管理器 - 单例模式确保全局只有一个模型实例"""
|
|
|
+
|
|
|
+ _instance: Optional['ModelManager'] = None
|
|
|
+ _lock = asyncio.Lock()
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self.model: Optional[QwenOcr] = None
|
|
|
+ self.is_loaded: bool = False
|
|
|
+ self.semaphore: Optional[asyncio.Semaphore] = None
|
|
|
+ self.max_concurrent_requests: int = 10 # 最大并发请求数
|
|
|
+ self.current_requests: int = 0
|
|
|
+ self._request_lock = asyncio.Lock()
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ async def get_instance(cls) -> 'ModelManager':
|
|
|
+ """获取单例实例(线程安全)"""
|
|
|
+ if cls._instance is None:
|
|
|
+ async with cls._lock:
|
|
|
+ if cls._instance is None:
|
|
|
+ cls._instance = cls()
|
|
|
+ return cls._instance
|
|
|
+
|
|
|
+ async def load_model(self, max_concurrent: int = 5):
|
|
|
+ """
|
|
|
+ 加载模型
|
|
|
+ Args:
|
|
|
+ max_concurrent: 最大并发请求数
|
|
|
+ """
|
|
|
+ if self.is_loaded:
|
|
|
+ logger.warning("模型已经加载,跳过重复加载")
|
|
|
+ return
|
|
|
+
|
|
|
+ try:
|
|
|
+ logger.info("开始加载 QwenOcr 模型...")
|
|
|
+ # 在线程池中加载模型,避免阻塞事件循环
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ self.model = await loop.run_in_executor(None, QwenOcr)
|
|
|
+
|
|
|
+ # 初始化并发控制
|
|
|
+ self.max_concurrent_requests = max_concurrent
|
|
|
+ self.semaphore = asyncio.Semaphore(max_concurrent)
|
|
|
+
|
|
|
+ self.is_loaded = True
|
|
|
+ logger.info(f"模型加载成功! 最大并发数: {max_concurrent}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"模型加载失败: {e}", exc_info=True)
|
|
|
+ raise RuntimeError(f"模型加载失败: {str(e)}")
|
|
|
+
|
|
|
+ async def unload_model(self):
|
|
|
+ """卸载模型并释放资源"""
|
|
|
+ if not self.is_loaded:
|
|
|
+ return
|
|
|
+
|
|
|
+ try:
|
|
|
+ logger.info("开始卸载模型...")
|
|
|
+ # 等待所有正在进行的请求完成
|
|
|
+ while self.current_requests > 0:
|
|
|
+ logger.info(f"等待 {self.current_requests} 个请求完成...")
|
|
|
+ await asyncio.sleep(0.5)
|
|
|
+
|
|
|
+ self.model = None
|
|
|
+ self.semaphore = None
|
|
|
+ self.is_loaded = False
|
|
|
+ logger.info("模型卸载成功")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"模型卸载失败: {e}", exc_info=True)
|
|
|
+
|
|
|
+ def base64_to_pil(self, base64_str: str) -> Image.Image:
|
|
|
+ """
|
|
|
+ 将 base64 字符串转换为 PIL Image
|
|
|
+ Args:
|
|
|
+ base64_str: base64 编码的图像字符串
|
|
|
+ Returns:
|
|
|
+ PIL.Image 对象
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 解码 base64
|
|
|
+ image_data = base64.b64decode(base64_str)
|
|
|
+ # 转换为 PIL Image
|
|
|
+ image = Image.open(io.BytesIO(image_data))
|
|
|
+ # 确保是 RGB 模式
|
|
|
+ if image.mode != 'RGB':
|
|
|
+ image = image.convert('RGB')
|
|
|
+ return image
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Base64 转换失败: {e}")
|
|
|
+ raise ValueError(f"图像解码失败: {str(e)}")
|
|
|
+
|
|
|
+ async def inference(self, image_base64: str, prompt: str) -> list:
|
|
|
+ """
|
|
|
+ 执行 OCR 推理(带并发控制)
|
|
|
+ Args:
|
|
|
+ image_base64: base64 编码的图像
|
|
|
+ prompt: 提示词
|
|
|
+ Returns:
|
|
|
+ 推理结果
|
|
|
+ """
|
|
|
+ if not self.is_loaded or self.model is None:
|
|
|
+ raise RuntimeError("模型未加载")
|
|
|
+
|
|
|
+ # 并发控制
|
|
|
+ async with self.semaphore:
|
|
|
+ async with self._request_lock:
|
|
|
+ self.current_requests += 1
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 转换图像
|
|
|
+ pil_image = self.base64_to_pil(image_base64)
|
|
|
+
|
|
|
+ # 在线程池中执行推理,避免阻塞
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ result = await loop.run_in_executor(
|
|
|
+ None,
|
|
|
+ self.model.inference,
|
|
|
+ pil_image,
|
|
|
+ prompt
|
|
|
+ )
|
|
|
+
|
|
|
+ return result
|
|
|
+ finally:
|
|
|
+ async with self._request_lock:
|
|
|
+ self.current_requests -= 1
|
|
|
+
|
|
|
+ def get_status(self) -> Dict[str, Any]:
|
|
|
+ """获取模型状态"""
|
|
|
+ return {
|
|
|
+ "is_loaded": self.is_loaded,
|
|
|
+ "current_requests": self.current_requests,
|
|
|
+ "max_concurrent": self.max_concurrent_requests
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+# ==================== FastAPI 应用 ====================
|
|
|
+@asynccontextmanager
|
|
|
+async def lifespan(app: FastAPI):
|
|
|
+ """应用生命周期管理"""
|
|
|
+ # 启动时加载模型
|
|
|
+ logger.info("应用启动中...")
|
|
|
+ manager = await ModelManager.get_instance()
|
|
|
+ try:
|
|
|
+ await manager.load_model(max_concurrent=10)
|
|
|
+ logger.info("应用启动完成")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"应用启动失败: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ yield
|
|
|
+
|
|
|
+ # 关闭时卸载模型
|
|
|
+ logger.info("应用关闭中...")
|
|
|
+ await manager.unload_model()
|
|
|
+ logger.info("应用已关闭")
|
|
|
+
|
|
|
+
|
|
|
+# 创建 FastAPI 应用
|
|
|
+app = FastAPI(
|
|
|
+ title="QwenOCR API",
|
|
|
+ description="企业级 OCR 推理服务",
|
|
|
+ version="1.0.0",
|
|
|
+ lifespan=lifespan
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+# ==================== API 端点 ====================
|
|
|
+@app.get("/", response_model=Dict[str, str])
|
|
|
+async def root():
|
|
|
+ """根路径"""
|
|
|
+ return {
|
|
|
+ "message": "QwenOCR API Service",
|
|
|
+ "version": "1.0.0",
|
|
|
+ "docs": "/docs"
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/health", response_model=HealthResponse)
|
|
|
+async def health_check():
|
|
|
+ """健康检查端点"""
|
|
|
+ manager = await ModelManager.get_instance()
|
|
|
+ status_info = manager.get_status()
|
|
|
+
|
|
|
+ return HealthResponse(
|
|
|
+ status="healthy" if status_info["is_loaded"] else "unhealthy",
|
|
|
+ model_loaded=status_info["is_loaded"],
|
|
|
+ timestamp=datetime.now().isoformat(),
|
|
|
+ concurrent_requests=status_info["current_requests"],
|
|
|
+ max_concurrent=status_info["max_concurrent"]
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/api/v1/ocr", response_model=OCRResponse)
|
|
|
+async def ocr_inference(request: OCRRequest):
|
|
|
+ """
|
|
|
+ OCR 推理端点
|
|
|
+
|
|
|
+ Args:
|
|
|
+ request: OCRRequest 对象,包含 image(base64) 和 text(提示词)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ OCRResponse: 推理结果
|
|
|
+ """
|
|
|
+ request_id = f"req_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
|
|
+ logger.info(f"[{request_id}] 收到 OCR 请求")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 获取模型管理器
|
|
|
+ manager = await ModelManager.get_instance()
|
|
|
+
|
|
|
+ if not manager.is_loaded:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
|
+ detail="模型未加载,服务暂不可用"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 执行推理
|
|
|
+ logger.info(f"[{request_id}] 开始推理...")
|
|
|
+ result = await manager.inference(request.image, request.text)
|
|
|
+ logger.info(f"[{request_id}] 推理完成")
|
|
|
+
|
|
|
+ return OCRResponse(
|
|
|
+ success=True,
|
|
|
+ data=result,
|
|
|
+ message="推理成功",
|
|
|
+ timestamp=datetime.now().isoformat(),
|
|
|
+ request_id=request_id
|
|
|
+ )
|
|
|
+
|
|
|
+ except ValueError as e:
|
|
|
+ # 参数验证错误
|
|
|
+ logger.warning(f"[{request_id}] 参数验证失败: {e}")
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ detail=str(e)
|
|
|
+ )
|
|
|
+
|
|
|
+ except RuntimeError as e:
|
|
|
+ # 模型运行时错误
|
|
|
+ logger.error(f"[{request_id}] 运行时错误: {e}")
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
+ detail=f"推理失败: {str(e)}"
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ # 未知错误
|
|
|
+ logger.error(f"[{request_id}] 未知错误: {e}", exc_info=True)
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
+ detail=f"服务器内部错误: {str(e)}"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@app.exception_handler(Exception)
|
|
|
+async def global_exception_handler(request, exc):
|
|
|
+ """全局异常处理器"""
|
|
|
+ logger.error(f"全局异常捕获: {exc}", exc_info=True)
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
+ content={
|
|
|
+ "success": False,
|
|
|
+ "data": None,
|
|
|
+ "message": f"服务器错误: {str(exc)}",
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 主函数 ====================
|
|
|
+def main():
|
|
|
+ """启动服务"""
|
|
|
+ uvicorn.run(
|
|
|
+ "model.model_api:app",
|
|
|
+ host="0.0.0.0",
|
|
|
+ port=8000,
|
|
|
+ workers=1, # 由于模型占用内存大,使用单worker
|
|
|
+ log_level="info",
|
|
|
+ access_log=True,
|
|
|
+ reload=False # 生产环境禁用热重载
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|