| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- import argparse
- import os
- from models.rank import DataProcess, Trainer, GbdtLrModel
- import time
- import pandas as pd
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel
- app = FastAPI()
- model_path = "./models/rank/weights"
- model_name = "model.pkl"
- # 定义请求体
- class TrainRequest(BaseModel):
- city_uuid: str
- train_data_path: str = "./models/rank/train_data/gbdt_data.csv"
- model_path: str = model_path
- model_name: str = model_name
-
- class RecommendRequest(BaseModel):
- city_uuid: str
- product_id: str
- last_n: int = 200
- model_path: str = model_path
- model_name: str = model_name
-
- class ImportanceRequest(BaseModel):
- city_uuid: str
- model_path: str = model_path
- model_name: str = model_name
-
- @app.post("/train")
- def train(request: TrainRequest):
- model_dir = os.path.join(request.model_path, request.city_uuid)
- train_data_dir = os.path.dirname(request.train_data_path)
- if not os.path.exists(model_dir):
- os.makedirs(model_dir)
-
- if not os.path.exists(train_data_dir):
- os.makedirs(train_data_dir)
-
- # 准备数据集
- print("正在整合训练数据...")
- processor = DataProcess(request.city_uuid, request.train_data_path)
- processor.data_process()
- print("训练数据整合完成!")
-
- # 进行训练
- trainer = Trainer(request.train_data_path)
-
- start_time = time.time()
- trainer.train()
- end_time = time.time()
-
- training_time_hours = (end_time - start_time) / 3600
- print(f"训练时间: {training_time_hours:.4f} 小时")
-
- eval_metrics = trainer.evaluate()
-
- # 保存模型
- trainer.save_model(os.path.join(model_dir, request.model_name))
-
- # 输出评估结果
- print("GBDT-LR Evaluation Metrics:")
- for metric, value in eval_metrics.items():
- print(f"{metric}: {value:.4f}")
-
- return {"message": "训练完成!"}
- @app.post("/recommend")
- def recommend(request: RecommendRequest):
- model_dir = os.path.join(request.model_path, request.city_uuid)
- if not os.path.exists(model_dir):
- raise HTTPException(status_code=404, detail="暂无该城市的模型,请先进行模型训练")
-
- # 加载模型
- model = GbdtLrModel(os.path.join(model_dir, request.model_name))
- recommend_list = model.sort(request.city_uuid, request.product_id)
-
- return {"recommendations": recommend_list[:min(request.last_n, len(recommend_list))]}
- @app.post("/importance")
- def importance(request: ImportanceRequest):
- model_dir = os.path.join(request.model_path, request.city_uuid)
- if not os.path.exists(model_dir):
- raise HTTPException(status_code=404, detail="暂无该城市的模型,请先进行模型训练")
-
- # 加载模型
- model = GbdtLrModel(os.path.join(model_dir, request.model_name))
- cust_features_importance, product_features_importance = model.generate_feats_importance()
-
- return {"cust_features_importance": cust_features_importance, "product_features_importance": product_features_importance}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8000)
|