|
@@ -0,0 +1,104 @@
|
|
|
|
|
+"""
|
|
|
|
|
+Agent OCR 服务启动脚本
|
|
|
|
|
+用法:
|
|
|
|
|
+ python start.py # 使用默认配置文件 ocr_config.yaml
|
|
|
|
|
+ python start.py --config /path/to/cfg # 指定配置文件
|
|
|
|
|
+"""
|
|
|
|
|
+import argparse
|
|
|
|
|
+import os
|
|
|
|
|
+import sys
|
|
|
|
|
+
|
|
|
|
|
+import yaml
|
|
|
|
|
+import uvicorn
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def load_config(config_path: str) -> dict:
|
|
|
|
|
+ """加载 yaml 配置文件"""
|
|
|
|
|
+ if not os.path.exists(config_path):
|
|
|
|
|
+ print(f"[ERROR] 配置文件不存在: {config_path}")
|
|
|
|
|
+ sys.exit(1)
|
|
|
|
|
+ with open(config_path, "r", encoding="utf-8") as f:
|
|
|
|
|
+ cfg = yaml.safe_load(f)
|
|
|
|
|
+ if not isinstance(cfg, dict):
|
|
|
|
|
+ print("[ERROR] 配置文件格式错误,请检查 yaml 格式")
|
|
|
|
|
+ sys.exit(1)
|
|
|
|
|
+ return cfg
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def apply_inference_config(cfg: dict):
|
|
|
|
|
+ """将推理服务配置写入环境变量,供 config/config.py 读取"""
|
|
|
|
|
+ inf = cfg.get("inference", {})
|
|
|
|
|
+ mapping = {
|
|
|
|
|
+ "INFERENCE_URL": inf.get("url"),
|
|
|
|
|
+ "INFERENCE_AUTH_TOKEN": inf.get("auth_token"),
|
|
|
|
|
+ "INFERENCE_MODEL": inf.get("model"),
|
|
|
|
|
+ }
|
|
|
|
|
+ for key, val in mapping.items():
|
|
|
|
|
+ if val:
|
|
|
|
|
+ os.environ[key] = str(val)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def apply_server_config(cfg: dict) -> tuple:
|
|
|
|
|
+ """解析服务配置,返回 (host, port, max_concurrent)"""
|
|
|
|
|
+ srv = cfg.get("server", {})
|
|
|
|
|
+ host = str(srv.get("host", "0.0.0.0"))
|
|
|
|
|
+ port = int(srv.get("port", 6006))
|
|
|
|
|
+ max_concurrent = int(srv.get("max_concurrent", 5))
|
|
|
|
|
+ # 写入环境变量,供 run_api.py 的 lifespan 读取
|
|
|
|
|
+ os.environ["MAX_CONCURRENT"] = str(max_concurrent)
|
|
|
|
|
+ return host, port, max_concurrent
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def setup_path():
|
|
|
|
|
+ """将项目根目录(ocr/ 的上级)加入 sys.path,确保能找到 agent/api/config 包"""
|
|
|
|
|
+ ocr_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
+ root_dir = os.path.dirname(ocr_dir)
|
|
|
|
|
+ if root_dir not in sys.path:
|
|
|
|
|
+ sys.path.insert(0, root_dir)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def print_config(cfg: dict, host: str, port: int, max_concurrent: int):
|
|
|
|
|
+ inf = cfg.get("inference", {})
|
|
|
|
|
+ token = inf.get("auth_token", "")
|
|
|
|
|
+ # 只显示 token 末尾 6 位,避免泄露
|
|
|
|
|
+ masked = ("*" * max(0, len(token) - 6)) + token[-6:] if token else ""
|
|
|
|
|
+ print("=" * 50)
|
|
|
|
|
+ print(" Agent OCR 服务配置")
|
|
|
|
|
+ print("=" * 50)
|
|
|
|
|
+ print(f" 推理地址 : {inf.get('url', '')}")
|
|
|
|
|
+ print(f" 认证Token : {masked}")
|
|
|
|
|
+ print(f" 模型名称 : {inf.get('model', '')}")
|
|
|
|
|
+ print(f" 服务地址 : http://{host}:{port}")
|
|
|
|
|
+ print(f" 最大并发 : {max_concurrent}")
|
|
|
|
|
+ print("=" * 50)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def main():
|
|
|
|
|
+ parser = argparse.ArgumentParser(description="启动 Agent OCR API 服务")
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--config",
|
|
|
|
|
+ default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ocr_config.yaml"),
|
|
|
|
|
+ help="配置文件路径(默认:与 start.py 同目录的 ocr_config.yaml)",
|
|
|
|
|
+ )
|
|
|
|
|
+ args = parser.parse_args()
|
|
|
|
|
+
|
|
|
|
|
+ cfg = load_config(args.config)
|
|
|
|
|
+ apply_inference_config(cfg)
|
|
|
|
|
+ host, port, max_concurrent = apply_server_config(cfg)
|
|
|
|
|
+ setup_path()
|
|
|
|
|
+
|
|
|
|
|
+ print_config(cfg, host, port, max_concurrent)
|
|
|
|
|
+
|
|
|
|
|
+ uvicorn.run(
|
|
|
|
|
+ "api.run_api:app",
|
|
|
|
|
+ host=host,
|
|
|
|
|
+ port=port,
|
|
|
|
|
+ workers=1,
|
|
|
|
|
+ log_level="info",
|
|
|
|
|
+ access_log=True,
|
|
|
|
|
+ reload=False,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ main()
|