Jelajahi Sumber

处理训练数据缺失值问题

yangzeyu 1 tahun lalu
induk
melakukan
90902e7909
2 mengubah file dengan 9 tambahan dan 8 penghapusan
  1. 1 1
      gbdt_lr.py
  2. 8 7
      models/rank/data/preprocess.py

+ 1 - 1
gbdt_lr.py

@@ -86,7 +86,7 @@ def run():
     parser.add_argument("--recommend", action='store_true')
     parser.add_argument("--importance", action='store_true')
     
-    parser.add_argument("--train_data_path", type=str, default="./models/rank/data/gbdt_data.csv")
+    parser.add_argument("--train_data_path", type=str, default="./models/rank/train_data/gbdt_data.csv")
     parser.add_argument("--model_path", type=str, default="./models/rank/weights")
     parser.add_argument("--model_name", type=str, default='model.pkl')
     parser.add_argument("--last_n", type=int, default=200)

+ 8 - 7
models/rank/data/preprocess.py

@@ -49,6 +49,7 @@ class DataProcess():
     
     def _clean_cust_data(self):
         """用户信息表数据清洗"""
+        self._cust_data["BB_RETAIL_CUSTOMER_CODE"] = self._cust_data["BB_RETAIL_CUSTOMER_CODE"].astype(str)
         # 根据配置规则清洗数据
         for feature, rules, in CustConfig.CLEANING_RULES.items():
             if rules["type"] == "num":
@@ -66,6 +67,7 @@ class DataProcess():
     
     def _clean_product_data(self):
         """卷烟信息表数据清洗"""
+        self._product_data["product_code"] = self._product_data["product_code"].astype(str)
         for feature, rules, in ProductConfig.CLEANING_RULES.items():
             if rules["type"] == "num":
                 self._product_data[feature] = pd.to_numeric(self._product_data[feature], errors="coerce")
@@ -78,7 +80,8 @@ class DataProcess():
                 self._product_data[feature] = self._product_data[feature].infer_objects(copy=False)
                     
     def _clean_order_data(self):
-        pass
+        self._order_data["BB_RETAIL_CUSTOMER_CODE"] = self._order_data["BB_RETAIL_CUSTOMER_CODE"].astype(str)
+        self._order_data["PRODUCT_CODE"] = self._order_data["PRODUCT_CODE"].astype(str)
     
     def _calculate_score(self):
         """计算order记录的fens"""
@@ -113,8 +116,8 @@ class DataProcess():
         
         self._train_data = self._order_score.copy()
         
-        self._train_data = self._train_data.join(cust_feats, on="BB_RETAIL_CUSTOMER_CODE", how="left")
-        self._train_data = self._train_data.join(product_feats, on="product_code", how="left")
+        self._train_data = self._train_data.join(cust_feats, on="BB_RETAIL_CUSTOMER_CODE", how="inner")
+        self._train_data = self._train_data.join(product_feats, on="product_code", how="inner")
         
         self._train_data = shuffle(self._train_data, random_state=42)
 
@@ -144,8 +147,6 @@ class DataProcess():
         
         positive_count = len(positive_samples)
         negative_count = min(1 * positive_count, len(negative_samples))
-        print(positive_count)
-        print(negative_count)
         
         # 随机抽取2倍正样本数量的负样本
         negative_samples_sampled = negative_samples.sample(n=negative_count, random_state=42)
@@ -157,7 +158,7 @@ class DataProcess():
         self._train_data.to_csv(self._save_res_path, index=False)
     
 if __name__ == '__main__':
-    city_uuid = "00000000000000000000000011445301"
-    save_path = "./models/rank/data/gbdt_data.csv"
+    city_uuid = "00000000000000000000000011441801"
+    save_path = "./models/rank/train_data/gbdt_data.csv"
     processor = DataProcess(city_uuid, save_path)
     processor.data_process()