Sherlock před 1 rokem
rodič
revize
1b1eea75f9

+ 36 - 5
models/rank/data/config.py

@@ -24,9 +24,9 @@ class CustConfig:
     ]
     # 数据清洗规则
     CLEANING_RULES = {
-        "BB_RTL_CUST_POSITION_TYPE_NAME":           {"method": "fillna", "opt": "fill", "value": "其", "type": "str"},
-        "BB_RTL_CUST_MARKET_TYPE_NAME":             {"method": "fillna", "opt": "fill", "value": "其它", "type": "str"},
-        "BB_RTL_CUST_SUB_BUSI_PLACE_NAME":          {"method": "fillna", "opt": "fill", "value": "其它", "type": "str"},
+        "BB_RTL_CUST_POSITION_TYPE_NAME":           {"method": "fillna", "opt": "fill", "value": "其", "type": "str"},
+        "BB_RTL_CUST_SUB_BUSI_PLACE_NAME":          {"method": "fillna", "opt": "fill", "value": "其他", "type": "str"},
+        "BB_RTL_CUST_GRADE_NAME":                   {"method": "fillna", "opt": "fill", "value": "其他", "type": "str"},
         # "BB_RTL_CUST_TERMINALEVEL_NAME":          {"method": "fillna", "opt": "replace", "value": "BB_RTL_CUST_TERMINAL_LEVEL_NAME", "type": "str"},
         # "MD04_MG_RTL_CUST_CREDITCLASS_NAME":        {"method": "fillna", "opt": "fill", "value": "未评价", "type": "str"},
         # "MD04_MG_SAMPLE_CUST_FLAG":                 {"method": "fillna", "value": "N", "opt": "fill"},
@@ -40,7 +40,15 @@ class CustConfig:
         "OPERATOR_EDU_LEVEL":                       {"method": "fillna", "opt": "fill", "value": "00", "type": "str"},
     }
     # one-hot编码
-    
+    ONEHOT = [
+        "BB_RTL_CUST_POSITION_TYPE_NAME",
+        "BB_RTL_CUST_MARKET_TYPE_NAME",
+        "BB_RTL_CUST_SUB_BUSI_PLACE_NAME",
+        "BB_RTL_CUST_GRADE_NAME",
+        "BB_RTL_CUST_CHAIN_FLAG",
+        "MD04_DIR_SAL_STORE_FLAG",
+        "OPERATOR_EDU_LEVEL",
+    ]
     
 class ProductConfig:
     FEATURE_COLUMNS = [
@@ -106,4 +114,27 @@ class ProductConfig:
         "product_style_code_name":                     {"method": "fillna", "opt": "fill", "type": "str", "value": "其他"},
         "chinese_mix":                                 {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
         "sub_price_type_name":                         {"method": "fillna", "opt": "fill", "type": "str", "value": "其他"},
-    }
+    }
+    
+    ONEHOT = [
+        "price_type_name",                             # 卷烟价类名称
+        "gear_type_name",                              # 卷烟档位名称
+        "category_type_name",                          # 卷烟品类名称
+        "is_key_brand",                                # 是否重点品牌
+        "is_high_level",                               # 是否高端烟
+        "is_upscale_level",                            # 是否高端烟不含高价
+        "is_high_price",                               # 是否高价烟
+        "is_low_price",                                # 是否低价烟
+        "is_low_tar",                                  # 是否低焦油烟
+        "is_encourage",                                # 是否全国鼓励品牌
+        "is_abnormity",                                # 是否异形包装
+        "is_intake",                                   # 是否进口烟
+        "is_short",                                    # 是否紧俏品牌
+        "is_medium",                                   # 是否中支烟
+        "is_shortbranch",                              # 是否短支烟
+        "is_ordinary_price_type",                      # 是否普一类烟
+        "source_type",                                 # 来源类型
+        "product_style_code_name",                     # 包装类型名称
+        "chinese_mix",                                 # 中式混合
+        "sub_price_type_name",                         # 细分卷烟价类名称
+    ]

+ 60 - 0
models/rank/data/dataloader.py

