model_api.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. """
  2. 企业级 OCR API 服务
  3. 提供基于 FastAPI 的高并发 OCR 推理服务
  4. """
  5. import asyncio
  6. import base64
  7. import io
  8. import logging
  9. import sys
  10. from contextlib import asynccontextmanager
  11. from typing import Optional, Dict, Any
  12. from datetime import datetime
  13. from fastapi import FastAPI, HTTPException, status
  14. from fastapi.responses import JSONResponse
  15. from pydantic import BaseModel, Field, validator
  16. from PIL import Image
  17. import uvicorn
  18. from model import QwenOcr, QwenOcrVLLM
  19. # ==================== 日志配置 ====================
  20. logging.basicConfig(
  21. level=logging.INFO,
  22. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  23. handlers=[
  24. logging.StreamHandler(sys.stdout),
  25. logging.FileHandler('ocr_api.log', encoding='utf-8')
  26. ]
  27. )
  28. logger = logging.getLogger(__name__)
  29. # ==================== 请求/响应模型 ====================
  30. class ImageURL(BaseModel):
  31. url: str = Field(..., description="data:image/png;base64,... 格式的图像 URL")
  32. class ContentItem(BaseModel):
  33. type: str = Field(..., description="内容类型: image_url 或 text")
  34. image_url: Optional[ImageURL] = None
  35. text: Optional[str] = None
  36. class Message(BaseModel):
  37. role: str = Field(..., description="角色: system 或 user")
  38. content: Any = Field(..., description="消息内容")
  39. class OCRRequest(BaseModel):
  40. """OCR 推理请求模型(OpenAI 兼容格式)"""
  41. model: Optional[str] = Field(None, description="模型名称")
  42. messages: list = Field(..., description="消息列表")
  43. max_tokens: Optional[int] = Field(4096, description="最大生成 token 数")
  44. stream: Optional[bool] = Field(False, description="是否流式输出")
  45. temperature: Optional[float] = Field(0, description="采样温度")
  46. @validator('messages')
  47. def validate_messages(cls, v):
  48. if not v:
  49. raise ValueError("messages 不能为空")
  50. return v
  51. def get_image_base64(self) -> str:
  52. """从 messages 中提取 base64 图像(去掉 data:image/xxx;base64, 前缀)"""
  53. for msg in self.messages:
  54. if msg.get('role') != 'user':
  55. continue
  56. content = msg.get('content', [])
  57. if not isinstance(content, list):
  58. continue
  59. for item in content:
  60. if item.get('type') == 'image_url':
  61. url = item.get('image_url', {}).get('url', '')
  62. # 去掉 "data:image/png;base64," 前缀
  63. if ';base64,' in url:
  64. return url.split(';base64,', 1)[1]
  65. return url
  66. raise ValueError("messages 中未找到 image_url")
  67. def get_prompt(self) -> str:
  68. """从 messages 中提取用户文本提示词"""
  69. for msg in self.messages:
  70. if msg.get('role') != 'user':
  71. continue
  72. content = msg.get('content', [])
  73. if not isinstance(content, list):
  74. continue
  75. for item in content:
  76. if item.get('type') == 'text':
  77. return item.get('text', '')
  78. raise ValueError("messages 中未找到 text")
  79. class ChoiceMessage(BaseModel):
  80. role: str = "assistant"
  81. content: Optional[str] = None
  82. class Choice(BaseModel):
  83. index: int = 0
  84. message: ChoiceMessage
  85. finish_reason: str = "stop"
  86. class OCRResponse(BaseModel):
  87. """OCR 推理响应模型(OpenAI 兼容格式)"""
  88. id: str = Field(..., description="请求ID")
  89. object: str = Field("chat.completion", description="对象类型")
  90. model: Optional[str] = Field(None, description="模型名称")
  91. choices: list = Field(..., description="推理结果列表")
  92. timestamp: str = Field(..., description="响应时间戳")
  93. class HealthResponse(BaseModel):
  94. """健康检查响应模型"""
  95. status: str
  96. model_loaded: bool
  97. timestamp: str
  98. concurrent_requests: int
  99. max_concurrent: int
  100. # ==================== 模型管理器(单例模式) ====================
  101. class ModelManager:
  102. """模型管理器 - 单例模式确保全局只有一个模型实例"""
  103. _instance: Optional['ModelManager'] = None
  104. _lock = asyncio.Lock()
  105. def __init__(self):
  106. self.model: Optional[QwenOcr] = None
  107. self.is_loaded: bool = False
  108. self.semaphore: Optional[asyncio.Semaphore] = None
  109. self.max_concurrent_requests: int = 10 # 最大并发请求数
  110. self.current_requests: int = 0
  111. self._request_lock = asyncio.Lock()
  112. @classmethod
  113. async def get_instance(cls) -> 'ModelManager':
  114. """获取单例实例(线程安全)"""
  115. if cls._instance is None:
  116. async with cls._lock:
  117. if cls._instance is None:
  118. cls._instance = cls()
  119. return cls._instance
  120. async def load_model(self, max_concurrent: int = 5):
  121. """
  122. 加载模型
  123. Args:
  124. max_concurrent: 最大并发请求数
  125. """
  126. if self.is_loaded:
  127. logger.warning("模型已经加载,跳过重复加载")
  128. return
  129. try:
  130. logger.info("开始加载 QwenOcr 模型...")
  131. # 在线程池中加载模型,避免阻塞事件循环
  132. loop = asyncio.get_event_loop()
  133. self.model = await loop.run_in_executor(None, QwenOcrVLLM)
  134. # 初始化并发控制
  135. self.max_concurrent_requests = max_concurrent
  136. self.semaphore = asyncio.Semaphore(max_concurrent)
  137. self.is_loaded = True
  138. logger.info(f"模型加载成功! 最大并发数: {max_concurrent}")
  139. except Exception as e:
  140. logger.error(f"模型加载失败: {e}", exc_info=True)
  141. raise RuntimeError(f"模型加载失败: {str(e)}")
  142. async def unload_model(self):
  143. """卸载模型并释放资源"""
  144. if not self.is_loaded:
  145. return
  146. try:
  147. logger.info("开始卸载模型...")
  148. # 等待所有正在进行的请求完成
  149. while self.current_requests > 0:
  150. logger.info(f"等待 {self.current_requests} 个请求完成...")
  151. await asyncio.sleep(0.5)
  152. self.model = None
  153. self.semaphore = None
  154. self.is_loaded = False
  155. logger.info("模型卸载成功")
  156. except Exception as e:
  157. logger.error(f"模型卸载失败: {e}", exc_info=True)
  158. def base64_to_pil(self, base64_str: str) -> Image.Image:
  159. """
  160. 将 base64 字符串转换为 PIL Image
  161. Args:
  162. base64_str: base64 编码的图像字符串
  163. Returns:
  164. PIL.Image 对象
  165. """
  166. try:
  167. # 解码 base64
  168. image_data = base64.b64decode(base64_str)
  169. # 转换为 PIL Image
  170. image = Image.open(io.BytesIO(image_data))
  171. # 确保是 RGB 模式
  172. if image.mode != 'RGB':
  173. image = image.convert('RGB')
  174. return image
  175. except Exception as e:
  176. logger.error(f"Base64 转换失败: {e}")
  177. raise ValueError(f"图像解码失败: {str(e)}")
  178. async def inference(self, image_base64: str, prompt: str) -> str:
  179. """
  180. 执行 OCR 推理(带并发控制)
  181. Args:
  182. image_base64: base64 编码的图像(不含 data URI 前缀)
  183. prompt: 提示词
  184. Returns:
  185. 推理结果字符串
  186. """
  187. if not self.is_loaded or self.model is None:
  188. raise RuntimeError("模型未加载")
  189. # 并发控制
  190. async with self.semaphore:
  191. async with self._request_lock:
  192. self.current_requests += 1
  193. try:
  194. # 转换图像
  195. pil_image = self.base64_to_pil(image_base64)
  196. # 在线程池中执行推理,避免阻塞
  197. loop = asyncio.get_event_loop()
  198. results = await loop.run_in_executor(
  199. None,
  200. self.model.batch_inference,
  201. [pil_image],
  202. [prompt]
  203. )
  204. return results[0]
  205. finally:
  206. async with self._request_lock:
  207. self.current_requests -= 1
  208. def get_status(self) -> Dict[str, Any]:
  209. """获取模型状态"""
  210. return {
  211. "is_loaded": self.is_loaded,
  212. "current_requests": self.current_requests,
  213. "max_concurrent": self.max_concurrent_requests
  214. }
  215. # ==================== FastAPI 应用 ====================
  216. @asynccontextmanager
  217. async def lifespan(app: FastAPI):
  218. """应用生命周期管理"""
  219. # 启动时加载模型
  220. logger.info("应用启动中...")
  221. manager = await ModelManager.get_instance()
  222. try:
  223. await manager.load_model(max_concurrent=10)
  224. logger.info("应用启动完成")
  225. except Exception as e:
  226. logger.error(f"应用启动失败: {e}")
  227. raise
  228. yield
  229. # 关闭时卸载模型
  230. logger.info("应用关闭中...")
  231. await manager.unload_model()
  232. logger.info("应用已关闭")
  233. # 创建 FastAPI 应用
  234. app = FastAPI(
  235. title="QwenOCR API",
  236. description="企业级 OCR 推理服务",
  237. version="1.0.0",
  238. lifespan=lifespan
  239. )
  240. # ==================== API 端点 ====================
  241. @app.get("/", response_model=Dict[str, str])
  242. async def root():
  243. """根路径"""
  244. return {
  245. "message": "QwenOCR API Service",
  246. "version": "1.0.0",
  247. "docs": "/docs"
  248. }
  249. @app.get("/health", response_model=HealthResponse)
  250. async def health_check():
  251. """健康检查端点"""
  252. manager = await ModelManager.get_instance()
  253. status_info = manager.get_status()
  254. return HealthResponse(
  255. status="healthy" if status_info["is_loaded"] else "unhealthy",
  256. model_loaded=status_info["is_loaded"],
  257. timestamp=datetime.now().isoformat(),
  258. concurrent_requests=status_info["current_requests"],
  259. max_concurrent=status_info["max_concurrent"]
  260. )
  261. @app.post("/api/v1/ocr", response_model=OCRResponse)
  262. async def ocr_inference(request: OCRRequest):
  263. """
  264. OCR 推理端点(OpenAI 兼容格式)
  265. 请求体与 /v1/chat/completions 格式一致:
  266. {
  267. "model": "...",
  268. "messages": [
  269. {"role": "system", "content": "..."},
  270. {"role": "user", "content": [
  271. {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
  272. {"type": "text", "text": "问题"}
  273. ]}
  274. ],
  275. "max_tokens": 4096,
  276. "stream": false,
  277. "temperature": 0
  278. }
  279. """
  280. request_id = f"chatcmpl-{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
  281. logger.info(f"[{request_id}] 收到 OCR 请求")
  282. try:
  283. manager = await ModelManager.get_instance()
  284. if not manager.is_loaded:
  285. raise HTTPException(
  286. status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  287. detail="模型未加载,服务暂不可用"
  288. )
  289. # 从 messages 中提取图像和提示词
  290. image_base64 = request.get_image_base64()
  291. prompt = request.get_prompt()
  292. logger.info(f"[{request_id}] 开始推理...")
  293. content = await manager.inference(image_base64, prompt)
  294. logger.info(f"[{request_id}] 推理完成")
  295. return OCRResponse(
  296. id=request_id,
  297. model=request.model,
  298. choices=[{
  299. "index": 0,
  300. "message": {"role": "assistant", "content": content},
  301. "finish_reason": "stop"
  302. }],
  303. timestamp=datetime.now().isoformat()
  304. )
  305. except ValueError as e:
  306. logger.warning(f"[{request_id}] 参数验证失败: {e}")
  307. raise HTTPException(
  308. status_code=status.HTTP_400_BAD_REQUEST,
  309. detail=str(e)
  310. )
  311. except RuntimeError as e:
  312. logger.error(f"[{request_id}] 运行时错误: {e}")
  313. raise HTTPException(
  314. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  315. detail=f"推理失败: {str(e)}"
  316. )
  317. except Exception as e:
  318. logger.error(f"[{request_id}] 未知错误: {e}", exc_info=True)
  319. raise HTTPException(
  320. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  321. detail=f"服务器内部错误: {str(e)}"
  322. )
  323. @app.exception_handler(Exception)
  324. async def global_exception_handler(request, exc):
  325. """全局异常处理器"""
  326. logger.error(f"全局异常捕获: {exc}", exc_info=True)
  327. return JSONResponse(
  328. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  329. content={
  330. "id": None,
  331. "object": "chat.completion",
  332. "choices": [],
  333. "timestamp": datetime.now().isoformat(),
  334. "error": {"message": str(exc)}
  335. }
  336. )
  337. # ==================== 主函数 ====================
  338. def main():
  339. """启动服务"""
  340. uvicorn.run(
  341. "model.model_api:app",
  342. host="0.0.0.0",
  343. port=8000,
  344. workers=1, # 由于模型占用内存大,使用单worker
  345. log_level="info",
  346. access_log=True,
  347. reload=False # 生产环境禁用热重载
  348. )
  349. if __name__ == "__main__":
  350. main()