Ver código fonte

封装gbdt_lr训练推理流程

Sherlock 1 ano atrás
pai
commit
90aa78a347

+ 1 - 0
dao/dao.py

@@ -4,6 +4,7 @@ def load_order_data_from_mysql(city_uuid):
     """从数据库中读取订单数据"""
     client = Mysql()
     tablename = "yunfu_mock_data"
+    # tablename = "tads_brandcul_cust_order"
     query_text = "*"
     # city_uuid = "00000000000000000000000011441801"
     # df = client.load_data(tablename, query_text, "city_uuid", city_uuid)

+ 110 - 0
gbdt_lr.py

@@ -0,0 +1,110 @@
+import argparse
+import os
+from models.rank import DataProcess, Trainer, GbdtLrModel
+import time
+import pandas as pd
+
+# train_data_path = "./moldes/rank/data/gbdt_data.csv"
+# model_path = "./models/rank/weights"
+
+def train(args):
+    model_dir = os.path.join(args.model_path, args.city_uuid)
+    if not os.path.exists(model_dir):
+        os.makedirs(model_dir)
+    
+    # 准备数据集  
+    print("正在整合训练数据...")
+    processor = DataProcess(args.city_uuid, args.train_data_path)
+    processor.data_process()
+    print("训练数据整合完成!")
+    
+    # 进行训练
+    trainer(args, model_dir)
+
+def trainer(args, model_dir):
+    trainer = Trainer(args.train_data_path)
+    
+    start_time = time.time()
+    trainer.train()
+    end_time = time.time()
+    
+    training_time_hours = (end_time - start_time) / 3600
+    print(f"训练时间: {training_time_hours:.4f} 小时")
+    
+    eval_metrics = trainer.evaluate()
+    
+    # 输出评估结果
+    print("GBDT-LR Evaluation Metrics:")
+    for metric, value in eval_metrics.items():
+        print(f"{metric}: {value:.4f}")
+        
+    # 保存模型
+    trainer.save_model(os.path.join(model_dir, args.model_name))
+
+def recommend_by_product(args):
+    model_dir = os.path.join(args.model_path, args.city_uuid)
+    if not os.path.exists(model_dir):
+        print("暂无该城市的模型,请先进行模型训练")
+        return
+    
+    # 加载模型
+    model = GbdtLrModel(os.path.join(model_dir, args.model_name))
+    recommend_list = model.sort(args.city_uuid, args.product_id)
+    for item in recommend_list[:min(args.last_n, len(recommend_list))]:
+        print(item)
+
+def get_features_importance(args):
+    model_dir = os.path.join(args.model_path, args.city_uuid)
+    if not os.path.exists(model_dir):
+        print("暂无该城市的模型,请先进行模型训练")
+        return
+    
+    # 加载模型
+    model = GbdtLrModel(os.path.join(model_dir, args.model_name))
+    cust_features_importance, product_features_importance = model.generate_feats_importance()
+    
+    # 将字典列表转换为 DataFrame
+    cust_df = pd.DataFrame([
+        {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
+        for item in cust_features_importance
+    ])
+    
+    product_df = pd.DataFrame([
+        {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
+        for item in product_features_importance
+    ])
+    
+    cust_file_path = os.path.join(model_dir, "cust_features_importance.csv")
+    product_file_path = os.path.join(model_dir, "product_features_importance.csv")
+    cust_df.to_csv(cust_file_path, index=False, encoding='utf-8')
+    product_df.to_csv(product_file_path, index=False, encoding='utf-8')
+        
+def run():
+    parser = argparse.ArgumentParser()
+    
+    parser.add_argument("--run_train", action='store_true')
+    parser.add_argument("--recommend", action='store_true')
+    parser.add_argument("--importance", action='store_true')
+    
+    parser.add_argument("--train_data_path", type=str, default="./models/rank/data/gbdt_data.csv")
+    parser.add_argument("--model_path", type=str, default="./models/rank/weights")
+    parser.add_argument("--model_name", type=str, default='model.pkl')
+    parser.add_argument("--last_n", type=int, default=200)
+    
+    parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301')
+    parser.add_argument("--product_id", type=str, default='110102')
+    
+    
+    args = parser.parse_args()
+    
+    if args.run_train:
+        train(args)
+        
+    if args.recommend:
+        recommend_by_product(args)
+        
+    if args.importance:
+        get_features_importance(args)
+        
+if __name__ == "__main__":
+    run()

+ 5 - 1
models/rank/__init__.py

@@ -1,7 +1,11 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 from models.rank.data.preprocess import DataProcess
+from models.rank.gbdt_lr import Trainer
+from models.rank.gbdt_lr_sort import GbdtLrModel
 
 __all__ = [
-    "DataProcess"
+    "DataProcess",
+    "Trainer",
+    "GbdtLrModel"
 ]

+ 3 - 2
models/rank/data/__init__.py

@@ -1,4 +1,4 @@
-from models.rank.data.config import CustConfig, ProductConfig, OrderConfig
+from models.rank.data.config import CustConfig, ProductConfig, OrderConfig, ImportanceFeaturesMap
 from models.rank.data.dataloader import DataLoader
 from models.rank.data.utils import one_hot_embedding, sample_data_clear
 __all__ = [
@@ -7,5 +7,6 @@ __all__ = [
     "OrderConfig",
     "DataLoader",
     "one_hot_embedding",
-    "sample_data_clear"
+    "sample_data_clear",
+    "ImportanceFeaturesMap"
 ]

+ 47 - 2
models/rank/data/config.py

@@ -13,6 +13,7 @@ class CustConfig:
         "MD04_DIR_SAL_STORE_FLAG",                     # 直营店标识
         "BB_CUSTOMER_MANAGER_SCOPE_NAME",              # 零售户经营范围名称
         "PRODUCT_INSALE_QTY",                          # 在销品规数
+        "CUST_INVESTMENT",                             # 店铺资源投入建设
         
         # "NEW_PRODUCT_MEMBERS_QTY_SAMEPRICE_OCC",       # 新品订货量占同价类比重
         # "PRODUCT_LISTING_RATE",                        # 品规上架率
@@ -50,6 +51,7 @@ class CustConfig:
         "MD04_DIR_SAL_STORE_FLAG":                  {"method": "fillna", "opt": "fill", "value": "否", "type": "str"},
         "BB_CUSTOMER_MANAGER_SCOPE_NAME":           {"method": "fillna", "opt": "fill", "value": "否", "type": "str"},
         "PRODUCT_INSALE_QTY":                       {"method": "fillna", "opt": "mean", "type": "num"},
+        "CUST_INVESTMENT":                          {"method": "fillna", "opt": "mean", "type": "num"}
         
         
         # "NEW_PRODUCT_MEMBERS_QTY_SAMEPRICE_OCC":    {"method": "fillna", "opt": "mean", "type": "num"},
@@ -227,8 +229,11 @@ class OrderConfig:
         "MONTH6_SALE_QTY_YOY",                              # 销售量同比
         "MONTH6_SALE_QTY_MOM",                              # 销售量环比
         "MONTH6_SALE_AMT_YOY",                              # 销售额(购进额)同比
-        "MONTH6_SALE_AMT_MOM",                              # 销售额(狗金额)环比
+        "MONTH6_SALE_AMT_MOM",                              # 销售额(购进额)环比
+        "STOCK_QTY",                                        # 库存
         "ORDER_FULLORDR_RATE",                              # 订足率
+        "FULL_FILLMENT_RATE",                               # 订单满足率
+        "ORDER_FULLORDR_RATE_MOM",                          # 订足率环比
         "CUSTOMER_REPURCHASE_RATE",                         # 会员重购率   
         "DEMAND_RATE",                                      # 需求量满足率
         "LISTING_RATE",                                     # 品规商上架率
@@ -237,11 +242,51 @@ class OrderConfig:
         "YLT_TURNOVER_RATE",                                # 易灵通动销率
         "YLT_BAR_PACKAGE_SALE_OCC",                         # 易灵通调包销售占比
         "UNPACKING_RATE",                                   # 拆包率
+        "POS_PACKAGE_PRICE",                                # pos机单包价格
     ]
     
     WEIGHTS = {
         "MONTH6_SALE_QTY":                                  0.15,
         "MONTH6_SALE_QTY_MOM":                              0.2,
         "ORDER_FULLORDR_RATE":                              0.3,
-        "DEMAND_RATE":                                      0.35,
+        "ORDER_FULLORDR_RATE_MOM":                          0.35,
+    }
+    
+class ImportanceFeaturesMap:
+    CUSTOM_FEATRUES_MAP = {
+        "BB_RTL_CUST_GRADE_NAME":                           "零售户分档名称",
+        "BB_RTL_CUST_MARKET_TYPE_NAME":                     "零售户市场类型名称",
+        "STORE_AREA":                                       "店铺经营面积",
+        "BB_RTL_CUST_BUSINESS_TYPE_NAME":                   "零售户业态名称",
+        "OPERATOR_EDU_LEVEL":                               "零售客户经营者文化程",
+        "OPERATOR_AGE":                                     "经营者年龄",
+        "BB_RTL_CUST_CHAIN_FLAG":                           "零售户连锁标识",
+        "PRESENT_STAR_TERMINAL":                            "终端星级",
+        "MD04_MG_RTL_CUST_CREDITCLASS_NAME":                "零售户信用等级名称",
+        "MD04_DIR_SAL_STORE_FLAG":                          "直营店标识",
+        "BB_CUSTOMER_MANAGER_SCOPE_NAME":                   "零售户经营范围名称",
+        "PRODUCT_INSALE_QTY":                               "在销品规数",
+        "CUST_INVESTMENT":                                  "店铺资源投入建设",
+    }
+    
+    PRODUCT_FEATRUES_MAP = {
+        # ProductConfig 字段映射
+        "direct_retail_price":                              "建议零售价",
+        "is_low_tar":                                       "是否低焦油烟",
+        "tar_qty":                                          "焦油含量",
+        "is_exploding_beads":                               "是否爆珠",
+        "is_shortbranch":                                   "是否短支烟",
+        "is_medium":                                        "是否中支烟",
+        "is_tiny":                                          "是否细支",
+        "product_style_code_name":                          "包装类型名称",
+        "org_is_abnormity":                                 "是否异形包装",
+        "is_chuangxin":                                     "是否创新品类",
+        "is_key_brand":                                     "是否重点品牌",
+        "foster_level_hy":                                  "是否行业共育品规",
+        "foster_level_sj":                                  "是否省级共育品规",
+        "is_cigar":                                         "是否雪茄型卷烟",
+        "co_qty":                                           "一氧化碳含量",
+        "tbc_total_length":                                 "烟支总长度",
+        "tbc_length":                                       "烟支长度",
+        "filter_length":                                    "滤嘴长度",
     }

+ 6 - 4
models/rank/data/preprocess.py

@@ -7,8 +7,8 @@ from sklearn.utils import shuffle
 import numpy as np
 
 class DataProcess():
-    def __init__(self, city_uuid):
-        self._save_res_path = "./models/rank/data/gbdt_data.csv"
+    def __init__(self, city_uuid, save_path):
+        self._save_res_path = save_path
         print("正在加载cust_info...")
         self._cust_data = load_cust_data_from_mysql(city_uuid)
         print("正在加载product_info...")
@@ -117,7 +117,7 @@ class DataProcess():
         self._train_data = self._train_data.join(product_feats, on="product_code", how="left")
         
         self._train_data = shuffle(self._train_data, random_state=42)
-        
+
         self._train_data.to_csv(self._save_res_path, index=False)
     
     def _descartes(self):
@@ -157,5 +157,7 @@ class DataProcess():
         self._train_data.to_csv(self._save_res_path, index=False)
     
 if __name__ == '__main__':
-    processor = DataProcess("00000000000000000000000011445301")
+    city_uuid = "00000000000000000000000011445301"
+    save_path = "./models/rank/data/gbdt_data.csv"
+    processor = DataProcess(city_uuid, save_path)
     processor.data_process()

+ 13 - 9
models/rank/gbdt_lr_sort.py

@@ -1,6 +1,6 @@
 import joblib
 from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
-from models.rank.data import ProductConfig, CustConfig
+from models.rank.data import ProductConfig, CustConfig, ImportanceFeaturesMap
 from models.rank.data.utils import one_hot_embedding, sample_data_clear
 import pandas as pd
 from sklearn.preprocessing import StandardScaler
@@ -70,8 +70,8 @@ class GbdtLrModel:
             self.recommend_list.append({cust_id: float(score)})
             
         self.recommend_list = sorted(self.recommend_list, key=lambda x: list(x.values())[0], reverse=True)
-        for res in self.recommend_list[:200]:
-            print(res)
+        # for res in self.recommend_list[:200]:
+        #     print(res)
         return self.recommend_list
     
     def generate_feats_importance(self):
@@ -100,10 +100,14 @@ class GbdtLrModel:
         sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
         
         # 输出特征重要性
-        features_importance = []
+        cust_features_importance = []
+        product_features_importance = []
         for feat, importance in sorted_importance:
-            features_importance.append({feat: float(importance)})
-        return features_importance
+            if feat in list(ImportanceFeaturesMap.CUSTOM_FEATRUES_MAP.keys()):
+                cust_features_importance.append({ImportanceFeaturesMap.CUSTOM_FEATRUES_MAP[feat]: float(importance)})
+            if feat in list(ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP.keys()):
+                product_features_importance.append({ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP[feat]: float(importance)})
+        return cust_features_importance, product_features_importance
     
 if __name__ == "__main__":
     model_path = "./models/rank/weights/model.pkl"
@@ -112,6 +116,6 @@ if __name__ == "__main__":
     gbdt_sort = GbdtLrModel(model_path)
     gbdt_sort.sort(city_uuid, product_id)
     
-    # importances = gbdt_sort.generate_feats_importance()
-    # for importance in importances:
-    #     print(importance)
+    importances = gbdt_sort.generate_feats_importance()
+    for importance in importances:
+        print(importance)

+ 24 - 1
烟草模型部署文档.md

@@ -39,7 +39,8 @@ redis:
     parser.add_argument("--run_all", action='store_true')
     parser.add_argument("--run_hot", action='store_true')
     parser.add_argument("--run_itemcf", action='store_true')
-    parser.add_argument("--run_itemcf_inference", action='store_true'
+    parser.add_argument("--run_itemcf_inference", action='store_true')
+    parser.add_argument("--city_uuid", type=str, help="City UUID for filtering data")
 ```
 
 ### 总共有4种启动模式分别是:
@@ -49,6 +50,28 @@ redis:
         3. 启动协同过滤  
         4. 启动系统过滤推理
 
+## 2、模型启动配置说明:
+
+### gbdt_lr.py
+
+```
+    parser.add_argument("--run_train", action='store_true')
+    parser.add_argument("--recommend", action='store_true')
+    parser.add_argument("--importance", action='store_true')
+
+    parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301')
+    parser.add_argument("--product_id", type=str, default='110102')
+```
+
+### gbdt_lr总共3个功能:
+
+1\. 启动gbdt_lr训练  python -m gbdt_lr --run_train --city_uuid "00000000000000000000000011445301"
+        2. 根据城市id和product_id进行推荐,需要指定city_uuid、product_id。      python -m gbdt_lr --recommend --city_uuid "00000000000000000000000011445301" --product_id '110102'
+        3. 获取指定城市的特征重要性指标。  python -m gbdt_lr --importance --city_uuid "00000000000000000000000011445301"
+        注意:在数据准备阶段,会将训练数据保存到./models/rank/data/gbdt_data.csv中
+              模型文件会存放在 ./models/rank/weights/city_uuid/model.pkl
+              重要性指标会存放在 ./models/rank/weights/下,分别是商户指标重要性和卷烟指标重要性
+
 ## 3、模型docker运行配置说明:
 
 ### docker镜像是:registry.cn-hangzhou.aliyuncs.com/hexiaoshi/brandcultivation:0.0.1