@@ -0,0 +1,60 @@
+import pandas as pd
+from models.rank.data.config import CustConfig, ProductConfig
+from sklearn.preprocessing import OneHotEncoder
+from sklearn.model_selection import train_test_split
+from sklearn.preprocessing import StandardScaler
+
+class DataLoader:
+    def __init__(self,path):
+        self._gbdt_data_path = path
+        self._load_data()
+    
+    def _load_data(self):
+        self._gbdt_data = pd.read_csv(self._gbdt_data_path, encoding="utf-8")
+        self._gbdt_data.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
+        self._gbdt_data.drop('product_code', axis=1, inplace=True)
+        
+        self._onehot_columns = CustConfig.ONEHOT + ProductConfig.ONEHOT
+        self._numeric_columns = self._gbdt_data.drop(self._onehot_columns + ["label"], axis=1).columns
+        
+        # 将类别数据进行one-hot编码
+        self.one_hot_embedding(self._onehot_columns)
+        
+    def one_hot_embedding(self, onehot_columns):
+        """对指定的特征进行onehot编码"""
+        self._gbdt_data = pd.get_dummies(self._gbdt_data, columns=onehot_columns, drop_first=False)
+        
+        
+    
+    def split_dataset(self):
+        """数据集划分,将数据集划分为训练集、验证集、测试集"""
+        # 1. 分离特征和标签
+        features = self._gbdt_data.drop("label", axis=1)
+        labels = self._gbdt_data["label"]
+        
+        # 2. 划分数据集,70%训练集、15%验证集、15%测试集
+        X_train, X_temp, y_train, y_temp = train_test_split(features, labels, test_size=0.3, random_state=42, shuffle=True)
+        X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp,test_size=0.5, random_state=42,shuffle=True)
+        
+        # 获取One-Hot编码列和数值型列
+        
+        
+        # 3. 数据标准化(仅对特征进行标准化)
+        scaler = StandardScaler()
+        X_train[self._numeric_columns] = scaler.fit_transform(X_train[self._numeric_columns])
+        X_val[self._numeric_columns] = scaler.fit_transform(X_val[self._numeric_columns])
+        X_test[self._numeric_columns] = scaler.fit_transform(X_test[self._numeric_columns])
+        
+        train_dataset = {"data": X_train, "label": y_train}
+        val_dataset = {"data": X_val, "label": y_val}
+        test_dataset = {"data": X_test, "label": y_test}
+        
+        train_data = pd.DataFrame(X_train, columns=self._gbdt_data.drop('label', axis=1).columns)
+        train_data['label'] = y_train
+        
+        return train_dataset, val_dataset, test_dataset
+    
+if __name__ == '__main__':
+    path = './models/rank/data/gbdt_data.csv'
+    dataloader = DataLoader(path)
+    dataloader.split_dataset()

+ 5 - 1
models/rank/data/preprocess.py

@@ -1,9 +1,11 @@
 from dao.dao import load_cust_data_from_mysql, load_product_data_from_mysql, load_order_data_from_mysql
 from models.rank.data.config import CustConfig, ProductConfig
+import os
 import pandas as pd
 
 class DataProcess():
     def __init__(self, city_uuid):
+        self._save_res_path = "./models/rank/data/gbdt_data.csv"
         print("正在加载cust_info...")
         self._cust_data = load_cust_data_from_mysql(city_uuid)
         print("正在加载product_info...")
@@ -13,6 +15,8 @@ class DataProcess():
         
     def data_process(self):
         """数据预处理"""
+        if os.path.exists(self._save_res_path):
+            os.remove(self._save_res_path)
         
         # 1. 获取指定的特征组合
         self._cust_data = self._cust_data[CustConfig.FEATURE_COLUMNS]
@@ -94,7 +98,7 @@ class DataProcess():
         self._train_data = self._train_data.sample(frac=1, random_state=42).reset_index(drop=True)
         
         # 保存训练数据
-        self._train_data.to_csv("./models/rank/data/gbdt_data.csv", index=False)
+        self._train_data.to_csv(self._save_res_path, index=False)
     
 if __name__ == '__main__':
     processor = DataProcess("00000000000000000000000011445301")