Browse Source

完善gbdt_lr排序功能

Sherlock 1 year ago
parent
commit
d5e15f6886

+ 4 - 1
models/rank/data/__init__.py

@@ -1,7 +1,10 @@
 from models.rank.data.config import CustConfig, ProductConfig
 from models.rank.data.dataloader import DataLoader
+from models.rank.data.utils import one_hot_embedding, sample_data_clear
 __all__ = [
     "CustConfig",
     "ProductConfig",
-    "DataLoader"
+    "DataLoader",
+    "one_hot_embedding",
+    "sample_data_clear"
 ]

+ 45 - 8
models/rank/data/config.py

@@ -25,19 +25,19 @@ class CustConfig:
     # 数据清洗规则
     CLEANING_RULES = {
         "BB_RTL_CUST_POSITION_TYPE_NAME":           {"method": "fillna", "opt": "fill", "value": "其他", "type": "str"},
+        "BB_RTL_CUST_MARKET_TYPE_NAME":             {"method": "fillna", "opt": "fill", "value": "城网", "type": "str"},
         "BB_RTL_CUST_SUB_BUSI_PLACE_NAME":          {"method": "fillna", "opt": "fill", "value": "其他", "type": "str"},
-        "BB_RTL_CUST_GRADE_NAME":                   {"method": "fillna", "opt": "fill", "value": "其他", "type": "str"},
+        "BB_RTL_CUST_GRADE_NAME":                   {"method": "fillna", "opt": "fill", "value": "十五档", "type": "str"},
         # "BB_RTL_CUST_TERMINALEVEL_NAME":          {"method": "fillna", "opt": "replace", "value": "BB_RTL_CUST_TERMINAL_LEVEL_NAME", "type": "str"},
         # "MD04_MG_RTL_CUST_CREDITCLASS_NAME":        {"method": "fillna", "opt": "fill", "value": "未评价", "type": "str"},
         # "MD04_MG_SAMPLE_CUST_FLAG":                 {"method": "fillna", "value": "N", "opt": "fill"},
         # "MD07_RTL_CUST_IS_SALE_LARGE_FLAG":         {"method": "fillna", "value": "N", "opt": "fill"},
-        "BB_RTL_CUST_CHAIN_FLAG":                   {"type": "fillna", "opt": "fill", "value": "0", "type": "str"},
         # "BB_RTL_CUST_CGT_OPERATE_SCOPE_NAME":       {"method": "fillna", "value": "中", "opt": "fill"},
-        "BB_RTL_CUST_CHAIN_FLAG":                   {"method": "fillna", "opt": "fill", "value": "0", "type": "str"},
-        "MD04_DIR_SAL_STORE_FLAG":                  {"method": "fillna", "opt": "fill", "value": "0", "type": "str"},
+        "BB_RTL_CUST_CHAIN_FLAG":                   {"method": "fillna", "opt": "fill", "value": "", "type": "str"},
+        "MD04_DIR_SAL_STORE_FLAG":                  {"method": "fillna", "opt": "fill", "value": "", "type": "str"},
         "STORE_AREA":                               {"method": "fillna", "opt": "mean", "type": "num"},
         "OPERATOR_AGE":                             {"method": "fillna", "opt": "mean", "type": "num"},
-        "OPERATOR_EDU_LEVEL":                       {"method": "fillna", "opt": "fill", "value": "00", "type": "str"},
+        "OPERATOR_EDU_LEVEL":                       {"method": "fillna", "opt": "fill", "value": "01", "type": "str"},
     }
     # one-hot编码
     ONEHOT = [
@@ -50,6 +50,20 @@ class CustConfig:
         "OPERATOR_EDU_LEVEL",
     ]
     
