|
@@ -0,0 +1,97 @@
|
|
|
|
|
+import argparse
|
|
|
|
|
+import os
|
|
|
|
|
+from models.rank import DataProcess, Trainer, GbdtLrModel
|
|
|
|
|
+import time
|
|
|
|
|
+import pandas as pd
|
|
|
|
|
+from fastapi import FastAPI, HTTPException
|
|
|
|
|
+from pydantic import BaseModel
|
|
|
|
|
+
|
|
|
|
|
+app = FastAPI()
|
|
|
|
|
+
|
|
|
|
|
+model_path = "./models/rank/weights"
|
|
|
|
|
+model_name = "model.pkl"
|
|
|
|
|
+
|
|
|
|
|
+# 定义请求体
|
|
|
|
|
+class TrainRequest(BaseModel):
|
|
|
|
|
+ city_uuid: str
|
|
|
|
|
+ train_data_path: str = "./models/rank/train_data/gbdt_data.csv"
|
|
|
|
|
+ model_path: str = model_path
|
|
|
|
|
+ model_name: str = model_name
|
|
|
|
|
+
|
|
|
|
|
+class RecommendRequest(BaseModel):
|
|
|
|
|
+ city_uuid: str
|
|
|
|
|
+ product_id: str
|
|
|
|
|
+ last_n: int = 200
|
|
|
|
|
+ model_path: str = model_path
|
|
|
|
|
+ model_name: str = model_name
|
|
|
|
|
+
|
|
|
|
|
+class ImportanceRequest(BaseModel):
|
|
|
|
|
+ city_uuid: str
|
|
|
|
|
+ model_path: str = model_path
|
|
|
|
|
+ model_name: str = model_name
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/train")
|
|
|
|
|
+def train(request: TrainRequest):
|
|
|
|
|
+ model_dir = os.path.join(request.model_path, request.city_uuid)
|
|
|
|
|
+ train_data_dir = os.path.dirname(request.train_data_path)
|
|
|
|
|
+ if not os.path.exists(model_dir):
|
|
|
|
|
+ os.makedirs(model_dir)
|
|
|
|
|
+
|
|
|
|
|
+ if not os.path.exists(train_data_dir):
|
|
|
|
|
+ os.makedirs(train_data_dir)
|
|
|
|
|
+
|
|
|
|
|
+ # 准备数据集
|
|
|
|
|
+ print("正在整合训练数据...")
|
|
|
|
|
+ processor = DataProcess(request.city_uuid, request.train_data_path)
|
|
|
|
|
+ processor.data_process()
|
|
|
|
|
+ print("训练数据整合完成!")
|
|
|
|
|
+
|
|
|
|
|
+ # 进行训练
|
|
|
|
|
+ trainer = Trainer(request.train_data_path)
|
|
|
|
|
+
|
|
|
|
|
+ start_time = time.time()
|
|
|
|
|
+ trainer.train()
|
|
|
|
|
+ end_time = time.time()
|
|
|
|
|
+
|
|
|
|
|
+ training_time_hours = (end_time - start_time) / 3600
|
|
|
|
|
+ print(f"训练时间: {training_time_hours:.4f} 小时")
|
|
|
|
|
+
|
|
|
|
|
+ eval_metrics = trainer.evaluate()
|
|
|
|
|
+
|
|
|
|
|
+ # 保存模型
|
|
|
|
|
+ trainer.save_model(os.path.join(model_dir, request.model_name))
|
|
|
|
|
+
|
|
|
|
|
+ # 输出评估结果
|
|
|
|
|
+ print("GBDT-LR Evaluation Metrics:")
|
|
|
|
|
+ for metric, value in eval_metrics.items():
|
|
|
|
|
+ print(f"{metric}: {value:.4f}")
|
|
|
|
|
+
|
|
|
|
|
+ return {"message": "训练完成!"}
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/recommend")
|
|
|
|
|
+def recommend(request: RecommendRequest):
|
|
|
|
|
+ model_dir = os.path.join(request.model_path, request.city_uuid)
|
|
|
|
|
+ if not os.path.exists(model_dir):
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="暂无该城市的模型,请先进行模型训练")
|
|
|
|
|
+
|
|
|
|
|
+ # 加载模型
|
|
|
|
|
+ model = GbdtLrModel(os.path.join(model_dir, request.model_name))
|
|
|
|
|
+ recommend_list = model.sort(request.city_uuid, request.product_id)
|
|
|
|
|
+
|
|
|
|
|
+ return {"recommendations": recommend_list[:min(request.last_n, len(recommend_list))]}
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/importance")
|
|
|
|
|
+def importance(request: ImportanceRequest):
|
|
|
|
|
+ model_dir = os.path.join(request.model_path, request.city_uuid)
|
|
|
|
|
+ if not os.path.exists(model_dir):
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="暂无该城市的模型,请先进行模型训练")
|
|
|
|
|
+
|
|
|
|
|
+ # 加载模型
|
|
|
|
|
+ model = GbdtLrModel(os.path.join(model_dir, request.model_name))
|
|
|
|
|
+ cust_features_importance, product_features_importance = model.generate_feats_importance()
|
|
|
|
|
+
|
|
|
|
|
+ return {"cust_features_importance": cust_features_importance, "product_features_importance": product_features_importance}
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ import uvicorn
|
|
|
|
|
+ uvicorn.run(app, host="0.0.0.0", port=8000)
|