import argparse import os from models.rank import DataProcess, Trainer, GbdtLrModel import time import pandas as pd from fastapi import FastAPI, HTTPException from pydantic import BaseModel app = FastAPI() model_path = "./models/rank/weights" model_name = "model.pkl" # 定义请求体 class TrainRequest(BaseModel): city_uuid: str train_data_path: str = "./models/rank/train_data/gbdt_data.csv" model_path: str = model_path model_name: str = model_name class RecommendRequest(BaseModel): city_uuid: str product_id: str last_n: int = 200 model_path: str = model_path model_name: str = model_name class ImportanceRequest(BaseModel): city_uuid: str model_path: str = model_path model_name: str = model_name @app.post("/train") def train(request: TrainRequest): model_dir = os.path.join(request.model_path, request.city_uuid) train_data_dir = os.path.dirname(request.train_data_path) if not os.path.exists(model_dir): os.makedirs(model_dir) if not os.path.exists(train_data_dir): os.makedirs(train_data_dir) # 准备数据集 print("正在整合训练数据...") processor = DataProcess(request.city_uuid, request.train_data_path) processor.data_process() print("训练数据整合完成!") # 进行训练 trainer = Trainer(request.train_data_path) start_time = time.time() trainer.train() end_time = time.time() training_time_hours = (end_time - start_time) / 3600 print(f"训练时间: {training_time_hours:.4f} 小时") eval_metrics = trainer.evaluate() # 保存模型 trainer.save_model(os.path.join(model_dir, request.model_name)) # 输出评估结果 print("GBDT-LR Evaluation Metrics:") for metric, value in eval_metrics.items(): print(f"{metric}: {value:.4f}") return {"message": "训练完成!"} @app.post("/recommend") def recommend(request: RecommendRequest): model_dir = os.path.join(request.model_path, request.city_uuid) if not os.path.exists(model_dir): raise HTTPException(status_code=404, detail="暂无该城市的模型,请先进行模型训练") # 加载模型 model = GbdtLrModel(os.path.join(model_dir, request.model_name)) recommend_list = model.sort(request.city_uuid, request.product_id) return {"recommendations": recommend_list[:min(request.last_n, len(recommend_list))]} @app.post("/importance") def importance(request: ImportanceRequest): model_dir = os.path.join(request.model_path, request.city_uuid) if not os.path.exists(model_dir): raise HTTPException(status_code=404, detail="暂无该城市的模型,请先进行模型训练") # 加载模型 model = GbdtLrModel(os.path.join(model_dir, request.model_name)) cust_features_importance, product_features_importance = model.generate_feats_importance() return {"cust_features_importance": cust_features_importance, "product_features_importance": product_features_importance} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)