| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361 |
- """
- 企业级 Agent OCR API 服务
- 提供基于 FastAPI 的高并发化学品安全标签信息提取服务
- """
- import asyncio
- import base64
- import io
- import logging
- import sys
- from contextlib import asynccontextmanager
- from typing import Optional, Dict, Any
- from datetime import datetime
- from fastapi import FastAPI, HTTPException, status
- from fastapi.responses import JSONResponse
- from pydantic import BaseModel, Field, field_validator
- from PIL import Image
- import uvicorn
- from agent import OcrAgent
- # ==================== 日志配置 ====================
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.StreamHandler(sys.stdout),
- logging.FileHandler('agent_ocr_api.log', encoding='utf-8')
- ]
- )
- logger = logging.getLogger(__name__)
- # ==================== 请求/响应模型 ====================
- class AgentOCRRequest(BaseModel):
- """Agent OCR 请求模型"""
- image: str = Field(..., description="Base64 编码的图像字符串")
- @field_validator('image')
- @classmethod
- def validate_image(cls, v):
- """验证 base64 图像格式"""
- if not v:
- raise ValueError("图像不能为空")
- try:
- # 尝试解码验证格式
- base64.b64decode(v)
- except Exception:
- raise ValueError("无效的 base64 图像格式")
- return v
- class SuccessResponse(BaseModel):
- """成功响应模型"""
- code: str = Field(default="200", description="响应代码")
- data: Dict[str, Any] = Field(..., description="提取的化学品标签信息")
- message: str = Field(default="操作成功", description="响应消息")
- class ErrorResponse(BaseModel):
- """错误响应模型"""
- code: str = Field(default="500", description="错误代码")
- data: Dict[str, Any] = Field(default_factory=dict, description="空数据")
- message: str = Field(default="请求失败", description="错误消息")
- class HealthResponse(BaseModel):
- """健康检查响应模型"""
- status: str
- agent_loaded: bool
- timestamp: str
- concurrent_requests: int
- max_concurrent: int
- # ==================== Agent 管理器(单例模式) ====================
- class AgentManager:
- """Agent 管理器 - 单例模式确保全局只有一个 Agent 实例"""
- _instance: Optional['AgentManager'] = None
- _lock = asyncio.Lock()
- def __init__(self):
- self.agent: Optional[OcrAgent] = None
- self.is_loaded: bool = False
- self.semaphore: Optional[asyncio.Semaphore] = None
- self.max_concurrent_requests: int = 10 # 最大并发请求数
- self.current_requests: int = 0
- self._request_lock = asyncio.Lock()
- @classmethod
- async def get_instance(cls) -> 'AgentManager':
- """获取单例实例(线程安全)"""
- if cls._instance is None:
- async with cls._lock:
- if cls._instance is None:
- cls._instance = cls()
- return cls._instance
- async def load_agent(self, max_concurrent: int = 5):
- """
- 加载 Agent
- Args:
- max_concurrent: 最大并发请求数
- """
- if self.is_loaded:
- logger.warning("Agent 已经加载,跳过重复加载")
- return
- try:
- logger.info("开始加载 OcrAgent...")
- # 在线程池中加载 Agent,避免阻塞事件循环
- loop = asyncio.get_event_loop()
- self.agent = await loop.run_in_executor(None, OcrAgent)
- # 初始化并发控制
- self.max_concurrent_requests = max_concurrent
- self.semaphore = asyncio.Semaphore(max_concurrent)
- self.is_loaded = True
- logger.info(f"Agent 加载成功! 最大并发数: {max_concurrent}")
- except Exception as e:
- logger.error(f"Agent 加载失败: {e}", exc_info=True)
- raise RuntimeError(f"Agent 加载失败: {str(e)}")
- async def unload_agent(self):
- """卸载 Agent 并释放资源"""
- if not self.is_loaded:
- return
- try:
- logger.info("开始卸载 Agent...")
- # 等待所有正在进行的请求完成
- while self.current_requests > 0:
- logger.info(f"等待 {self.current_requests} 个请求完成...")
- await asyncio.sleep(0.5)
- self.agent = None
- self.semaphore = None
- self.is_loaded = False
- logger.info("Agent 卸载成功")
- except Exception as e:
- logger.error(f"Agent 卸载失败: {e}", exc_info=True)
- def base64_to_pil(self, base64_str: str) -> Image.Image:
- """
- 将 base64 字符串转换为 PIL Image
- Args:
- base64_str: base64 编码的图像字符串
- Returns:
- PIL.Image 对象
- """
- try:
- # 解码 base64
- image_data = base64.b64decode(base64_str)
- # 转换为 PIL Image
- image = Image.open(io.BytesIO(image_data))
- # 确保是 RGB 模式
- if image.mode != 'RGB':
- image = image.convert('RGB')
- return image
- except Exception as e:
- logger.error(f"Base64 转换失败: {e}")
- raise ValueError(f"图像解码失败: {str(e)}")
- async def process_ocr(self, image_base64: str) -> Dict[str, Any]:
- """
- 执行 Agent OCR 处理(带并发控制)
- Args:
- image_base64: base64 编码的图像
- Returns:
- 化学品标签信息提取结果
- """
- if not self.is_loaded or self.agent is None:
- raise RuntimeError("Agent 未加载")
- # 并发控制
- async with self.semaphore:
- async with self._request_lock:
- self.current_requests += 1
- try:
- # 转换图像
- pil_image = self.base64_to_pil(image_base64)
- # 在线程池中执行 agent_ocr,避免阻塞
- loop = asyncio.get_event_loop()
- result = await loop.run_in_executor(
- None,
- self.agent.agent_ocr,
- pil_image
- )
- return result
- finally:
- async with self._request_lock:
- self.current_requests -= 1
- def get_status(self) -> Dict[str, Any]:
- """获取 Agent 状态"""
- return {
- "is_loaded": self.is_loaded,
- "current_requests": self.current_requests,
- "max_concurrent": self.max_concurrent_requests
- }
- # ==================== FastAPI 应用 ====================
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- """应用生命周期管理"""
- # 启动时加载 Agent
- logger.info("应用启动中...")
- manager = await AgentManager.get_instance()
- try:
- await manager.load_agent(max_concurrent=5)
- logger.info("应用启动完成")
- except Exception as e:
- logger.error(f"应用启动失败: {e}")
- raise
- yield
- # 关闭时卸载 Agent
- logger.info("应用关闭中...")
- await manager.unload_agent()
- logger.info("应用已关闭")
- # 创建 FastAPI 应用
- app = FastAPI(
- title="Agent OCR API",
- description="企业级化学品安全标签信息提取服务",
- version="1.0.0",
- lifespan=lifespan
- )
- # ==================== API 端点 ====================
- @app.get("/", response_model=Dict[str, str])
- async def root():
- """根路径"""
- return {
- "message": "Agent OCR API Service",
- "version": "1.0.0",
- "docs": "/docs"
- }
- @app.get("/health", response_model=HealthResponse)
- async def health_check():
- """健康检查端点"""
- manager = await AgentManager.get_instance()
- status_info = manager.get_status()
- return HealthResponse(
- status="healthy" if status_info["is_loaded"] else "unhealthy",
- agent_loaded=status_info["is_loaded"],
- timestamp=datetime.now().isoformat(),
- concurrent_requests=status_info["current_requests"],
- max_concurrent=status_info["max_concurrent"]
- )
- @app.post("/api/v1/agent_ocr")
- async def agent_ocr_endpoint(request: AgentOCRRequest):
- """
- Agent OCR 化学品标签信息提取端点
- Args:
- request: AgentOCRRequest 对象,包含 image(base64 编码的图像)
- Returns:
- 成功: {"code": "200", "data": {...}, "message": "操作成功"}
- 失败: {"code": "500", "data": {}, "message": "请求失败"}
- """
- request_id = f"req_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
- logger.info(f"[{request_id}] 收到 Agent OCR 请求")
- try:
- # 获取 Agent 管理器
- manager = await AgentManager.get_instance()
- if not manager.is_loaded:
- logger.error(f"[{request_id}] Agent 未加载")
- return ErrorResponse(
- code="500",
- data={},
- message="请求失败"
- )
- # 执行 OCR 处理
- logger.info(f"[{request_id}] 开始处理...")
- result = await manager.process_ocr(request.image)
- logger.info(f"[{request_id}] 处理完成")
- return SuccessResponse(
- code="200",
- data=result,
- message="操作成功"
- )
- except ValueError as e:
- # 参数验证错误
- logger.warning(f"[{request_id}] 参数验证失败: {e}")
- return ErrorResponse(
- code="500",
- data={},
- message="请求失败"
- )
- except RuntimeError as e:
- # 运行时错误
- logger.error(f"[{request_id}] 运行时错误: {e}")
- return ErrorResponse(
- code="500",
- data={},
- message="请求失败"
- )
- except Exception as e:
- # 未知错误
- logger.error(f"[{request_id}] 未知错误: {e}", exc_info=True)
- return ErrorResponse(
- code="500",
- data={},
- message="请求失败"
- )
- @app.exception_handler(Exception)
- async def global_exception_handler(request, exc):
- """全局异常处理器"""
- logger.error(f"全局异常捕获: {exc}", exc_info=True)
- return JSONResponse(
- status_code=200, # 按照要求,即使失败也返回 200 HTTP 状态码
- content={
- "code": "500",
- "data": {},
- "message": "请求失败"
- }
- )
- # ==================== 主函数 ====================
- def main():
- """启动服务"""
- uvicorn.run(
- "api.run_api:app",
- host="0.0.0.0",
- port=6006, # 使用 8001 端口,避免与 model_api 的 8000 端口冲突
- workers=1, # 由于 Agent 占用资源,使用单 worker
- log_level="info",
- access_log=True,
- reload=False # 生产环境禁用热重载
- )
- if __name__ == "__main__":
- main()
|