run_api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. """
  2. 企业级 Agent OCR API 服务
  3. 提供基于 FastAPI 的高并发化学品安全标签信息提取服务
  4. """
  5. import asyncio
  6. import base64
  7. import io
  8. import logging
  9. import os
  10. import sys
  11. from contextlib import asynccontextmanager
  12. from typing import Optional, Dict, Any
  13. from datetime import datetime
  14. from fastapi import FastAPI, HTTPException, status
  15. from fastapi.responses import JSONResponse
  16. from pydantic import BaseModel, Field, field_validator
  17. from PIL import Image
  18. import uvicorn
  19. from agent import OcrAgent
  20. # ==================== 日志配置 ====================
  21. logging.basicConfig(
  22. level=logging.INFO,
  23. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  24. handlers=[
  25. logging.StreamHandler(sys.stdout),
  26. logging.FileHandler('agent_ocr_api.log', encoding='utf-8')
  27. ]
  28. )
  29. logger = logging.getLogger(__name__)
  30. # ==================== 请求/响应模型 ====================
  31. class AgentOCRRequest(BaseModel):
  32. """Agent OCR 请求模型"""
  33. image: str = Field(..., description="Base64 编码的图像字符串")
  34. @field_validator('image')
  35. @classmethod
  36. def validate_image(cls, v):
  37. """验证 base64 图像格式"""
  38. if not v:
  39. raise ValueError("图像不能为空")
  40. try:
  41. # 尝试解码验证格式
  42. base64.b64decode(v)
  43. except Exception:
  44. raise ValueError("无效的 base64 图像格式")
  45. return v
  46. class SuccessResponse(BaseModel):
  47. """成功响应模型"""
  48. code: str = Field(default="200", description="响应代码")
  49. data: Dict[str, Any] = Field(..., description="提取的化学品标签信息")
  50. message: str = Field(default="操作成功", description="响应消息")
  51. class ErrorResponse(BaseModel):
  52. """错误响应模型"""
  53. code: str = Field(default="500", description="错误代码")
  54. data: Dict[str, Any] = Field(default_factory=dict, description="空数据")
  55. message: str = Field(default="请求失败", description="错误消息")
  56. class HealthResponse(BaseModel):
  57. """健康检查响应模型"""
  58. status: str
  59. agent_loaded: bool
  60. timestamp: str
  61. concurrent_requests: int
  62. max_concurrent: int
  63. # ==================== Agent 管理器(单例模式) ====================
  64. class AgentManager:
  65. """Agent 管理器 - 单例模式确保全局只有一个 Agent 实例"""
  66. _instance: Optional['AgentManager'] = None
  67. _lock = asyncio.Lock()
  68. def __init__(self):
  69. self.agent: Optional[OcrAgent] = None
  70. self.is_loaded: bool = False
  71. self.semaphore: Optional[asyncio.Semaphore] = None
  72. self.max_concurrent_requests: int = 10 # 最大并发请求数
  73. self.current_requests: int = 0
  74. self._request_lock = asyncio.Lock()
  75. @classmethod
  76. async def get_instance(cls) -> 'AgentManager':
  77. """获取单例实例(线程安全)"""
  78. if cls._instance is None:
  79. async with cls._lock:
  80. if cls._instance is None:
  81. cls._instance = cls()
  82. return cls._instance
  83. async def load_agent(self, max_concurrent: int = 5):
  84. """
  85. 加载 Agent
  86. Args:
  87. max_concurrent: 最大并发请求数
  88. """
  89. if self.is_loaded:
  90. logger.warning("Agent 已经加载,跳过重复加载")
  91. return
  92. try:
  93. logger.info("开始加载 OcrAgent...")
  94. # 在线程池中加载 Agent,避免阻塞事件循环
  95. loop = asyncio.get_event_loop()
  96. self.agent = await loop.run_in_executor(None, OcrAgent)
  97. # 初始化并发控制
  98. self.max_concurrent_requests = max_concurrent
  99. self.semaphore = asyncio.Semaphore(max_concurrent)
  100. self.is_loaded = True
  101. logger.info(f"Agent 加载成功! 最大并发数: {max_concurrent}")
  102. except Exception as e:
  103. logger.error(f"Agent 加载失败: {e}", exc_info=True)
  104. raise RuntimeError(f"Agent 加载失败: {str(e)}")
  105. async def unload_agent(self):
  106. """卸载 Agent 并释放资源"""
  107. if not self.is_loaded:
  108. return
  109. try:
  110. logger.info("开始卸载 Agent...")
  111. # 等待所有正在进行的请求完成
  112. while self.current_requests > 0:
  113. logger.info(f"等待 {self.current_requests} 个请求完成...")
  114. await asyncio.sleep(0.5)
  115. self.agent = None
  116. self.semaphore = None
  117. self.is_loaded = False
  118. logger.info("Agent 卸载成功")
  119. except Exception as e:
  120. logger.error(f"Agent 卸载失败: {e}", exc_info=True)
  121. def base64_to_pil(self, base64_str: str) -> Image.Image:
  122. """
  123. 将 base64 字符串转换为 PIL Image
  124. Args:
  125. base64_str: base64 编码的图像字符串
  126. Returns:
  127. PIL.Image 对象
  128. """
  129. try:
  130. # 解码 base64
  131. image_data = base64.b64decode(base64_str)
  132. # 转换为 PIL Image
  133. image = Image.open(io.BytesIO(image_data))
  134. # 确保是 RGB 模式
  135. if image.mode != 'RGB':
  136. image = image.convert('RGB')
  137. return image
  138. except Exception as e:
  139. logger.error(f"Base64 转换失败: {e}")
  140. raise ValueError(f"图像解码失败: {str(e)}")
  141. async def process_ocr(self, image_base64: str) -> Dict[str, Any]:
  142. """
  143. 执行 Agent OCR 处理(带并发控制)
  144. Args:
  145. image_base64: base64 编码的图像
  146. Returns:
  147. 化学品标签信息提取结果
  148. """
  149. if not self.is_loaded or self.agent is None:
  150. raise RuntimeError("Agent 未加载")
  151. # 并发控制
  152. async with self.semaphore:
  153. async with self._request_lock:
  154. self.current_requests += 1
  155. try:
  156. # 转换图像
  157. pil_image = self.base64_to_pil(image_base64)
  158. # 在线程池中执行 agent_ocr,避免阻塞
  159. loop = asyncio.get_event_loop()
  160. result = await loop.run_in_executor(
  161. None,
  162. self.agent.agent_ocr,
  163. pil_image
  164. )
  165. return result
  166. finally:
  167. async with self._request_lock:
  168. self.current_requests -= 1
  169. def get_status(self) -> Dict[str, Any]:
  170. """获取 Agent 状态"""
  171. return {
  172. "is_loaded": self.is_loaded,
  173. "current_requests": self.current_requests,
  174. "max_concurrent": self.max_concurrent_requests
  175. }
  176. # ==================== FastAPI 应用 ====================
  177. @asynccontextmanager
  178. async def lifespan(app: FastAPI):
  179. """应用生命周期管理"""
  180. # 启动时加载 Agent
  181. logger.info("应用启动中...")
  182. manager = await AgentManager.get_instance()
  183. try:
  184. max_concurrent = int(os.environ.get("MAX_CONCURRENT", 5))
  185. await manager.load_agent(max_concurrent=max_concurrent)
  186. logger.info("应用启动完成")
  187. except Exception as e:
  188. logger.error(f"应用启动失败: {e}")
  189. raise
  190. yield
  191. # 关闭时卸载 Agent
  192. logger.info("应用关闭中...")
  193. await manager.unload_agent()
  194. logger.info("应用已关闭")
  195. # 创建 FastAPI 应用
  196. app = FastAPI(
  197. title="Agent OCR API",
  198. description="企业级化学品安全标签信息提取服务",
  199. version="1.0.0",
  200. lifespan=lifespan
  201. )
  202. # ==================== API 端点 ====================
  203. @app.get("/", response_model=Dict[str, str])
  204. async def root():
  205. """根路径"""
  206. return {
  207. "message": "Agent OCR API Service",
  208. "version": "1.0.0",
  209. "docs": "/docs"
  210. }
  211. @app.get("/health", response_model=HealthResponse)
  212. async def health_check():
  213. """健康检查端点"""
  214. manager = await AgentManager.get_instance()
  215. status_info = manager.get_status()
  216. return HealthResponse(
  217. status="healthy" if status_info["is_loaded"] else "unhealthy",
  218. agent_loaded=status_info["is_loaded"],
  219. timestamp=datetime.now().isoformat(),
  220. concurrent_requests=status_info["current_requests"],
  221. max_concurrent=status_info["max_concurrent"]
  222. )
  223. @app.post("/api/v1/agent_ocr")
  224. async def agent_ocr_endpoint(request: AgentOCRRequest):
  225. """
  226. Agent OCR 化学品标签信息提取端点
  227. Args:
  228. request: AgentOCRRequest 对象,包含 image(base64 编码的图像)
  229. Returns:
  230. 成功: {"code": "200", "data": {...}, "message": "操作成功"}
  231. 失败: {"code": "500", "data": {}, "message": "请求失败"}
  232. """
  233. request_id = f"req_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
  234. logger.info(f"[{request_id}] 收到 Agent OCR 请求")
  235. try:
  236. # 获取 Agent 管理器
  237. manager = await AgentManager.get_instance()
  238. if not manager.is_loaded:
  239. logger.error(f"[{request_id}] Agent 未加载")
  240. return ErrorResponse(
  241. code="500",
  242. data={},
  243. message="请求失败"
  244. )
  245. # 执行 OCR 处理
  246. logger.info(f"[{request_id}] 开始处理...")
  247. result = await manager.process_ocr(request.image)
  248. logger.info(f"[{request_id}] 处理完成")
  249. return SuccessResponse(
  250. code="200",
  251. data=result,
  252. message="操作成功"
  253. )
  254. except ValueError as e:
  255. # 请求参数验证错误(如 base64 格式非法)
  256. logger.warning(f"[{request_id}] 请求参数验证失败: {e}")
  257. return ErrorResponse(
  258. code="500",
  259. data={},
  260. message="请求失败"
  261. )
  262. except RuntimeError as e:
  263. # 运行时错误(含模型返回 JSON 解析失败)
  264. logger.error(f"[{request_id}] 运行时错误: {e}")
  265. return ErrorResponse(
  266. code="500",
  267. data={},
  268. message="请求失败"
  269. )
  270. except Exception as e:
  271. # 未知错误
  272. logger.error(f"[{request_id}] 未知错误: {e}", exc_info=True)
  273. return ErrorResponse(
  274. code="500",
  275. data={},
  276. message="请求失败"
  277. )
  278. @app.exception_handler(Exception)
  279. async def global_exception_handler(request, exc):
  280. """全局异常处理器"""
  281. logger.error(f"全局异常捕获: {exc}", exc_info=True)
  282. return JSONResponse(
  283. status_code=200, # 按照要求,即使失败也返回 200 HTTP 状态码
  284. content={
  285. "code": "500",
  286. "data": {},
  287. "message": "请求失败"
  288. }
  289. )
  290. # ==================== 主函数 ====================
  291. def main():
  292. """启动服务"""
  293. uvicorn.run(
  294. "api.run_api:app",
  295. host="0.0.0.0",
  296. port=6006, # 使用 8001 端口,避免与 model_api 的 8000 端口冲突
  297. workers=1, # 由于 Agent 占用资源,使用单 worker
  298. log_level="info",
  299. access_log=True,
  300. reload=False # 生产环境禁用热重载
  301. )
  302. if __name__ == "__main__":
  303. main()