Ver Fonte

完善特征重要性功能,以batch的形式计算,优化内存占用

yangzeyu há 11 meses atrás
pai
commit
5f325133fd
3 ficheiros alterados com 138 adições e 64 exclusões
  1. 121 64
      models/rank/gbdt_lr_sort.py
  2. 5 0
      utils/__init__.py
  3. 12 0
      utils/result_process.py

+ 121 - 64
models/rank/gbdt_lr_sort.py

@@ -8,8 +8,10 @@ import numpy as np
 import pandas as pd
 from sklearn.preprocessing import StandardScaler
 import shap
+from tqdm import tqdm
+from utils import split_relation_subtable
 import os
-
+import tempfile
 
 class GbdtLrModel:
     def __init__(self, model_path):
@@ -124,65 +126,116 @@ class GbdtLrModel:
         if self._explanier is None:
             self._explanier = shap.TreeExplainer(self.gbdt_model)
         
-        # 计算SHAP交互值(限制样本数量以提高性能)
-        shap_interaction = self._explanier.shap_interaction_values(data)
-        
-        # 取平均交互值(绝对值)
-        mean_interaction = np.abs(shap_interaction).mean(0)
-        
-        # 构建交互矩阵DataFrame
-        interaction_df = pd.DataFrame(
-            mean_interaction,
-            index=data.columns,
-            columns=data.columns
-        )
-        
-        # 分离卷烟和商户特征(修正点1:直接生成特征名列表)
-        product_feats = [
-            f"{feat}_{item}" 
-            for feat, categories in ProductConfig.ONEHOT_CAT.items()
-            for item in categories
-        ]
-        
-        cust_feats = [
-            f"{feat}_{item}"
-            for feat, categories in {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}.items()
-            for item in categories
-        ]
-        
-        # 提取交叉区块(修正点2:使用.loc[]和列表索引)
-        cross_matrix = interaction_df.loc[product_feats, cust_feats] 
-        
-        # 1. 将矩阵转换为长格式(行:卷烟特征,列:商户特征,值:SHAP交互值)
-        stacked = cross_matrix.stack().reset_index()
-        stacked.columns = ['product_feat', 'cust_feat', 'relation']
-        
-        # 2. 过滤掉零值或NaN的配对
-        filtered = stacked[
-            (stacked['relation'].abs() > 1e-6) &  # 排除极小值
-            (~stacked['relation'].isna())         # 排除NaN
-        ].copy()
-        
-        # 3. 转换为字典列表并按relation降序排序
-        results = (
-            filtered
-            .sort_values(['product_feat', 'relation'], ascending=[False, False])
-            .to_dict('records')
-        )
-        
-        # 4. 替换名字
-        feats_name_map = {**ImportanceFeaturesMap.CUSTOM_FEATURES_MAP, **ImportanceFeaturesMap.SHOPING_FEATURES_MAP, **ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP}
-        for item in results:
-            product_f = item["product_feat"]
-            product_infos = product_f.split("_")
-            item["product_feat"] = f"{feats_name_map['_'.join(product_infos[:-1])]}({product_infos[-1]})"
-            
-            cust_f = item["cust_feat"]
-            cust_infos = cust_f.split("_")
-            item["cust_feat"] = f"{feats_name_map['_'.join(cust_infos[:-1])]}({cust_infos[-1]})"
-        
-        results = pd.DataFrame(results, columns=['product_feat', 'cust_feat', 'relation']) 
-        return results
+        # 获取数据基本信息
+        n_samples = len(data)
+        n_features = len(data.columns)
+        batch_size = 500  # 可根据内存调整
+        
+        # 创建临时内存映射文件
+        temp_dir = tempfile.mkdtemp()
+        temp_file = os.path.join(temp_dir, "shap_interaction_temp.dat")
+        
+        try:
+            # 预创建内存映射文件
+            fp_shape = (n_samples, n_features, n_features)
+            fp = np.memmap(temp_file, dtype=np.float32, 
+                        mode='w+', 
+                        shape=fp_shape)
+            
+            # 分批计算并存储SHAP交互值
+            for i in tqdm(range(0, n_samples, batch_size), desc="计算SHAP交互值..."):
+                batch_data = data.iloc[i:i+batch_size]
+                batch_interaction = self._explanier.shap_interaction_values(batch_data)
+                fp[i:i+len(batch_interaction)] = batch_interaction.astype(np.float32)
+                fp.flush()  # 确保数据写入磁盘
+            print("SHAP交互值计算并存储完成")
+            
+            # 分批计算均值
+            mean_interaction = np.zeros((n_features, n_features), dtype=np.float32)
+            for i in tqdm(range(0, n_samples, batch_size), desc="计算均值..."):
+                batch = np.abs(fp[i:i+batch_size])  # 读取批数据并取绝对值
+                mean_interaction += batch.sum(axis=0)  # 按批累加
+            
+            mean_interaction /= n_samples  # 计算最终均值
+            print("均值计算完成")
+            
+            # 构建交互矩阵DataFrame
+            interaction_df = pd.DataFrame(
+                mean_interaction,
+                index=data.columns,
+                columns=data.columns
+            )
+            print("交互矩阵构建完成")
+            
+            # 分离卷烟和商户特征
+            product_feats = [
+                f"{feat}_{item}" 
+                for feat, categories in ProductConfig.ONEHOT_CAT.items()
+                for item in categories
+            ]
+            
+            cust_feats = [
+                f"{feat}_{item}"
+                for feat, categories in {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}.items()
+                for item in categories
+            ]
+            print("特征分离完成")
+            
+            # 提取交叉区块
+            cross_matrix = interaction_df.loc[product_feats, cust_feats] 
+            print("交叉区块提取完成")
+            
+            # 转换为长格式
+            stacked = cross_matrix.stack().reset_index()
+            stacked.columns = ['product_feat', 'cust_feat', 'relation']
+            print("转换为长格式完成")
+            
+            # 过滤掉零值或NaN的配对
+            filtered = stacked[
+                (stacked['relation'].abs() > 1e-6) &  # 排除极小值
+                (~stacked['relation'].isna())         # 排除NaN
+            ].copy()
+            print("过滤完成")
+            
+            # 排序结果
+            results = (
+                filtered
+                .sort_values('relation', ascending=False)
+                .to_dict('records')
+            )
+            print("排序完成")
+            
+            # 替换特征名称
+            feats_name_map = {
+                **ImportanceFeaturesMap.CUSTOM_FEATURES_MAP, 
+                **ImportanceFeaturesMap.SHOPING_FEATURES_MAP, 
+                **ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP
+            }
+            
+            for item in results:
+                # 处理产品特征名
+                product_f = item["product_feat"]
+                product_infos = product_f.split("_")
+                item["product_feat"] = f"{feats_name_map['_'.join(product_infos[:-1])]}({product_infos[-1]})"
+                
+                # 处理客户特征名
+                cust_f = item["cust_feat"]
+                cust_infos = cust_f.split("_")
+                item["cust_feat"] = f"{feats_name_map['_'.join(cust_infos[:-1])]}({cust_infos[-1]})"
+            
+            print("名称替换完成")
+            
+            # 返回最终结果
+            return pd.DataFrame(results, columns=['product_feat', 'cust_feat', 'relation'])
+        
+        finally:
+            # 清理临时文件
+            try:
+                del fp  # 必须先删除内存映射对象
+                os.remove(temp_file)
+                os.rmdir(temp_dir)
+            except Exception as e:
+                print(f"清理临时文件时出错: {e}")
     
 if __name__ == "__main__":
     model_path = "./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl"
