gbdt_lr_api.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import argparse
  2. import os
  3. from models.rank import DataProcess, Trainer, GbdtLrModel
  4. import time
  5. import pandas as pd
  6. from fastapi import FastAPI, HTTPException
  7. from pydantic import BaseModel
  8. app = FastAPI()
  9. model_path = "./models/rank/weights"
  10. model_name = "model.pkl"
  11. # 定义请求体
  12. class TrainRequest(BaseModel):
  13. city_uuid: str
  14. train_data_path: str = "./models/rank/train_data/gbdt_data.csv"
  15. model_path: str = model_path
  16. model_name: str = model_name
  17. class RecommendRequest(BaseModel):
  18. city_uuid: str
  19. product_id: str
  20. last_n: int = 200
  21. model_path: str = model_path
  22. model_name: str = model_name
  23. class ImportanceRequest(BaseModel):
  24. city_uuid: str
  25. model_path: str = model_path
  26. model_name: str = model_name
  27. @app.post("/train")
  28. def train(request: TrainRequest):
  29. model_dir = os.path.join(request.model_path, request.city_uuid)
  30. train_data_dir = os.path.dirname(request.train_data_path)
  31. if not os.path.exists(model_dir):
  32. os.makedirs(model_dir)
  33. if not os.path.exists(train_data_dir):
  34. os.makedirs(train_data_dir)
  35. # 准备数据集
  36. print("正在整合训练数据...")
  37. processor = DataProcess(request.city_uuid, request.train_data_path)
  38. processor.data_process()
  39. print("训练数据整合完成!")
  40. # 进行训练
  41. trainer = Trainer(request.train_data_path)
  42. start_time = time.time()
  43. trainer.train()
  44. end_time = time.time()
  45. training_time_hours = (end_time - start_time) / 3600
  46. print(f"训练时间: {training_time_hours:.4f} 小时")
  47. eval_metrics = trainer.evaluate()
  48. # 保存模型
  49. trainer.save_model(os.path.join(model_dir, request.model_name))
  50. # 输出评估结果
  51. print("GBDT-LR Evaluation Metrics:")
  52. for metric, value in eval_metrics.items():
  53. print(f"{metric}: {value:.4f}")
  54. return {"message": "训练完成!"}
  55. @app.post("/recommend")
  56. def recommend(request: RecommendRequest):
  57. model_dir = os.path.join(request.model_path, request.city_uuid)
  58. if not os.path.exists(model_dir):
  59. raise HTTPException(status_code=404, detail="暂无该城市的模型,请先进行模型训练")
  60. # 加载模型
  61. model = GbdtLrModel(os.path.join(model_dir, request.model_name))
  62. recommend_list = model.sort(request.city_uuid, request.product_id)
  63. return {"recommendations": recommend_list[:min(request.last_n, len(recommend_list))]}
  64. @app.post("/importance")
  65. def importance(request: ImportanceRequest):
  66. model_dir = os.path.join(request.model_path, request.city_uuid)
  67. if not os.path.exists(model_dir):
  68. raise HTTPException(status_code=404, detail="暂无该城市的模型,请先进行模型训练")
  69. # 加载模型
  70. model = GbdtLrModel(os.path.join(model_dir, request.model_name))
  71. cust_features_importance, product_features_importance = model.generate_feats_importance()
  72. return {"cust_features_importance": cust_features_importance, "product_features_importance": product_features_importance}
  73. if __name__ == "__main__":
  74. import uvicorn
  75. uvicorn.run(app, host="0.0.0.0", port=8000)