start.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. """
  2. Agent OCR 服务启动脚本
  3. 用法:
  4. python start.py # 使用默认配置文件 ocr_config.yaml
  5. python start.py --config /path/to/cfg # 指定配置文件
  6. """
  7. import argparse
  8. import os
  9. import sys
  10. import yaml
  11. import uvicorn
  12. def load_config(config_path: str) -> dict:
  13. """加载 yaml 配置文件"""
  14. if not os.path.exists(config_path):
  15. print(f"[ERROR] 配置文件不存在: {config_path}")
  16. sys.exit(1)
  17. with open(config_path, "r", encoding="utf-8") as f:
  18. cfg = yaml.safe_load(f)
  19. if not isinstance(cfg, dict):
  20. print("[ERROR] 配置文件格式错误,请检查 yaml 格式")
  21. sys.exit(1)
  22. return cfg
  23. def apply_inference_config(cfg: dict):
  24. """将推理服务配置写入环境变量,供 config/config.py 读取"""
  25. inf = cfg.get("inference", {})
  26. mapping = {
  27. "INFERENCE_URL": inf.get("url"),
  28. "INFERENCE_AUTH_TOKEN": inf.get("auth_token"),
  29. "INFERENCE_MODEL": inf.get("model"),
  30. }
  31. for key, val in mapping.items():
  32. if val:
  33. os.environ[key] = str(val)
  34. def apply_server_config(cfg: dict) -> tuple:
  35. """解析服务配置,返回 (host, port, max_concurrent)"""
  36. srv = cfg.get("server", {})
  37. host = str(srv.get("host", "0.0.0.0"))
  38. port = int(srv.get("port", 6006))
  39. max_concurrent = int(srv.get("max_concurrent", 5))
  40. os.environ["MAX_CONCURRENT"] = str(max_concurrent)
  41. return host, port, max_concurrent
  42. def apply_image_config(cfg: dict):
  43. """将图像预处理配置写入环境变量,供 agent/agent.py 读取"""
  44. img = cfg.get("image", {})
  45. os.environ["IMAGE_MAX_SIZE"] = str(img.get("max_size", 512))
  46. os.environ["IMAGE_COMPRESS"] = str(img.get("compress", False)).lower()
  47. os.environ["IMAGE_COMPRESS_QUALITY"]= str(img.get("compress_quality", 70))
  48. def setup_path():
  49. """将项目根目录(ocr/ 的上级)加入 sys.path,确保能找到 agent/api/config 包"""
  50. ocr_dir = os.path.dirname(os.path.abspath(__file__))
  51. root_dir = os.path.dirname(ocr_dir)
  52. if root_dir not in sys.path:
  53. sys.path.insert(0, root_dir)
  54. def print_config(cfg: dict, host: str, port: int, max_concurrent: int):
  55. inf = cfg.get("inference", {})
  56. img = cfg.get("image", {})
  57. token = inf.get("auth_token", "")
  58. masked = ("*" * max(0, len(token) - 6)) + token[-6:] if token else ""
  59. print("=" * 50)
  60. print(" Agent OCR 服务配置")
  61. print("=" * 50)
  62. print(f" 推理地址 : {inf.get('url', '')}")
  63. print(f" 认证Token : {masked}")
  64. print(f" 模型名称 : {inf.get('model', '')}")
  65. print(f" 服务地址 : http://{host}:{port}")
  66. print(f" 最大并发 : {max_concurrent}")
  67. print(f" 图像最大边 : {img.get('max_size', 512)}px")
  68. compress = img.get("compress", False)
  69. print(f" 图像压缩 : {'开启,质量=' + str(img.get('compress_quality', 70)) if compress else '关闭'}")
  70. print("=" * 50)
  71. def main():
  72. parser = argparse.ArgumentParser(description="启动 Agent OCR API 服务")
  73. parser.add_argument(
  74. "--config",
  75. default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ocr_config.yaml"),
  76. help="配置文件路径(默认:与 start.py 同目录的 ocr_config.yaml)",
  77. )
  78. args = parser.parse_args()
  79. cfg = load_config(args.config)
  80. apply_inference_config(cfg)
  81. host, port, max_concurrent = apply_server_config(cfg)
  82. apply_image_config(cfg)
  83. setup_path()
  84. print_config(cfg, host, port, max_concurrent)
  85. uvicorn.run(
  86. "api.run_api:app",
  87. host=host,
  88. port=port,
  89. workers=1,
  90. log_level="info",
  91. access_log=True,
  92. reload=False,
  93. )
  94. if __name__ == "__main__":
  95. main()