+    ONEHOT_CAT = {
+        "BB_RTL_CUST_POSITION_TYPE_NAME":           ["居民区", "商业娱乐区", "交通枢纽区", "旅游景区", "工业区", "集贸区", "院校学区", "办公区", "其他"],
+        "BB_RTL_CUST_MARKET_TYPE_NAME":             ["城网", "农网"],
+        "BB_RTL_CUST_SUB_BUSI_PLACE_NAME":          ["便利店", "超市", "烟草专业店", "娱乐服务类", "其他"],
+        "BB_RTL_CUST_GRADE_NAME":                   ['一档', '二档', '三档', '四档', '五档', '六档', '七档', '八档', '九档', '十档', '十一档', '十二档', 
+                                                    '十三档', '十四档', '十五档', '十六档', '十七档', '十八档', '十九档', '二十档', '二十一档', '二十二档', 
+                                                    '二十三档', '二十四档', '二十五档', '二十六档', '二十七档', '二十八档', '二十九档', '三十档'],
+        "BB_RTL_CUST_CHAIN_FLAG":                   ["是", "否"],
+        "MD04_DIR_SAL_STORE_FLAG":                  ["是", "否"],
+        "OPERATOR_EDU_LEVEL":                       [1, 2, 3, 4, 5, 6, 7]
+    }
+    
+    
+    
 class ProductConfig:
     FEATURE_COLUMNS = [
         "product_code",                                # 商品编码
@@ -93,7 +107,7 @@ class ProductConfig:
         "allot_price":                                 {"method": "fillna", "opt": "fill", "type": "num", "value": 0.0},
         "direct_whole_price":                          {"method": "fillna", "opt": "mean", "type": "num"},
         "retail_price":                                {"method": "fillna", "opt": "mean", "type": "num"},
-        "price_type_name":                             {"method": "fillna", "opt": "fill", "type": "str", "value": "其他"},
+        "price_type_name":                             {"method": "fillna", "opt": "fill", "type": "str", "value": "一类烟"},
         "gear_type_name":                              {"method": "fillna", "opt": "fill", "type": "str", "value": "其他"},
         "category_type_name":                          {"method": "fillna", "opt": "fill", "type": "str", "value": "其他"},
         "is_key_brand":                                {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
@@ -113,7 +127,7 @@ class ProductConfig:
         "tar_qty":                                     {"method": "fillna", "opt": "mean", "type": "num"},
         "product_style_code_name":                     {"method": "fillna", "opt": "fill", "type": "str", "value": "其他"},
         "chinese_mix":                                 {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
-        "sub_price_type_name":                         {"method": "fillna", "opt": "fill", "type": "str", "value": "其他"},
+        "sub_price_type_name":                         {"method": "fillna", "opt": "fill", "type": "str", "value": "普一类烟"},
     }
     
     ONEHOT = [
@@ -137,4 +151,27 @@ class ProductConfig:
         "product_style_code_name",                     # 包装类型名称
         "chinese_mix",                                 # 中式混合
         "sub_price_type_name",                         # 细分卷烟价类名称
-    ]
+    ]
+    ONEHOT_CAT = {
+        "price_type_name":                             ["一类烟", "二类烟", "三类烟", "四类烟", "五类烟", "无价类"],
+        "gear_type_name":                              ["第一档位", "第二档位", "第三档位", "第四档位", "第五档位", "第六档位", "第七档位", "第八档位", "其他"],
+        "category_type_name":                          ["第1品类", "第2品类", "第3品类", "第4品类", "第5品类", "第6品类", "第7品类", 
+                                                        "第8品类", "第9品类", "第10品类", "第11品类", "第12品类", "第13品类", "其他"],
+        "is_key_brand":                                ["是", "否"],
+        "is_high_level":                               ["是", "否"],
+        "is_upscale_level":                            ["是", "否"],
+        "is_high_price":                               ["是", "否"],
+        "is_low_price":                                ["是", "否"],
+        "is_low_tar":                                  ["是", "否"],
+        "is_encourage":                                ["是", "否"],
+        "is_abnormity":                                ["是", "否"],
+        "is_intake":                                   ["是", "否"],
+        "is_short":                                    ["是", "否"],
+        "is_medium":                                   ["是", "否"],
+        "is_shortbranch":                              ["是", "否"],
+        "is_ordinary_price_type":                      ["是", "否"],
+        "source_type":                                 ["是", "否"],
+        "product_style_code_name":                     ["条盒硬盒", "条包硬盒", "条盒软盒", "条包软盒", "铁盒", "其他"],
+        "chinese_mix":                                 ["是", "否"],
+        "sub_price_type_name":                         ["高端烟", "高价位烟", "普一类烟", "二类烟", "三类烟", "四类烟", "五类烟", "无价类"],
+    }

+ 7 - 7
models/rank/data/dataloader.py

@@ -3,6 +3,7 @@ from models.rank.data.config import CustConfig, ProductConfig
 from sklearn.preprocessing import OneHotEncoder
 from sklearn.model_selection import train_test_split
 from sklearn.preprocessing import StandardScaler
+from models.rank.data.utils import one_hot_embedding
 
 class DataLoader:
     def __init__(self,path):
@@ -10,20 +11,18 @@ class DataLoader:
         self._load_data()
     
     def _load_data(self):
+       
         self._gbdt_data = pd.read_csv(self._gbdt_data_path, encoding="utf-8")
         self._gbdt_data.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
         self._gbdt_data.drop('product_code', axis=1, inplace=True)
         
-        self._onehot_columns = CustConfig.ONEHOT + ProductConfig.ONEHOT
+        self._onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
+        
+        self._onehot_columns = list(self._onehot_feats.keys())
         self._numeric_columns = self._gbdt_data.drop(self._onehot_columns + ["label"], axis=1).columns
         
         # 将类别数据进行one-hot编码
-        self.one_hot_embedding(self._onehot_columns)
-        
-    def one_hot_embedding(self, onehot_columns):
-        """对指定的特征进行onehot编码"""
-        self._gbdt_data = pd.get_dummies(self._gbdt_data, columns=onehot_columns, drop_first=False)
-        
+        self._gbdt_data = one_hot_embedding(self._gbdt_data, self._onehot_feats)
         
     
     def split_dataset(self):
@@ -39,6 +38,7 @@ class DataLoader:
         scaler = StandardScaler()
         X_train[self._numeric_columns] = scaler.fit_transform(X_train[self._numeric_columns])
         X_test[self._numeric_columns] = scaler.fit_transform(X_test[self._numeric_columns])
+        print(X_test["notwithtax_adjust_price"])
         
         train_dataset = {"data": X_train, "label": y_train}
         test_dataset = {"data": X_test, "label": y_test}

+ 24 - 0
models/rank/data/utils.py

@@ -0,0 +1,24 @@
+import pandas as pd
+def one_hot_embedding(dataframe, onehout_feat):
+    """对数据的指定特征做embedding编码"""
+    # 先将指定的特征进行Categorical处理
+    for feat, categories in onehout_feat.items():
+        dataframe[feat] = pd.Categorical(dataframe[feat], categories=categories, ordered=False)
+    dataframe = pd.get_dummies(
+        dataframe,
+        columns=list(onehout_feat.keys()),
+        prefix_sep="_",
+        dtype=int
+    )
+    return dataframe
+
+def sample_data_clear(data, config):
+    for feature, rules, in config.CLEANING_RULES.items():
+        if rules["type"] == "num":
+            data[feature] = pd.to_numeric(data[feature], errors="coerce")
+        if rules["method"] == "fill":
+            if rules["type"] == "str":
+                data[feature] = data[feature].fillna(rules["value"])
+            elif rules["type"] == "num":
+                data[feature] = data[feature].fillna(0.0)
+    return data

+ 39 - 4
models/rank/gbdt_lr_sort.py

@@ -1,6 +1,11 @@
 import joblib
 from dao import Redis, get_product_by_id, get_custs_by_ids
 from models.rank.data import ProductConfig, CustConfig
+from models.rank.data.utils import one_hot_embedding, sample_data_clear
+import pandas as pd
+from sklearn.preprocessing import StandardScaler
+
+
 class GbdtLrSort:
     def __init__(self, model_path):
         self.load_model(model_path)
@@ -8,7 +13,7 @@ class GbdtLrSort:
     
     def load_model(self, model_path):
         models = joblib.load(model_path)
-        self.gbdt_model = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
+        self.gbdt_model, self.lr_model, self.onehot_encoder = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
         
     
     def get_recall_list(self, city_uuid, product_id):
@@ -19,17 +24,47 @@ class GbdtLrSort:
     def load_recall_data(self, city_uuid, product_id):
         self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
         self.custs_data = get_custs_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
-        print(self.product_data)
     
     def generate_feats_map(self, city_uuid, product_id):
         """组合卷烟、商户特征矩阵"""
         self.get_recall_list(city_uuid, product_id)
         self.load_recall_data(city_uuid, product_id)
         # 做数据清洗
+        self.product_data = sample_data_clear(self.product_data, ProductConfig)
+        self.custs_data = sample_data_clear(self.custs_data, CustConfig)
         
+        # 笛卡尔积联合
+        self.custs_data["descartes"] = 1
+        self.product_data["descartes"] = 1
+        self.feats_map = pd.merge(self.custs_data, self.product_data, on="descartes").drop("descartes", axis=1)
+        self.recall_cust_list = self.feats_map["BB_RETAIL_CUSTOMER_CODE"].to_list()
+        self.feats_map.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
+        self.feats_map.drop('product_code', axis=1, inplace=True)
+        
+        # onehot编码
+        onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
+        onehot_columns = list(onehot_feats.keys())
+        numeric_columns = self.feats_map.drop(onehot_columns, axis=1).columns
+        self.feats_map = one_hot_embedding(self.feats_map, onehot_feats)
+        
+        # 数字特征归一化
+        scaler = StandardScaler()
+        self.feats_map[numeric_columns] = scaler.fit_transform(self.feats_map[numeric_columns])
     
     def sort(self, city_uuid, product_id):
-        pass
+        self.generate_feats_map(city_uuid, product_id)
+        
+        gbdt_preds = self.gbdt_model.apply(self.feats_map)[:, :, 0]
+        gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
+        scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1]
+        
+        self.recommend_list = []
+        for cust_id, score in zip(self.recall_cust_list, scores):
+            self.recommend_list.append({cust_id: float(score)})
+            
+        self.recommend_list = sorted(self.recommend_list, key=lambda x: list(x.values())[0], reverse=True)
+        for res in self.recommend_list:
+            print(res)
     
     def generate_feats_importance(self):
         pass
@@ -39,4 +74,4 @@ if __name__ == "__main__":
     city_uuid = "00000000000000000000000011445301"
     product_id = "110102"
     gbdt_sort = GbdtLrSort(model_path)
-    gbdt_sort.generate_feats_map(city_uuid, product_id)
+    gbdt_sort.sort(city_uuid, product_id)