| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420 |
- """
- 企业级 OCR API 服务
- 提供基于 FastAPI 的高并发 OCR 推理服务
- """
- import asyncio
- import base64
- import io
- import logging
- import sys
- from contextlib import asynccontextmanager
- from typing import Optional, Dict, Any
- from datetime import datetime
- from fastapi import FastAPI, HTTPException, status
- from fastapi.responses import JSONResponse
- from pydantic import BaseModel, Field, validator
- from PIL import Image
- import uvicorn
- from model import QwenOcr, QwenOcrVLLM
- # ==================== 日志配置 ====================
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.StreamHandler(sys.stdout),
- logging.FileHandler('ocr_api.log', encoding='utf-8')
- ]
- )
- logger = logging.getLogger(__name__)
- # ==================== 请求/响应模型 ====================
- class ImageURL(BaseModel):
- url: str = Field(..., description="data:image/png;base64,... 格式的图像 URL")
- class ContentItem(BaseModel):
- type: str = Field(..., description="内容类型: image_url 或 text")
- image_url: Optional[ImageURL] = None
- text: Optional[str] = None
- class Message(BaseModel):
- role: str = Field(..., description="角色: system 或 user")
- content: Any = Field(..., description="消息内容")
- class OCRRequest(BaseModel):
- """OCR 推理请求模型(OpenAI 兼容格式)"""
- model: Optional[str] = Field(None, description="模型名称")
- messages: list = Field(..., description="消息列表")
- max_tokens: Optional[int] = Field(4096, description="最大生成 token 数")
- stream: Optional[bool] = Field(False, description="是否流式输出")
- temperature: Optional[float] = Field(0, description="采样温度")
- @validator('messages')
- def validate_messages(cls, v):
- if not v:
- raise ValueError("messages 不能为空")
- return v
- def get_image_base64(self) -> str:
- """从 messages 中提取 base64 图像(去掉 data:image/xxx;base64, 前缀)"""
- for msg in self.messages:
- if msg.get('role') != 'user':
- continue
- content = msg.get('content', [])
- if not isinstance(content, list):
- continue
- for item in content:
- if item.get('type') == 'image_url':
- url = item.get('image_url', {}).get('url', '')
- # 去掉 "data:image/png;base64," 前缀
- if ';base64,' in url:
- return url.split(';base64,', 1)[1]
- return url
- raise ValueError("messages 中未找到 image_url")
- def get_prompt(self) -> str:
- """从 messages 中提取用户文本提示词"""
- for msg in self.messages:
- if msg.get('role') != 'user':
- continue
- content = msg.get('content', [])
- if not isinstance(content, list):
- continue
- for item in content:
- if item.get('type') == 'text':
- return item.get('text', '')
- raise ValueError("messages 中未找到 text")
- class ChoiceMessage(BaseModel):
- role: str = "assistant"
- content: Optional[str] = None
- class Choice(BaseModel):
- index: int = 0
- message: ChoiceMessage
- finish_reason: str = "stop"
- class OCRResponse(BaseModel):
- """OCR 推理响应模型(OpenAI 兼容格式)"""
- id: str = Field(..., description="请求ID")
- object: str = Field("chat.completion", description="对象类型")
- model: Optional[str] = Field(None, description="模型名称")
- choices: list = Field(..., description="推理结果列表")
- timestamp: str = Field(..., description="响应时间戳")
- class HealthResponse(BaseModel):
- """健康检查响应模型"""
- status: str
- model_loaded: bool
- timestamp: str
- concurrent_requests: int
- max_concurrent: int
- # ==================== 模型管理器(单例模式) ====================
- class ModelManager:
- """模型管理器 - 单例模式确保全局只有一个模型实例"""
- _instance: Optional['ModelManager'] = None
- _lock = asyncio.Lock()
- def __init__(self):
- self.model: Optional[QwenOcr] = None
- self.is_loaded: bool = False
- self.semaphore: Optional[asyncio.Semaphore] = None
- self.max_concurrent_requests: int = 10 # 最大并发请求数
- self.current_requests: int = 0
- self._request_lock = asyncio.Lock()
- @classmethod
- async def get_instance(cls) -> 'ModelManager':
- """获取单例实例(线程安全)"""
- if cls._instance is None:
- async with cls._lock:
- if cls._instance is None:
- cls._instance = cls()
- return cls._instance
- async def load_model(self, max_concurrent: int = 5):
- """
- 加载模型
- Args:
- max_concurrent: 最大并发请求数
- """
- if self.is_loaded:
- logger.warning("模型已经加载,跳过重复加载")
- return
- try:
- logger.info("开始加载 QwenOcr 模型...")
- # 在线程池中加载模型,避免阻塞事件循环
- loop = asyncio.get_event_loop()
- self.model = await loop.run_in_executor(None, QwenOcrVLLM)
- # 初始化并发控制
- self.max_concurrent_requests = max_concurrent
- self.semaphore = asyncio.Semaphore(max_concurrent)
- self.is_loaded = True
- logger.info(f"模型加载成功! 最大并发数: {max_concurrent}")
- except Exception as e:
- logger.error(f"模型加载失败: {e}", exc_info=True)
- raise RuntimeError(f"模型加载失败: {str(e)}")
- async def unload_model(self):
- """卸载模型并释放资源"""
- if not self.is_loaded:
- return
- try:
- logger.info("开始卸载模型...")
- # 等待所有正在进行的请求完成
- while self.current_requests > 0:
- logger.info(f"等待 {self.current_requests} 个请求完成...")
- await asyncio.sleep(0.5)
- self.model = None
- self.semaphore = None
- self.is_loaded = False
- logger.info("模型卸载成功")
- except Exception as e:
- logger.error(f"模型卸载失败: {e}", exc_info=True)
- def base64_to_pil(self, base64_str: str) -> Image.Image:
- """
- 将 base64 字符串转换为 PIL Image
- Args:
- base64_str: base64 编码的图像字符串
- Returns:
- PIL.Image 对象
- """
- try:
- # 解码 base64
- image_data = base64.b64decode(base64_str)
- # 转换为 PIL Image
- image = Image.open(io.BytesIO(image_data))
- # 确保是 RGB 模式
- if image.mode != 'RGB':
- image = image.convert('RGB')
- return image
- except Exception as e:
- logger.error(f"Base64 转换失败: {e}")
- raise ValueError(f"图像解码失败: {str(e)}")
- async def inference(self, image_base64: str, prompt: str) -> str:
- """
- 执行 OCR 推理(带并发控制)
- Args:
- image_base64: base64 编码的图像(不含 data URI 前缀)
- prompt: 提示词
- Returns:
- 推理结果字符串
- """
- if not self.is_loaded or self.model is None:
- raise RuntimeError("模型未加载")
- # 并发控制
- async with self.semaphore:
- async with self._request_lock:
- self.current_requests += 1
- try:
- # 转换图像
- pil_image = self.base64_to_pil(image_base64)
- # 在线程池中执行推理,避免阻塞
- loop = asyncio.get_event_loop()
- results = await loop.run_in_executor(
- None,
- self.model.batch_inference,
- [pil_image],
- [prompt]
- )
- return results[0]
- finally:
- async with self._request_lock:
- self.current_requests -= 1
- def get_status(self) -> Dict[str, Any]:
- """获取模型状态"""
- return {
- "is_loaded": self.is_loaded,
- "current_requests": self.current_requests,
- "max_concurrent": self.max_concurrent_requests
- }
- # ==================== FastAPI 应用 ====================
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- """应用生命周期管理"""
- # 启动时加载模型
- logger.info("应用启动中...")
- manager = await ModelManager.get_instance()
- try:
- await manager.load_model(max_concurrent=10)
- logger.info("应用启动完成")
- except Exception as e:
- logger.error(f"应用启动失败: {e}")
- raise
- yield
- # 关闭时卸载模型
- logger.info("应用关闭中...")
- await manager.unload_model()
- logger.info("应用已关闭")
- # 创建 FastAPI 应用
- app = FastAPI(
- title="QwenOCR API",
- description="企业级 OCR 推理服务",
- version="1.0.0",
- lifespan=lifespan
- )
- # ==================== API 端点 ====================
- @app.get("/", response_model=Dict[str, str])
- async def root():
- """根路径"""
- return {
- "message": "QwenOCR API Service",
- "version": "1.0.0",
- "docs": "/docs"
- }
- @app.get("/health", response_model=HealthResponse)
- async def health_check():
- """健康检查端点"""
- manager = await ModelManager.get_instance()
- status_info = manager.get_status()
- return HealthResponse(
- status="healthy" if status_info["is_loaded"] else "unhealthy",
- model_loaded=status_info["is_loaded"],
- timestamp=datetime.now().isoformat(),
- concurrent_requests=status_info["current_requests"],
- max_concurrent=status_info["max_concurrent"]
- )
- @app.post("/api/v1/ocr", response_model=OCRResponse)
- async def ocr_inference(request: OCRRequest):
- """
- OCR 推理端点(OpenAI 兼容格式)
- 请求体与 /v1/chat/completions 格式一致:
- {
- "model": "...",
- "messages": [
- {"role": "system", "content": "..."},
- {"role": "user", "content": [
- {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
- {"type": "text", "text": "问题"}
- ]}
- ],
- "max_tokens": 4096,
- "stream": false,
- "temperature": 0
- }
- """
- request_id = f"chatcmpl-{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
- logger.info(f"[{request_id}] 收到 OCR 请求")
- try:
- manager = await ModelManager.get_instance()
- if not manager.is_loaded:
- raise HTTPException(
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail="模型未加载,服务暂不可用"
- )
- # 从 messages 中提取图像和提示词
- image_base64 = request.get_image_base64()
- prompt = request.get_prompt()
- logger.info(f"[{request_id}] 开始推理...")
- content = await manager.inference(image_base64, prompt)
- logger.info(f"[{request_id}] 推理完成")
- return OCRResponse(
- id=request_id,
- model=request.model,
- choices=[{
- "index": 0,
- "message": {"role": "assistant", "content": content},
- "finish_reason": "stop"
- }],
- timestamp=datetime.now().isoformat()
- )
- except ValueError as e:
- logger.warning(f"[{request_id}] 参数验证失败: {e}")
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=str(e)
- )
- except RuntimeError as e:
- logger.error(f"[{request_id}] 运行时错误: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"推理失败: {str(e)}"
- )
- except Exception as e:
- logger.error(f"[{request_id}] 未知错误: {e}", exc_info=True)
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"服务器内部错误: {str(e)}"
- )
- @app.exception_handler(Exception)
- async def global_exception_handler(request, exc):
- """全局异常处理器"""
- logger.error(f"全局异常捕获: {exc}", exc_info=True)
- return JSONResponse(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- content={
- "id": None,
- "object": "chat.completion",
- "choices": [],
- "timestamp": datetime.now().isoformat(),
- "error": {"message": str(exc)}
- }
- )
- # ==================== 主函数 ====================
- def main():
- """启动服务"""
- uvicorn.run(
- "model.model_api:app",
- host="0.0.0.0",
- port=8000,
- workers=1, # 由于模型占用内存大,使用单worker
- log_level="info",
- access_log=True,
- reload=False # 生产环境禁用热重载
- )
- if __name__ == "__main__":
- main()
|