@@ -204,7 +257,11 @@ if __name__ == "__main__":
     #     for item in product_features_importance
     # ])
     # product_df.to_csv("./data/product_feats.csv", index=False)
-    _, data = DataLoader("./data/gbdt/train_data.csv").split_dataset()
-    result = gbdt_sort.generate_shap_interance(data["data"][:2000])
-    
-    result.to_csv("./data/feats_interaction.csv", index=False, encoding='utf-8-sig')
+    data, _ = DataLoader("./data/gbdt/train_data.csv").split_dataset()
+    # data = data["data"].sample(n=1000, replace=True, random_state=42)
+    data = data["data"]
+    result = gbdt_sort.generate_shap_interance(data)
+    print("保存结果")
+    result.to_csv("./data/feats_interaction.csv", index=False, encoding='utf-8-sig')
+    split_relation_subtable(result, "./data")
+    

+ 5 - 0
utils/__init__.py

@@ -1,2 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
+from utils.result_process import split_relation_subtable
+
+__all__ = [
+    "split_relation_subtable"
+]

+ 12 - 0
utils/result_process.py

@@ -0,0 +1,12 @@
+import os
+def split_relation_subtable(data, save_dir):
+    """拆分卷烟商户特征相关性子表"""
+    data['group_key'] = data["product_feat"].str.extract(r'^([^(]+)')
+    grouped = data.groupby('group_key')
+    sub_tables = {
+        name: group.drop(columns=['group_key']).sort_values('relation', ascending=False)
+        for name, group in grouped
+    }
+    
+    for name, sub_data in sub_tables.items():
+        sub_data.to_csv(os.path.join(save_dir, f"{name}.csv"), index=False)