start.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. """
  2. Agent OCR 服务启动脚本
  3. 用法:
  4. python start.py # 使用默认配置文件 ocr_config.yaml
  5. python start.py --config /path/to/cfg # 指定配置文件
  6. """
  7. import argparse
  8. import os
  9. import sys
  10. import yaml
  11. import uvicorn
  12. def load_config(config_path: str) -> dict:
  13. """加载 yaml 配置文件"""
  14. if not os.path.exists(config_path):
  15. print(f"[ERROR] 配置文件不存在: {config_path}")
  16. sys.exit(1)
  17. with open(config_path, "r", encoding="utf-8") as f:
  18. cfg = yaml.safe_load(f)
  19. if not isinstance(cfg, dict):
  20. print("[ERROR] 配置文件格式错误,请检查 yaml 格式")
  21. sys.exit(1)
  22. return cfg
  23. def apply_inference_config(cfg: dict):
  24. """将推理服务配置写入环境变量,供 config/config.py 读取"""
  25. inf = cfg.get("inference", {})
  26. mapping = {
  27. "INFERENCE_URL": inf.get("url"),
  28. "INFERENCE_AUTH_TOKEN": inf.get("auth_token"),
  29. "INFERENCE_MODEL": inf.get("model"),
  30. }
  31. for key, val in mapping.items():
  32. if val:
  33. os.environ[key] = str(val)
  34. def apply_server_config(cfg: dict) -> tuple:
  35. """解析服务配置,返回 (host, port, max_concurrent)"""
  36. srv = cfg.get("server", {})
  37. host = str(srv.get("host", "0.0.0.0"))
  38. port = int(srv.get("port", 6006))
  39. max_concurrent = int(srv.get("max_concurrent", 5))
  40. # 写入环境变量,供 run_api.py 的 lifespan 读取
  41. os.environ["MAX_CONCURRENT"] = str(max_concurrent)
  42. return host, port, max_concurrent
  43. def setup_path():
  44. """将项目根目录(ocr/ 的上级)加入 sys.path,确保能找到 agent/api/config 包"""
  45. ocr_dir = os.path.dirname(os.path.abspath(__file__))
  46. root_dir = os.path.dirname(ocr_dir)
  47. if root_dir not in sys.path:
  48. sys.path.insert(0, root_dir)
  49. def print_config(cfg: dict, host: str, port: int, max_concurrent: int):
  50. inf = cfg.get("inference", {})
  51. token = inf.get("auth_token", "")
  52. # 只显示 token 末尾 6 位,避免泄露
  53. masked = ("*" * max(0, len(token) - 6)) + token[-6:] if token else ""
  54. print("=" * 50)
  55. print(" Agent OCR 服务配置")
  56. print("=" * 50)
  57. print(f" 推理地址 : {inf.get('url', '')}")
  58. print(f" 认证Token : {masked}")
  59. print(f" 模型名称 : {inf.get('model', '')}")
  60. print(f" 服务地址 : http://{host}:{port}")
  61. print(f" 最大并发 : {max_concurrent}")
  62. print("=" * 50)
  63. def main():
  64. parser = argparse.ArgumentParser(description="启动 Agent OCR API 服务")
  65. parser.add_argument(
  66. "--config",
  67. default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ocr_config.yaml"),
  68. help="配置文件路径(默认:与 start.py 同目录的 ocr_config.yaml)",
  69. )
  70. args = parser.parse_args()
  71. cfg = load_config(args.config)
  72. apply_inference_config(cfg)
  73. host, port, max_concurrent = apply_server_config(cfg)
  74. setup_path()
  75. print_config(cfg, host, port, max_concurrent)
  76. uvicorn.run(
  77. "api.run_api:app",
  78. host=host,
  79. port=port,
  80. workers=1,
  81. log_level="info",
  82. access_log=True,
  83. reload=False,
  84. )
  85. if __name__ == "__main__":
  86. main()