|
|
@@ -8,8 +8,10 @@ import numpy as np
|
|
|
import pandas as pd
|
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
import shap
|
|
|
+from tqdm import tqdm
|
|
|
+from utils import split_relation_subtable
|
|
|
import os
|
|
|
-
|
|
|
+import tempfile
|
|
|
|
|
|
class GbdtLrModel:
|
|
|
def __init__(self, model_path):
|
|
|
@@ -124,65 +126,116 @@ class GbdtLrModel:
|
|
|
if self._explanier is None:
|
|
|
self._explanier = shap.TreeExplainer(self.gbdt_model)
|
|
|
|
|
|
- # 计算SHAP交互值(限制样本数量以提高性能)
|
|
|
- shap_interaction = self._explanier.shap_interaction_values(data)
|
|
|
-
|
|
|
- # 取平均交互值(绝对值)
|
|
|
- mean_interaction = np.abs(shap_interaction).mean(0)
|
|
|
-
|
|
|
- # 构建交互矩阵DataFrame
|
|
|
- interaction_df = pd.DataFrame(
|
|
|
- mean_interaction,
|
|
|
- index=data.columns,
|
|
|
- columns=data.columns
|
|
|
- )
|
|
|
-
|
|
|
- # 分离卷烟和商户特征(修正点1:直接生成特征名列表)
|
|
|
- product_feats = [
|
|
|
- f"{feat}_{item}"
|
|
|
- for feat, categories in ProductConfig.ONEHOT_CAT.items()
|
|
|
- for item in categories
|
|
|
- ]
|
|
|
-
|
|
|
- cust_feats = [
|
|
|
- f"{feat}_{item}"
|
|
|
- for feat, categories in {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}.items()
|
|
|
- for item in categories
|
|
|
- ]
|
|
|
-
|
|
|
- # 提取交叉区块(修正点2:使用.loc[]和列表索引)
|
|
|
- cross_matrix = interaction_df.loc[product_feats, cust_feats]
|
|
|
-
|
|
|
- # 1. 将矩阵转换为长格式(行:卷烟特征,列:商户特征,值:SHAP交互值)
|
|
|
- stacked = cross_matrix.stack().reset_index()
|
|
|
- stacked.columns = ['product_feat', 'cust_feat', 'relation']
|
|
|
-
|
|
|
- # 2. 过滤掉零值或NaN的配对
|
|
|
- filtered = stacked[
|
|
|
- (stacked['relation'].abs() > 1e-6) & # 排除极小值
|
|
|
- (~stacked['relation'].isna()) # 排除NaN
|
|
|
- ].copy()
|
|
|
-
|
|
|
- # 3. 转换为字典列表并按relation降序排序
|
|
|
- results = (
|
|
|
- filtered
|
|
|
- .sort_values(['product_feat', 'relation'], ascending=[False, False])
|
|
|
- .to_dict('records')
|
|
|
- )
|
|
|
-
|
|
|
- # 4. 替换名字
|
|
|
- feats_name_map = {**ImportanceFeaturesMap.CUSTOM_FEATURES_MAP, **ImportanceFeaturesMap.SHOPING_FEATURES_MAP, **ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP}
|
|
|
- for item in results:
|
|
|
- product_f = item["product_feat"]
|
|
|
- product_infos = product_f.split("_")
|
|
|
- item["product_feat"] = f"{feats_name_map['_'.join(product_infos[:-1])]}({product_infos[-1]})"
|
|
|
-
|
|
|
- cust_f = item["cust_feat"]
|
|
|
- cust_infos = cust_f.split("_")
|
|
|
- item["cust_feat"] = f"{feats_name_map['_'.join(cust_infos[:-1])]}({cust_infos[-1]})"
|
|
|
-
|
|
|
- results = pd.DataFrame(results, columns=['product_feat', 'cust_feat', 'relation'])
|
|
|
- return results
|
|
|
+ # 获取数据基本信息
|
|
|
+ n_samples = len(data)
|
|
|
+ n_features = len(data.columns)
|
|
|
+ batch_size = 500 # 可根据内存调整
|
|
|
+
|
|
|
+ # 创建临时内存映射文件
|
|
|
+ temp_dir = tempfile.mkdtemp()
|
|
|
+ temp_file = os.path.join(temp_dir, "shap_interaction_temp.dat")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 预创建内存映射文件
|
|
|
+ fp_shape = (n_samples, n_features, n_features)
|
|
|
+ fp = np.memmap(temp_file, dtype=np.float32,
|
|
|
+ mode='w+',
|
|
|
+ shape=fp_shape)
|
|
|
+
|
|
|
+ # 分批计算并存储SHAP交互值
|
|
|
+ for i in tqdm(range(0, n_samples, batch_size), desc="计算SHAP交互值..."):
|
|
|
+ batch_data = data.iloc[i:i+batch_size]
|
|
|
+ batch_interaction = self._explanier.shap_interaction_values(batch_data)
|
|
|
+ fp[i:i+len(batch_interaction)] = batch_interaction.astype(np.float32)
|
|
|
+ fp.flush() # 确保数据写入磁盘
|
|
|
+ print("SHAP交互值计算并存储完成")
|
|
|
+
|
|
|
+ # 分批计算均值
|
|
|
+ mean_interaction = np.zeros((n_features, n_features), dtype=np.float32)
|
|
|
+ for i in tqdm(range(0, n_samples, batch_size), desc="计算均值..."):
|
|
|
+ batch = np.abs(fp[i:i+batch_size]) # 读取批数据并取绝对值
|
|
|
+ mean_interaction += batch.sum(axis=0) # 按批累加
|
|
|
+
|
|
|
+ mean_interaction /= n_samples # 计算最终均值
|
|
|
+ print("均值计算完成")
|
|
|
+
|
|
|
+ # 构建交互矩阵DataFrame
|
|
|
+ interaction_df = pd.DataFrame(
|
|
|
+ mean_interaction,
|
|
|
+ index=data.columns,
|
|
|
+ columns=data.columns
|
|
|
+ )
|
|
|
+ print("交互矩阵构建完成")
|
|
|
+
|
|
|
+ # 分离卷烟和商户特征
|
|
|
+ product_feats = [
|
|
|
+ f"{feat}_{item}"
|
|
|
+ for feat, categories in ProductConfig.ONEHOT_CAT.items()
|
|
|
+ for item in categories
|
|
|
+ ]
|
|
|
+
|
|
|
+ cust_feats = [
|
|
|
+ f"{feat}_{item}"
|
|
|
+ for feat, categories in {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}.items()
|
|
|
+ for item in categories
|
|
|
+ ]
|
|
|
+ print("特征分离完成")
|
|
|
+
|
|
|
+ # 提取交叉区块
|
|
|
+ cross_matrix = interaction_df.loc[product_feats, cust_feats]
|
|
|
+ print("交叉区块提取完成")
|
|
|
+
|
|
|
+ # 转换为长格式
|
|
|
+ stacked = cross_matrix.stack().reset_index()
|
|
|
+ stacked.columns = ['product_feat', 'cust_feat', 'relation']
|
|
|
+ print("转换为长格式完成")
|
|
|
+
|
|
|
+ # 过滤掉零值或NaN的配对
|
|
|
+ filtered = stacked[
|
|
|
+ (stacked['relation'].abs() > 1e-6) & # 排除极小值
|
|
|
+ (~stacked['relation'].isna()) # 排除NaN
|
|
|
+ ].copy()
|
|
|
+ print("过滤完成")
|
|
|
+
|
|
|
+ # 排序结果
|
|
|
+ results = (
|
|
|
+ filtered
|
|
|
+ .sort_values('relation', ascending=False)
|
|
|
+ .to_dict('records')
|
|
|
+ )
|
|
|
+ print("排序完成")
|
|
|
+
|
|
|
+ # 替换特征名称
|
|
|
+ feats_name_map = {
|
|
|
+ **ImportanceFeaturesMap.CUSTOM_FEATURES_MAP,
|
|
|
+ **ImportanceFeaturesMap.SHOPING_FEATURES_MAP,
|
|
|
+ **ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP
|
|
|
+ }
|
|
|
+
|
|
|
+ for item in results:
|
|
|
+ # 处理产品特征名
|
|
|
+ product_f = item["product_feat"]
|
|
|
+ product_infos = product_f.split("_")
|
|
|
+ item["product_feat"] = f"{feats_name_map['_'.join(product_infos[:-1])]}({product_infos[-1]})"
|
|
|
+
|
|
|
+ # 处理客户特征名
|
|
|
+ cust_f = item["cust_feat"]
|
|
|
+ cust_infos = cust_f.split("_")
|
|
|
+ item["cust_feat"] = f"{feats_name_map['_'.join(cust_infos[:-1])]}({cust_infos[-1]})"
|
|
|
+
|
|
|
+ print("名称替换完成")
|
|
|
+
|
|
|
+ # 返回最终结果
|
|
|
+ return pd.DataFrame(results, columns=['product_feat', 'cust_feat', 'relation'])
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # 清理临时文件
|
|
|
+ try:
|
|
|
+ del fp # 必须先删除内存映射对象
|
|
|
+ os.remove(temp_file)
|
|
|
+ os.rmdir(temp_dir)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"清理临时文件时出错: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
model_path = "./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl"
|
|
|
@@ -204,7 +257,11 @@ if __name__ == "__main__":
|
|
|
# for item in product_features_importance
|
|
|
# ])
|
|
|
# product_df.to_csv("./data/product_feats.csv", index=False)
|
|
|
- _, data = DataLoader("./data/gbdt/train_data.csv").split_dataset()
|
|
|
- result = gbdt_sort.generate_shap_interance(data["data"][:2000])
|
|
|
-
|
|
|
- result.to_csv("./data/feats_interaction.csv", index=False, encoding='utf-8-sig')
|
|
|
+ data, _ = DataLoader("./data/gbdt/train_data.csv").split_dataset()
|
|
|
+ # data = data["data"].sample(n=1000, replace=True, random_state=42)
|
|
|
+ data = data["data"]
|
|
|
+ result = gbdt_sort.generate_shap_interance(data)
|
|
|
+ print("保存结果")
|
|
|
+ result.to_csv("./data/feats_interaction.csv", index=False, encoding='utf-8-sig')
|
|
|
+ split_relation_subtable(result, "./data")
|
|
|
+
|