Bläddra i källkod

特征相关性功能开发

Sherlock 11 månader sedan
förälder
incheckning
122601aa09
1 ändrade filer med 86 tillägg och 16 borttagningar
  1. 86 16
      models/rank/gbdt_lr_sort.py

+ 86 - 16
models/rank/gbdt_lr_sort.py

@@ -1,10 +1,13 @@
 import joblib
 # from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
 from database import RedisDatabaseHelper, MySqlDao
+from models.rank.data import DataLoader
 from models.rank.data import ProductConfig, CustConfig, ShopConfig, ImportanceFeaturesMap
 from models.rank.data.utils import one_hot_embedding, sample_data_clear
+import numpy as np
 import pandas as pd
 from sklearn.preprocessing import StandardScaler
+import shap
 import os
 
 
@@ -13,6 +16,7 @@ class GbdtLrModel:
         self.load_model(model_path)
         self.redis = RedisDatabaseHelper().redis
         self._mysql_dao = MySqlDao()
+        self._explanier = None
     
     def load_model(self, model_path):
         models = joblib.load(model_path)
@@ -61,7 +65,7 @@ class GbdtLrModel:
         scaler = StandardScaler()
         self.feats_map[numeric_columns] = scaler.fit_transform(self.feats_map[numeric_columns])
     
-    def sort(self, city_uuid, product_id):
+    def recommend_sort(self, city_uuid, product_id):
         self.generate_feats_map(city_uuid, product_id)
         
         gbdt_preds = self.gbdt_model.apply(self.feats_map)[:, :, 0]
@@ -81,16 +85,14 @@ class GbdtLrModel:
         """生成特征重要性"""
         # 获取GBDT模型的特征重要性
         feats_importance = self.gbdt_model.feature_importances_
-        
         # 获取特征名称
         feats_names = self.gbdt_model.feature_names_in_
-        
         importance_dict = dict(zip(feats_names, feats_importance))
         
         onehot_feats = {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
         for feat, categories in onehot_feats.items():
+            related_columns = [f"{feat}_{item}" for item in categories]
             
-            related_columns = [col for col in feats_names if col.startswith(feat)]
             if related_columns:
                 # 合并类别重要性
                 combined_importance = sum(importance_dict[col] for col in related_columns)
@@ -117,6 +119,71 @@ class GbdtLrModel:
                 
         return cust_features_importance, product_features_importance
     
+    def generate_shap_interance(self, data):
+        # 初始化SHAP解释器
+        if self._explanier is None:
+            self._explanier = shap.TreeExplainer(self.gbdt_model)
+        
+        # 计算SHAP交互值(限制样本数量以提高性能)
+        shap_interaction = self._explanier.shap_interaction_values(data)
+        
+        # 取平均交互值(绝对值)
+        mean_interaction = np.abs(shap_interaction).mean(0)
+        
+        # 构建交互矩阵DataFrame
+        interaction_df = pd.DataFrame(
+            mean_interaction,
+            index=data.columns,
+            columns=data.columns
+        )
+        
+        # 分离卷烟和商户特征(修正点1:直接生成特征名列表)
+        product_feats = [
+            f"{feat}_{item}" 
+            for feat, categories in ProductConfig.ONEHOT_CAT.items()
+            for item in categories
+        ]
+        
+        cust_feats = [
+            f"{feat}_{item}"
+            for feat, categories in {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}.items()
+            for item in categories
+        ]
+        
+        # 提取交叉区块(修正点2:使用.loc[]和列表索引)
+        cross_matrix = interaction_df.loc[product_feats, cust_feats] 
+        
+        # 1. 将矩阵转换为长格式(行:卷烟特征,列:商户特征,值:SHAP交互值)
+        stacked = cross_matrix.stack().reset_index()
+        stacked.columns = ['product_feat', 'cust_feat', 'relation']
+        
+        # 2. 过滤掉零值或NaN的配对
+        filtered = stacked[
+            (stacked['relation'].abs() > 1e-6) &  # 排除极小值
+            (~stacked['relation'].isna())         # 排除NaN
+        ].copy()
+        
+        # 3. 转换为字典列表并按relation降序排序
+        results = (
+            filtered
+            .sort_values(['product_feat', 'relation'], ascending=[False, False])
+            .to_dict('records')
+        )
+        
+        # 4. 替换名字
+        feats_name_map = {**ImportanceFeaturesMap.CUSTOM_FEATURES_MAP, **ImportanceFeaturesMap.SHOPING_FEATURES_MAP, **ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP}
+        for item in results:
+            product_f = item["product_feat"]
+            product_infos = product_f.split("_")
+            item["product_feat"] = f"{feats_name_map['_'.join(product_infos[:-1])]}({product_infos[-1]})"
+            
+            cust_f = item["cust_feat"]
+            cust_infos = cust_f.split("_")
+            item["cust_feat"] = f"{feats_name_map['_'.join(cust_infos[:-1])]}({cust_infos[-1]})"
+        
+        results = pd.DataFrame(results, columns=['product_feat', 'cust_feat', 'relation']) 
+        return results
+    
 if __name__ == "__main__":
     model_path = "./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl"
     city_uuid = "00000000000000000000000011445301"
@@ -124,17 +191,20 @@ if __name__ == "__main__":
     gbdt_sort = GbdtLrModel(model_path)
     # gbdt_sort.sort(city_uuid, product_id)
     
-    cust_features_importance,  product_features_importance = gbdt_sort.generate_feats_importance()
+    # cust_features_importance,  product_features_importance = gbdt_sort.generate_feats_importance()
 
-    cust_df = pd.DataFrame([
-        {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
-        for item in cust_features_importance
-    ])
-    cust_df.to_csv("./data/cust_feats.csv", index=False)
+    # cust_df = pd.DataFrame([
+    #     {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
+    #     for item in cust_features_importance
+    # ])
+    # cust_df.to_csv("./data/cust_feats.csv", index=False)
+    
+    # product_df = pd.DataFrame([
+    #     {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
+    #     for item in product_features_importance
+    # ])
+    # product_df.to_csv("./data/product_feats.csv", index=False)
+    _, data = DataLoader("./data/gbdt/train_data.csv").split_dataset()
+    result = gbdt_sort.generate_shap_interance(data["data"][:2000])
     
-    product_df = pd.DataFrame([
-        {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
-        for item in product_features_importance
-    ])
-    product_df.to_csv("./data/product_feats.csv", index=False)
-    
+    result.to_csv("./data/feats_interaction.csv", index=False, encoding='utf-8-sig')