Переглянути джерело

增加sample输入进行gbdt推理

WarriorMin 11 місяців тому
батько
коміт
2d130c83ee

+ 1 - 1
database/dao/mysql_dao.py

@@ -50,7 +50,7 @@ class MySqlDao:
         params = {"city_uuid": city_uuid}
         
         data = self.db_helper.load_data_with_page(query, params)
-        # data.drop('stat_month', axis=1, inplace=True)
+        data.drop('stat_month', axis=1, inplace=True)
         data.drop('city_uuid', axis=1, inplace=True)
         
         return data

+ 1 - 1
database/db/mysql.py

@@ -110,4 +110,4 @@ class MySqlDatabaseHelper:
             session.rollback()
             print(f"Error: {e}")
         finally:
-            session.close()
+            session.close()

+ 27 - 3
inference.py

@@ -83,10 +83,34 @@ def run():
     pass
 
 if __name__ == '__main__':
-    generate_features_shap("00000000000000000000000011445301", "350139", delivery_count=5000)
+    # generate_features_shap("00000000000000000000000011445301", "420202", delivery_count=5000)
     # recommend_list = get_recommend_list("00000000000000000000000011445301", "420202")
     # recommend_list = pd.DataFrame(recommend_list)
     # recommend_list.to_csv("./data/recommend_list.csv", index=False, encoding="utf-8-sig")
-    # data = dao.get_order_by_cust("00000000000000000000000011445301", "445381107139")
+    
+    # 拿龙军数据
+    # data = dao.get_order_by_cust("00000000000000000000000011445301", "445323105795")
     # data = data.groupby(["cust_code", "product_code", "product_name"], as_index=False)["sale_qty"].sum()
-    # data.to_csv("./data/cust.csv", index=False)
+    # data.to_csv("./data/cust.csv", index=False)
+    
+    city_uuid = "00000000000000000000000011445301"
+    order_data = dao.get_order_by_cust("00000000000000000000000011445301", "445323105795")
+    order_data["sale_qty"] = order_data["sale_qty"].fillna(0)
+    order_data = order_data.infer_objects(copy=False)
+    order_data = order_data.groupby(["cust_code", "product_code", "product_name"], as_index=False)["sale_qty"].sum()
+    
+    cust_data = dao.load_cust_data(city_uuid)[CustConfig.FEATURE_COLUMNS]
+    sample_data_clear(cust_data, CustConfig)
+    shop_data = dao.load_shopping_data(city_uuid)[ShopConfig.FEATURE_COLUMNS]
+    sample_data_clear(shop_data, ShopConfig)
+    cust_ids = shop_data.set_index("cust_code")
+    cust_data = cust_data.join(cust_ids, on="BB_RETAIL_CUSTOMER_CODE", how="inner")
+    
+    product_data = dao.load_product_data(city_uuid)[ProductConfig.FEATURE_COLUMNS]
+    sample_data_clear(product_data, ProductConfig)
+    
+    order_data = order_data.merge(product_data, on="product_code", how="inner")
+    order_data = order_data.merge(cust_data, left_on='cust_code', right_on='BB_RETAIL_CUSTOMER_CODE', how="inner")
+    
+    result = gbdtlr_model.inference_from_sample(order_data)
+    result.to_csv("./data/junlong.csv", index=False)

+ 22 - 0
models/rank/gbdt_lr_inference.py

@@ -74,6 +74,28 @@ class GbdtLrModel:
         )
         return recommend_list
     
+    def inference_from_sample(self, sample):
+        inference_sample = sample.drop(columns=["BB_RETAIL_CUSTOMER_CODE", "product_code", "sale_qty", "product_name", "cust_code"])
+        
+        onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}
+        onehot_columns = list(onehot_feats.keys())
+        numeric_columns = inference_sample.drop(onehot_columns, axis=1).columns
+        inference_sample = one_hot_embedding(inference_sample, onehot_feats)
+        print(numeric_columns)
+        # 数字特征归一化
+        if len(numeric_columns) != 0:
+            scaler = StandardScaler()
+            inference_sample[numeric_columns] = scaler.fit_transform(inference_sample[numeric_columns])
+        
+        gbdt_preds = self.gbdt_model.predict(inference_sample, pred_leaf=True)
+        gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
+        scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1]
+        
+        sample["score"] = scores
+        
+        return sample[["cust_code", "product_code", "product_name", "sale_qty", "score"] + ProductConfig.FEATURE_COLUMNS]
+        
+    
     def generate_feats_importance(self):
         """生成特征重要性"""
         # 获取GBDT模型的特征重要性

+ 1 - 1
utils/result_process.py

@@ -110,5 +110,5 @@ def get_cust_list_from_history_order(city_uuid, product_code):
     return merge_data
         
 if __name__ == "__main__":
-    order_data = get_cust_list_from_history_order("00000000000000000000000011445301", "350355")
+    order_data = get_cust_list_from_history_order("00000000000000000000000011445301", "420202")
     order_data.to_csv("./data/eval.csv", index=False)