| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- """
- Agent OCR 服务启动脚本
- 用法:
- python start.py # 使用默认配置文件 ocr_config.yaml
- python start.py --config /path/to/cfg # 指定配置文件
- """
- import argparse
- import os
- import sys
- import yaml
- import uvicorn
- def load_config(config_path: str) -> dict:
- """加载 yaml 配置文件"""
- if not os.path.exists(config_path):
- print(f"[ERROR] 配置文件不存在: {config_path}")
- sys.exit(1)
- with open(config_path, "r", encoding="utf-8") as f:
- cfg = yaml.safe_load(f)
- if not isinstance(cfg, dict):
- print("[ERROR] 配置文件格式错误,请检查 yaml 格式")
- sys.exit(1)
- return cfg
- def apply_inference_config(cfg: dict):
- """将推理服务配置写入环境变量,供 config/config.py 读取"""
- inf = cfg.get("inference", {})
- mapping = {
- "INFERENCE_URL": inf.get("url"),
- "INFERENCE_AUTH_TOKEN": inf.get("auth_token"),
- "INFERENCE_MODEL": inf.get("model"),
- }
- for key, val in mapping.items():
- if val:
- os.environ[key] = str(val)
- def apply_server_config(cfg: dict) -> tuple:
- """解析服务配置,返回 (host, port, max_concurrent)"""
- srv = cfg.get("server", {})
- host = str(srv.get("host", "0.0.0.0"))
- port = int(srv.get("port", 6006))
- max_concurrent = int(srv.get("max_concurrent", 5))
- # 写入环境变量,供 run_api.py 的 lifespan 读取
- os.environ["MAX_CONCURRENT"] = str(max_concurrent)
- return host, port, max_concurrent
- def setup_path():
- """将项目根目录(ocr/ 的上级)加入 sys.path,确保能找到 agent/api/config 包"""
- ocr_dir = os.path.dirname(os.path.abspath(__file__))
- root_dir = os.path.dirname(ocr_dir)
- if root_dir not in sys.path:
- sys.path.insert(0, root_dir)
- def print_config(cfg: dict, host: str, port: int, max_concurrent: int):
- inf = cfg.get("inference", {})
- token = inf.get("auth_token", "")
- # 只显示 token 末尾 6 位,避免泄露
- masked = ("*" * max(0, len(token) - 6)) + token[-6:] if token else ""
- print("=" * 50)
- print(" Agent OCR 服务配置")
- print("=" * 50)
- print(f" 推理地址 : {inf.get('url', '')}")
- print(f" 认证Token : {masked}")
- print(f" 模型名称 : {inf.get('model', '')}")
- print(f" 服务地址 : http://{host}:{port}")
- print(f" 最大并发 : {max_concurrent}")
- print("=" * 50)
- def main():
- parser = argparse.ArgumentParser(description="启动 Agent OCR API 服务")
- parser.add_argument(
- "--config",
- default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ocr_config.yaml"),
- help="配置文件路径(默认:与 start.py 同目录的 ocr_config.yaml)",
- )
- args = parser.parse_args()
- cfg = load_config(args.config)
- apply_inference_config(cfg)
- host, port, max_concurrent = apply_server_config(cfg)
- setup_path()
- print_config(cfg, host, port, max_concurrent)
- uvicorn.run(
- "api.run_api:app",
- host=host,
- port=port,
- workers=1,
- log_level="info",
- access_log=True,
- reload=False,
- )
- if __name__ == "__main__":
- main()
|