|
|
@@ -1,10 +1,13 @@
|
|
|
import joblib
|
|
|
# from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
|
|
|
from database import RedisDatabaseHelper, MySqlDao
|
|
|
+from models.rank.data import DataLoader
|
|
|
from models.rank.data import ProductConfig, CustConfig, ShopConfig, ImportanceFeaturesMap
|
|
|
from models.rank.data.utils import one_hot_embedding, sample_data_clear
|
|
|
+import numpy as np
|
|
|
import pandas as pd
|
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
+import shap
|
|
|
import os
|
|
|
|
|
|
|
|
|
@@ -13,6 +16,7 @@ class GbdtLrModel:
|
|
|
self.load_model(model_path)
|
|
|
self.redis = RedisDatabaseHelper().redis
|
|
|
self._mysql_dao = MySqlDao()
|
|
|
+ self._explanier = None
|
|
|
|
|
|
def load_model(self, model_path):
|
|
|
models = joblib.load(model_path)
|
|
|
@@ -61,7 +65,7 @@ class GbdtLrModel:
|
|
|
scaler = StandardScaler()
|
|
|
self.feats_map[numeric_columns] = scaler.fit_transform(self.feats_map[numeric_columns])
|
|
|
|
|
|
- def sort(self, city_uuid, product_id):
|
|
|
+ def recommend_sort(self, city_uuid, product_id):
|
|
|
self.generate_feats_map(city_uuid, product_id)
|
|
|
|
|
|
gbdt_preds = self.gbdt_model.apply(self.feats_map)[:, :, 0]
|
|
|
@@ -81,16 +85,14 @@ class GbdtLrModel:
|
|
|
"""生成特征重要性"""
|
|
|
# 获取GBDT模型的特征重要性
|
|
|
feats_importance = self.gbdt_model.feature_importances_
|
|
|
-
|
|
|
# 获取特征名称
|
|
|
feats_names = self.gbdt_model.feature_names_in_
|
|
|
-
|
|
|
importance_dict = dict(zip(feats_names, feats_importance))
|
|
|
|
|
|
onehot_feats = {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
|
|
|
for feat, categories in onehot_feats.items():
|
|
|
+ related_columns = [f"{feat}_{item}" for item in categories]
|
|
|
|
|
|
- related_columns = [col for col in feats_names if col.startswith(feat)]
|
|
|
if related_columns:
|
|
|
# 合并类别重要性
|
|
|
combined_importance = sum(importance_dict[col] for col in related_columns)
|
|
|
@@ -117,6 +119,71 @@ class GbdtLrModel:
|
|
|
|
|
|
return cust_features_importance, product_features_importance
|
|
|
|
|
|
+ def generate_shap_interance(self, data):
|
|
|
+ # 初始化SHAP解释器
|
|
|
+ if self._explanier is None:
|
|
|
+ self._explanier = shap.TreeExplainer(self.gbdt_model)
|
|
|
+
|
|
|
+ # 计算SHAP交互值(限制样本数量以提高性能)
|
|
|
+ shap_interaction = self._explanier.shap_interaction_values(data)
|
|
|
+
|
|
|
+ # 取平均交互值(绝对值)
|
|
|
+ mean_interaction = np.abs(shap_interaction).mean(0)
|
|
|
+
|
|
|
+ # 构建交互矩阵DataFrame
|
|
|
+ interaction_df = pd.DataFrame(
|
|
|
+ mean_interaction,
|
|
|
+ index=data.columns,
|
|
|
+ columns=data.columns
|
|
|
+ )
|
|
|
+
|
|
|
+ # 分离卷烟和商户特征(修正点1:直接生成特征名列表)
|
|
|
+ product_feats = [
|
|
|
+ f"{feat}_{item}"
|
|
|
+ for feat, categories in ProductConfig.ONEHOT_CAT.items()
|
|
|
+ for item in categories
|
|
|
+ ]
|
|
|
+
|
|
|
+ cust_feats = [
|
|
|
+ f"{feat}_{item}"
|
|
|
+ for feat, categories in {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}.items()
|
|
|
+ for item in categories
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 提取交叉区块(修正点2:使用.loc[]和列表索引)
|
|
|
+ cross_matrix = interaction_df.loc[product_feats, cust_feats]
|
|
|
+
|
|
|
+ # 1. 将矩阵转换为长格式(行:卷烟特征,列:商户特征,值:SHAP交互值)
|
|
|
+ stacked = cross_matrix.stack().reset_index()
|
|
|
+ stacked.columns = ['product_feat', 'cust_feat', 'relation']
|
|
|
+
|
|
|
+ # 2. 过滤掉零值或NaN的配对
|
|
|
+ filtered = stacked[
|
|
|
+ (stacked['relation'].abs() > 1e-6) & # 排除极小值
|
|
|
+ (~stacked['relation'].isna()) # 排除NaN
|
|
|
+ ].copy()
|
|
|
+
|
|
|
+ # 3. 转换为字典列表并按relation降序排序
|
|
|
+ results = (
|
|
|
+ filtered
|
|
|
+ .sort_values(['product_feat', 'relation'], ascending=[False, False])
|
|
|
+ .to_dict('records')
|
|
|
+ )
|
|
|
+
|
|
|
+ # 4. 替换名字
|
|
|
+ feats_name_map = {**ImportanceFeaturesMap.CUSTOM_FEATURES_MAP, **ImportanceFeaturesMap.SHOPING_FEATURES_MAP, **ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP}
|
|
|
+ for item in results:
|
|
|
+ product_f = item["product_feat"]
|
|
|
+ product_infos = product_f.split("_")
|
|
|
+ item["product_feat"] = f"{feats_name_map['_'.join(product_infos[:-1])]}({product_infos[-1]})"
|
|
|
+
|
|
|
+ cust_f = item["cust_feat"]
|
|
|
+ cust_infos = cust_f.split("_")
|
|
|
+ item["cust_feat"] = f"{feats_name_map['_'.join(cust_infos[:-1])]}({cust_infos[-1]})"
|
|
|
+
|
|
|
+ results = pd.DataFrame(results, columns=['product_feat', 'cust_feat', 'relation'])
|
|
|
+ return results
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
model_path = "./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl"
|
|
|
city_uuid = "00000000000000000000000011445301"
|
|
|
@@ -124,17 +191,20 @@ if __name__ == "__main__":
|
|
|
gbdt_sort = GbdtLrModel(model_path)
|
|
|
# gbdt_sort.sort(city_uuid, product_id)
|
|
|
|
|
|
- cust_features_importance, product_features_importance = gbdt_sort.generate_feats_importance()
|
|
|
+ # cust_features_importance, product_features_importance = gbdt_sort.generate_feats_importance()
|
|
|
|
|
|
- cust_df = pd.DataFrame([
|
|
|
- {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
|
|
|
- for item in cust_features_importance
|
|
|
- ])
|
|
|
- cust_df.to_csv("./data/cust_feats.csv", index=False)
|
|
|
+ # cust_df = pd.DataFrame([
|
|
|
+ # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
|
|
|
+ # for item in cust_features_importance
|
|
|
+ # ])
|
|
|
+ # cust_df.to_csv("./data/cust_feats.csv", index=False)
|
|
|
+
|
|
|
+ # product_df = pd.DataFrame([
|
|
|
+ # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
|
|
|
+ # for item in product_features_importance
|
|
|
+ # ])
|
|
|
+ # product_df.to_csv("./data/product_feats.csv", index=False)
|
|
|
+ _, data = DataLoader("./data/gbdt/train_data.csv").split_dataset()
|
|
|
+ result = gbdt_sort.generate_shap_interance(data["data"][:2000])
|
|
|
|
|
|
- product_df = pd.DataFrame([
|
|
|
- {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
|
|
|
- for item in product_features_importance
|
|
|
- ])
|
|
|
- product_df.to_csv("./data/product_feats.csv", index=False)
|
|
|
-
|
|
|
+ result.to_csv("./data/feats_interaction.csv", index=False, encoding='utf-8-sig')
|