|
|
@@ -33,38 +33,84 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
# ==================== 请求/响应模型 ====================
|
|
|
-class OCRRequest(BaseModel):
|
|
|
- """OCR 推理请求模型"""
|
|
|
- image: str = Field(..., description="Base64 编码的图像字符串")
|
|
|
- text: list = Field(..., description="OCR 提示词文本列表")
|
|
|
+class ImageURL(BaseModel):
|
|
|
+ url: str = Field(..., description="data:image/png;base64,... 格式的图像 URL")
|
|
|
|
|
|
- @validator('image')
|
|
|
- def validate_image(cls, v):
|
|
|
- """验证 base64 图像格式"""
|
|
|
- if not v:
|
|
|
- raise ValueError("图像不能为空")
|
|
|
- try:
|
|
|
- # 尝试解码验证格式
|
|
|
- base64.b64decode(v)
|
|
|
- except Exception:
|
|
|
- raise ValueError("无效的 base64 图像格式")
|
|
|
- return v
|
|
|
|
|
|
- @validator('text')
|
|
|
- def validate_text(cls, v):
|
|
|
- """验证提示词文本"""
|
|
|
+class ContentItem(BaseModel):
|
|
|
+ type: str = Field(..., description="内容类型: image_url 或 text")
|
|
|
+ image_url: Optional[ImageURL] = None
|
|
|
+ text: Optional[str] = None
|
|
|
+
|
|
|
+
|
|
|
+class Message(BaseModel):
|
|
|
+ role: str = Field(..., description="角色: system 或 user")
|
|
|
+ content: Any = Field(..., description="消息内容")
|
|
|
+
|
|
|
+
|
|
|
+class OCRRequest(BaseModel):
|
|
|
+ """OCR 推理请求模型(OpenAI 兼容格式)"""
|
|
|
+ model: Optional[str] = Field(None, description="模型名称")
|
|
|
+ messages: list = Field(..., description="消息列表")
|
|
|
+ max_tokens: Optional[int] = Field(4096, description="最大生成 token 数")
|
|
|
+ stream: Optional[bool] = Field(False, description="是否流式输出")
|
|
|
+ temperature: Optional[float] = Field(0, description="采样温度")
|
|
|
+
|
|
|
+ @validator('messages')
|
|
|
+ def validate_messages(cls, v):
|
|
|
if not v:
|
|
|
- raise ValueError("提示词不能为空")
|
|
|
+ raise ValueError("messages 不能为空")
|
|
|
return v
|
|
|
|
|
|
+ def get_image_base64(self) -> str:
|
|
|
+ """从 messages 中提取 base64 图像(去掉 data:image/xxx;base64, 前缀)"""
|
|
|
+ for msg in self.messages:
|
|
|
+ if msg.get('role') != 'user':
|
|
|
+ continue
|
|
|
+ content = msg.get('content', [])
|
|
|
+ if not isinstance(content, list):
|
|
|
+ continue
|
|
|
+ for item in content:
|
|
|
+ if item.get('type') == 'image_url':
|
|
|
+ url = item.get('image_url', {}).get('url', '')
|
|
|
+ # 去掉 "data:image/png;base64," 前缀
|
|
|
+ if ';base64,' in url:
|
|
|
+ return url.split(';base64,', 1)[1]
|
|
|
+ return url
|
|
|
+ raise ValueError("messages 中未找到 image_url")
|
|
|
+
|
|
|
+ def get_prompt(self) -> str:
|
|
|
+ """从 messages 中提取用户文本提示词"""
|
|
|
+ for msg in self.messages:
|
|
|
+ if msg.get('role') != 'user':
|
|
|
+ continue
|
|
|
+ content = msg.get('content', [])
|
|
|
+ if not isinstance(content, list):
|
|
|
+ continue
|
|
|
+ for item in content:
|
|
|
+ if item.get('type') == 'text':
|
|
|
+ return item.get('text', '')
|
|
|
+ raise ValueError("messages 中未找到 text")
|
|
|
+
|
|
|
+
|
|
|
+class ChoiceMessage(BaseModel):
|
|
|
+ role: str = "assistant"
|
|
|
+ content: Optional[str] = None
|
|
|
+
|
|
|
+
|
|
|
+class Choice(BaseModel):
|
|
|
+ index: int = 0
|
|
|
+ message: ChoiceMessage
|
|
|
+ finish_reason: str = "stop"
|
|
|
+
|
|
|
|
|
|
class OCRResponse(BaseModel):
|
|
|
- """OCR 推理响应模型"""
|
|
|
- success: bool = Field(..., description="请求是否成功")
|
|
|
- data: Optional[Any] = Field(None, description="推理结果数据")
|
|
|
- message: str = Field(..., description="响应消息")
|
|
|
+ """OCR 推理响应模型(OpenAI 兼容格式)"""
|
|
|
+ id: str = Field(..., description="请求ID")
|
|
|
+ object: str = Field("chat.completion", description="对象类型")
|
|
|
+ model: Optional[str] = Field(None, description="模型名称")
|
|
|
+ choices: list = Field(..., description="推理结果列表")
|
|
|
timestamp: str = Field(..., description="响应时间戳")
|
|
|
- request_id: Optional[str] = Field(None, description="请求ID(用于追踪)")
|
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
|
@@ -166,14 +212,14 @@ class ModelManager:
|
|
|
logger.error(f"Base64 转换失败: {e}")
|
|
|
raise ValueError(f"图像解码失败: {str(e)}")
|
|
|
|
|
|
- async def inference(self, image_base64: str, prompts: str) -> list:
|
|
|
+ async def inference(self, image_base64: str, prompt: str) -> str:
|
|
|
"""
|
|
|
- 执行 OCR 推理(带并发控制)
|
|
|
+ 执行 OCR 推理(带并发控制)
|
|
|
Args:
|
|
|
- image_base64: base64 编码的图像
|
|
|
- prompts: 提示词
|
|
|
+ image_base64: base64 编码的图像(不含 data URI 前缀)
|
|
|
+ prompt: 提示词
|
|
|
Returns:
|
|
|
- 推理结果
|
|
|
+ 推理结果字符串
|
|
|
"""
|
|
|
if not self.is_loaded or self.model is None:
|
|
|
raise RuntimeError("模型未加载")
|
|
|
@@ -187,16 +233,16 @@ class ModelManager:
|
|
|
# 转换图像
|
|
|
pil_image = self.base64_to_pil(image_base64)
|
|
|
|
|
|
- # 在线程池中执行推理,避免阻塞
|
|
|
+ # 在线程池中执行推理,避免阻塞
|
|
|
loop = asyncio.get_event_loop()
|
|
|
- result = await loop.run_in_executor(
|
|
|
+ results = await loop.run_in_executor(
|
|
|
None,
|
|
|
self.model.batch_inference,
|
|
|
- [pil_image] * len(prompts),
|
|
|
- prompts
|
|
|
+ [pil_image],
|
|
|
+ [prompt]
|
|
|
)
|
|
|
|
|
|
- return result
|
|
|
+ return results[0]
|
|
|
finally:
|
|
|
async with self._request_lock:
|
|
|
self.current_requests -= 1
|
|
|
@@ -270,42 +316,55 @@ async def health_check():
|
|
|
@app.post("/api/v1/ocr", response_model=OCRResponse)
|
|
|
async def ocr_inference(request: OCRRequest):
|
|
|
"""
|
|
|
- OCR 推理端点
|
|
|
-
|
|
|
- Args:
|
|
|
- request: OCRRequest 对象,包含 image(base64) 和 text(提示词)
|
|
|
-
|
|
|
- Returns:
|
|
|
- OCRResponse: 推理结果
|
|
|
+ OCR 推理端点(OpenAI 兼容格式)
|
|
|
+
|
|
|
+ 请求体与 /v1/chat/completions 格式一致:
|
|
|
+ {
|
|
|
+ "model": "...",
|
|
|
+ "messages": [
|
|
|
+ {"role": "system", "content": "..."},
|
|
|
+ {"role": "user", "content": [
|
|
|
+ {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
|
|
|
+ {"type": "text", "text": "问题"}
|
|
|
+ ]}
|
|
|
+ ],
|
|
|
+ "max_tokens": 4096,
|
|
|
+ "stream": false,
|
|
|
+ "temperature": 0
|
|
|
+ }
|
|
|
"""
|
|
|
- request_id = f"req_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
|
|
+ request_id = f"chatcmpl-{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
|
|
logger.info(f"[{request_id}] 收到 OCR 请求")
|
|
|
|
|
|
try:
|
|
|
- # 获取模型管理器
|
|
|
manager = await ModelManager.get_instance()
|
|
|
|
|
|
if not manager.is_loaded:
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
|
- detail="模型未加载,服务暂不可用"
|
|
|
+ detail="模型未加载,服务暂不可用"
|
|
|
)
|
|
|
|
|
|
- # 执行推理
|
|
|
+ # 从 messages 中提取图像和提示词
|
|
|
+ image_base64 = request.get_image_base64()
|
|
|
+ prompt = request.get_prompt()
|
|
|
+
|
|
|
logger.info(f"[{request_id}] 开始推理...")
|
|
|
- result = await manager.inference(request.image, request.text)
|
|
|
+ content = await manager.inference(image_base64, prompt)
|
|
|
logger.info(f"[{request_id}] 推理完成")
|
|
|
|
|
|
return OCRResponse(
|
|
|
- success=True,
|
|
|
- data=result,
|
|
|
- message="推理成功",
|
|
|
- timestamp=datetime.now().isoformat(),
|
|
|
- request_id=request_id
|
|
|
+ id=request_id,
|
|
|
+ model=request.model,
|
|
|
+ choices=[{
|
|
|
+ "index": 0,
|
|
|
+ "message": {"role": "assistant", "content": content},
|
|
|
+ "finish_reason": "stop"
|
|
|
+ }],
|
|
|
+ timestamp=datetime.now().isoformat()
|
|
|
)
|
|
|
|
|
|
except ValueError as e:
|
|
|
- # 参数验证错误
|
|
|
logger.warning(f"[{request_id}] 参数验证失败: {e}")
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
@@ -313,7 +372,6 @@ async def ocr_inference(request: OCRRequest):
|
|
|
)
|
|
|
|
|
|
except RuntimeError as e:
|
|
|
- # 模型运行时错误
|
|
|
logger.error(f"[{request_id}] 运行时错误: {e}")
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
@@ -321,7 +379,6 @@ async def ocr_inference(request: OCRRequest):
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
- # 未知错误
|
|
|
logger.error(f"[{request_id}] 未知错误: {e}", exc_info=True)
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
@@ -336,10 +393,11 @@ async def global_exception_handler(request, exc):
|
|
|
return JSONResponse(
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
content={
|
|
|
- "success": False,
|
|
|
- "data": None,
|
|
|
- "message": f"服务器错误: {str(exc)}",
|
|
|
- "timestamp": datetime.now().isoformat()
|
|
|
+ "id": None,
|
|
|
+ "object": "chat.completion",
|
|
|
+ "choices": [],
|
|
|
+ "timestamp": datetime.now().isoformat(),
|
|
|
+ "error": {"message": str(exc)}
|
|
|
}
|
|
|
)
|
|
|
|