Преглед на файлове

gbdt_lr端口调用封装

Sherlock преди 1 година
родител
ревизия
f844ceaca5
променени са 3 файла, в които са добавени 103 реда и са изтрити 6 реда
  1. 6 6
      dao/dao.py
  2. 97 0
      gbdt_lr_api.py
  3. BIN
      requirements.txt

+ 6 - 6
dao/dao.py

@@ -3,17 +3,17 @@ from dao import Mysql
 def load_order_data_from_mysql(city_uuid):
     """从数据库中读取订单数据"""
     client = Mysql()
-    tablename = "yunfu_mock_data"
-    # tablename = "tads_brandcul_cust_order"
+    # tablename = "yunfu_mock_data"
+    tablename = "tads_brandcul_cust_order"
     query_text = "*"
     # city_uuid = "00000000000000000000000011441801"
-    # df = client.load_data(tablename, query_text, "city_uuid", city_uuid)
-    df = client.load_mock_data(tablename, query_text)
+    df = client.load_data(tablename, query_text, "city_uuid", city_uuid)
+    # df = client.load_mock_data(tablename, query_text)
     if len(df) == 0:
         return None
     
-    # df.drop('stat_month', axis=1, inplace=True)
-    # df.drop('city_uuid', axis=1, inplace=True)
+    df.drop('stat_month', axis=1, inplace=True)
+    df.drop('city_uuid', axis=1, inplace=True)
     
     # 去除重复值和填补缺失值
     df.drop_duplicates(inplace=True)

+ 97 - 0
gbdt_lr_api.py

@@ -0,0 +1,97 @@
+import argparse
+import os
+from models.rank import DataProcess, Trainer, GbdtLrModel
+import time
+import pandas as pd
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel
+
+app = FastAPI()
+
+model_path = "./models/rank/weights"
+model_name = "model.pkl"
+
+# 定义请求体
+class TrainRequest(BaseModel):
+    city_uuid: str
+    train_data_path: str = "./models/rank/train_data/gbdt_data.csv"
+    model_path: str = model_path
+    model_name: str = model_name
+    
+class RecommendRequest(BaseModel):
+    city_uuid: str
+    product_id: str
+    last_n: int = 200
+    model_path: str = model_path
+    model_name: str = model_name
+    
+class ImportanceRequest(BaseModel):
+    city_uuid: str
+    model_path: str = model_path
+    model_name: str = model_name
+    
+@app.post("/train")
+def train(request: TrainRequest):
+    model_dir = os.path.join(request.model_path, request.city_uuid)
+    train_data_dir = os.path.dirname(request.train_data_path)
+    if not os.path.exists(model_dir):
+        os.makedirs(model_dir)
+    
+    if not os.path.exists(train_data_dir):
+        os.makedirs(train_data_dir)
+        
+    # 准备数据集  
+    print("正在整合训练数据...")
+    processor = DataProcess(request.city_uuid, request.train_data_path)
+    processor.data_process()
+    print("训练数据整合完成!")
+    
+    # 进行训练
+    trainer = Trainer(request.train_data_path)
+    
+    start_time = time.time()
+    trainer.train()
+    end_time = time.time()
+    
+    training_time_hours = (end_time - start_time) / 3600
+    print(f"训练时间: {training_time_hours:.4f} 小时")
+    
+    eval_metrics = trainer.evaluate()
+    
+    # 保存模型
+    trainer.save_model(os.path.join(model_dir, request.model_name))
+    
+    # 输出评估结果
+    print("GBDT-LR Evaluation Metrics:")
+    for metric, value in eval_metrics.items():
+        print(f"{metric}: {value:.4f}")
+    
+    return {"message": "训练完成!"}
+
+@app.post("/recommend")
+def recommend(request: RecommendRequest):
+    model_dir = os.path.join(request.model_path, request.city_uuid)
+    if not os.path.exists(model_dir):
+        raise HTTPException(status_code=404, detail="暂无该城市的模型,请先进行模型训练")
+    
+    # 加载模型
+    model = GbdtLrModel(os.path.join(model_dir, request.model_name))
+    recommend_list = model.sort(request.city_uuid, request.product_id)
+    
+    return {"recommendations": recommend_list[:min(request.last_n, len(recommend_list))]}
+
+@app.post("/importance")
+def importance(request: ImportanceRequest):
+    model_dir = os.path.join(request.model_path, request.city_uuid)
+    if not os.path.exists(model_dir):
+        raise HTTPException(status_code=404, detail="暂无该城市的模型,请先进行模型训练")
+    
+    # 加载模型
+    model = GbdtLrModel(os.path.join(model_dir, request.model_name))
+    cust_features_importance, product_features_importance = model.generate_feats_importance()
+    
+    return {"cust_features_importance": cust_features_importance, "product_features_importance": product_features_importance}
+
+if __name__ == "__main__":
+    import uvicorn
+    uvicorn.run(app, host="0.0.0.0", port=8000)

BIN
requirements.txt