| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362 |
- """
- 企业级 OCR API 服务
- 提供基于 FastAPI 的高并发 OCR 推理服务
- """
- import asyncio
- import base64
- import io
- import logging
- import sys
- from contextlib import asynccontextmanager
- from typing import Optional, Dict, Any
- from datetime import datetime
- from fastapi import FastAPI, HTTPException, status
- from fastapi.responses import JSONResponse
- from pydantic import BaseModel, Field, validator
- from PIL import Image
- import uvicorn
- from model import QwenOcr, QwenOcrVLLM
- # ==================== 日志配置 ====================
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.StreamHandler(sys.stdout),
- logging.FileHandler('ocr_api.log', encoding='utf-8')
- ]
- )
- logger = logging.getLogger(__name__)
- # ==================== 请求/响应模型 ====================
- class OCRRequest(BaseModel):
- """OCR 推理请求模型"""
- image: str = Field(..., description="Base64 编码的图像字符串")
- text: list = Field(..., description="OCR 提示词文本列表")
- @validator('image')
- def validate_image(cls, v):
- """验证 base64 图像格式"""
- if not v:
- raise ValueError("图像不能为空")
- try:
- # 尝试解码验证格式
- base64.b64decode(v)
- except Exception:
- raise ValueError("无效的 base64 图像格式")
- return v
- @validator('text')
- def validate_text(cls, v):
- """验证提示词文本"""
- if not v:
- raise ValueError("提示词不能为空")
- return v
- class OCRResponse(BaseModel):
- """OCR 推理响应模型"""
- success: bool = Field(..., description="请求是否成功")
- data: Optional[Any] = Field(None, description="推理结果数据")
- message: str = Field(..., description="响应消息")
- timestamp: str = Field(..., description="响应时间戳")
- request_id: Optional[str] = Field(None, description="请求ID(用于追踪)")
- class HealthResponse(BaseModel):
- """健康检查响应模型"""
- status: str
- model_loaded: bool
- timestamp: str
- concurrent_requests: int
- max_concurrent: int
- # ==================== 模型管理器(单例模式) ====================
- class ModelManager:
- """模型管理器 - 单例模式确保全局只有一个模型实例"""
- _instance: Optional['ModelManager'] = None
- _lock = asyncio.Lock()
- def __init__(self):
- self.model: Optional[QwenOcr] = None
- self.is_loaded: bool = False
- self.semaphore: Optional[asyncio.Semaphore] = None
- self.max_concurrent_requests: int = 10 # 最大并发请求数
- self.current_requests: int = 0
- self._request_lock = asyncio.Lock()
- @classmethod
- async def get_instance(cls) -> 'ModelManager':
- """获取单例实例(线程安全)"""
- if cls._instance is None:
- async with cls._lock:
- if cls._instance is None:
- cls._instance = cls()
- return cls._instance
- async def load_model(self, max_concurrent: int = 5):
- """
- 加载模型
- Args:
- max_concurrent: 最大并发请求数
- """
- if self.is_loaded:
- logger.warning("模型已经加载,跳过重复加载")
- return
- try:
- logger.info("开始加载 QwenOcr 模型...")
- # 在线程池中加载模型,避免阻塞事件循环
- loop = asyncio.get_event_loop()
- self.model = await loop.run_in_executor(None, QwenOcrVLLM)
- # 初始化并发控制
- self.max_concurrent_requests = max_concurrent
- self.semaphore = asyncio.Semaphore(max_concurrent)
- self.is_loaded = True
- logger.info(f"模型加载成功! 最大并发数: {max_concurrent}")
- except Exception as e:
- logger.error(f"模型加载失败: {e}", exc_info=True)
- raise RuntimeError(f"模型加载失败: {str(e)}")
- async def unload_model(self):
- """卸载模型并释放资源"""
- if not self.is_loaded:
- return
- try:
- logger.info("开始卸载模型...")
- # 等待所有正在进行的请求完成
- while self.current_requests > 0:
- logger.info(f"等待 {self.current_requests} 个请求完成...")
- await asyncio.sleep(0.5)
- self.model = None
- self.semaphore = None
- self.is_loaded = False
- logger.info("模型卸载成功")
- except Exception as e:
- logger.error(f"模型卸载失败: {e}", exc_info=True)
- def base64_to_pil(self, base64_str: str) -> Image.Image:
- """
- 将 base64 字符串转换为 PIL Image
- Args:
- base64_str: base64 编码的图像字符串
- Returns:
- PIL.Image 对象
- """
- try:
- # 解码 base64
- image_data = base64.b64decode(base64_str)
- # 转换为 PIL Image
- image = Image.open(io.BytesIO(image_data))
- # 确保是 RGB 模式
- if image.mode != 'RGB':
- image = image.convert('RGB')
- return image
- except Exception as e:
- logger.error(f"Base64 转换失败: {e}")
- raise ValueError(f"图像解码失败: {str(e)}")
- async def inference(self, image_base64: str, prompts: str) -> list:
- """
- 执行 OCR 推理(带并发控制)
- Args:
- image_base64: base64 编码的图像
- prompts: 提示词
- Returns:
- 推理结果
- """
- if not self.is_loaded or self.model is None:
- raise RuntimeError("模型未加载")
- # 并发控制
- async with self.semaphore:
- async with self._request_lock:
- self.current_requests += 1
- try:
- # 转换图像
- pil_image = self.base64_to_pil(image_base64)
- # 在线程池中执行推理,避免阻塞
- loop = asyncio.get_event_loop()
- result = await loop.run_in_executor(
- None,
- self.model.batch_inference,
- [pil_image] * len(prompts),
- prompts
- )
- return result
- finally:
- async with self._request_lock:
- self.current_requests -= 1
- def get_status(self) -> Dict[str, Any]:
- """获取模型状态"""
- return {
- "is_loaded": self.is_loaded,
- "current_requests": self.current_requests,
- "max_concurrent": self.max_concurrent_requests
- }
- # ==================== FastAPI 应用 ====================
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- """应用生命周期管理"""
- # 启动时加载模型
- logger.info("应用启动中...")
- manager = await ModelManager.get_instance()
- try:
- await manager.load_model(max_concurrent=10)
- logger.info("应用启动完成")
- except Exception as e:
- logger.error(f"应用启动失败: {e}")
- raise
- yield
- # 关闭时卸载模型
- logger.info("应用关闭中...")
- await manager.unload_model()
- logger.info("应用已关闭")
- # 创建 FastAPI 应用
- app = FastAPI(
- title="QwenOCR API",
- description="企业级 OCR 推理服务",
- version="1.0.0",
- lifespan=lifespan
- )
- # ==================== API 端点 ====================
- @app.get("/", response_model=Dict[str, str])
- async def root():
- """根路径"""
- return {
- "message": "QwenOCR API Service",
- "version": "1.0.0",
- "docs": "/docs"
- }
- @app.get("/health", response_model=HealthResponse)
- async def health_check():
- """健康检查端点"""
- manager = await ModelManager.get_instance()
- status_info = manager.get_status()
- return HealthResponse(
- status="healthy" if status_info["is_loaded"] else "unhealthy",
- model_loaded=status_info["is_loaded"],
- timestamp=datetime.now().isoformat(),
- concurrent_requests=status_info["current_requests"],
- max_concurrent=status_info["max_concurrent"]
- )
- @app.post("/api/v1/ocr", response_model=OCRResponse)
- async def ocr_inference(request: OCRRequest):
- """
- OCR 推理端点
- Args:
- request: OCRRequest 对象,包含 image(base64) 和 text(提示词)
- Returns:
- OCRResponse: 推理结果
- """
- request_id = f"req_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
- logger.info(f"[{request_id}] 收到 OCR 请求")
- try:
- # 获取模型管理器
- manager = await ModelManager.get_instance()
- if not manager.is_loaded:
- raise HTTPException(
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail="模型未加载,服务暂不可用"
- )
- # 执行推理
- logger.info(f"[{request_id}] 开始推理...")
- result = await manager.inference(request.image, request.text)
- logger.info(f"[{request_id}] 推理完成")
- return OCRResponse(
- success=True,
- data=result,
- message="推理成功",
- timestamp=datetime.now().isoformat(),
- request_id=request_id
- )
- except ValueError as e:
- # 参数验证错误
- logger.warning(f"[{request_id}] 参数验证失败: {e}")
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=str(e)
- )
- except RuntimeError as e:
- # 模型运行时错误
- logger.error(f"[{request_id}] 运行时错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"推理失败: {str(e)}"
- )
- except Exception as e:
- # 未知错误
- logger.error(f"[{request_id}] 未知错误: {e}", exc_info=True)
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"服务器内部错误: {str(e)}"
- )
- @app.exception_handler(Exception)
- async def global_exception_handler(request, exc):
- """全局异常处理器"""
- logger.error(f"全局异常捕获: {exc}", exc_info=True)
- return JSONResponse(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- content={
- "success": False,
- "data": None,
- "message": f"服务器错误: {str(exc)}",
- "timestamp": datetime.now().isoformat()
- }
- )
- # ==================== 主函数 ====================
- def main():
- """启动服务"""
- uvicorn.run(
- "model.model_api:app",
- host="0.0.0.0",
- port=8000,
- workers=1, # 由于模型占用内存大,使用单worker
- log_level="info",
- access_log=True,
- reload=False # 生产环境禁用热重载
- )
- if __name__ == "__main__":
- main()
|