gbdt_lr_sort.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import joblib
  2. # from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
  3. from database import RedisDatabaseHelper, MySqlDao
  4. from models.rank.data import DataLoader
  5. from models.rank.data import ProductConfig, CustConfig, ShopConfig, ImportanceFeaturesMap
  6. from models.rank.data.utils import one_hot_embedding, sample_data_clear
  7. import numpy as np
  8. import pandas as pd
  9. from sklearn.preprocessing import StandardScaler
  10. import shap
  11. import os
  12. class GbdtLrModel:
  13. def __init__(self, model_path):
  14. self.load_model(model_path)
  15. self.redis = RedisDatabaseHelper().redis
  16. self._mysql_dao = MySqlDao()
  17. self._explanier = None
  18. def load_model(self, model_path):
  19. models = joblib.load(model_path)
  20. self.gbdt_model, self.lr_model, self.onehot_encoder = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
  21. # def get_recall_list(self, city_uuid, product_id):
  22. # """根据卷烟id获取召回的商铺列表"""
  23. # key = f"fc:{city_uuid}:{product_id}"
  24. # self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
  25. # def load_recall_data(self, city_uuid, product_id):
  26. # self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  27. # self.custs_data = self._mysql_dao.get_cust_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
  28. def get_cust_and_product_data(self, city_uuid, product_id):
  29. """从商户数据库中获取指定城市所有商户的id"""
  30. self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  31. self.custs_data = self._mysql_dao.load_cust_data(city_uuid)[CustConfig.FEATURE_COLUMNS]
  32. def generate_feats_map(self, city_uuid, product_id):
  33. """组合卷烟、商户特征矩阵"""
  34. # self.get_recall_list(city_uuid, product_id)
  35. # self.load_recall_data(city_uuid, product_id)
  36. self.get_cust_and_product_data(city_uuid, product_id)
  37. # 做数据清洗
  38. self.product_data = sample_data_clear(self.product_data, ProductConfig)
  39. self.custs_data = sample_data_clear(self.custs_data, CustConfig)
  40. # 笛卡尔积联合
  41. self.custs_data["descartes"] = 1
  42. self.product_data["descartes"] = 1
  43. self.feats_map = pd.merge(self.custs_data, self.product_data, on="descartes").drop("descartes", axis=1)
  44. self.recall_cust_list = self.feats_map["BB_RETAIL_CUSTOMER_CODE"].to_list()
  45. self.feats_map.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
  46. self.feats_map.drop('product_code', axis=1, inplace=True)
  47. # onehot编码
  48. onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  49. onehot_columns = list(onehot_feats.keys())
  50. numeric_columns = self.feats_map.drop(onehot_columns, axis=1).columns
  51. self.feats_map = one_hot_embedding(self.feats_map, onehot_feats)
  52. # 数字特征归一化
  53. scaler = StandardScaler()
  54. self.feats_map[numeric_columns] = scaler.fit_transform(self.feats_map[numeric_columns])
  55. def recommend_sort(self, city_uuid, product_id):
  56. self.generate_feats_map(city_uuid, product_id)
  57. gbdt_preds = self.gbdt_model.apply(self.feats_map)[:, :, 0]
  58. gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
  59. scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1]
  60. self.recommend_list = []
  61. for cust_id, score in zip(self.recall_cust_list, scores):
  62. self.recommend_list.append({cust_id: float(score)})
  63. self.recommend_list = sorted(self.recommend_list, key=lambda x: list(x.values())[0], reverse=True)
  64. # for res in self.recommend_list[:200]:
  65. # print(res)
  66. return self.recommend_list
  67. def generate_feats_importance(self):
  68. """生成特征重要性"""
  69. # 获取GBDT模型的特征重要性
  70. feats_importance = self.gbdt_model.feature_importances_
  71. # 获取特征名称
  72. feats_names = self.gbdt_model.feature_names_in_
  73. importance_dict = dict(zip(feats_names, feats_importance))
  74. onehot_feats = {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  75. for feat, categories in onehot_feats.items():
  76. related_columns = [f"{feat}_{item}" for item in categories]
  77. if related_columns:
  78. # 合并类别重要性
  79. combined_importance = sum(importance_dict[col] for col in related_columns)
  80. # 删除onehot类别列
  81. for col in related_columns:
  82. del importance_dict[col]
  83. # 添加合并后的重要性
  84. importance_dict[feat] = combined_importance
  85. # 排序
  86. sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
  87. # 输出特征重要性
  88. cust_features_importance = []
  89. product_features_importance = []
  90. for feat, importance in sorted_importance:
  91. if feat in list(ImportanceFeaturesMap.CUSTOM_FEATURES_MAP.keys()):
  92. cust_features_importance.append({ImportanceFeaturesMap.CUSTOM_FEATURES_MAP[feat]: float(importance)})
  93. if feat in list(ImportanceFeaturesMap.SHOPING_FEATURES_MAP.keys()):
  94. cust_features_importance.append({ImportanceFeaturesMap.SHOPING_FEATURES_MAP[feat]: float(importance)})
  95. if feat in list(ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP.keys()):
  96. product_features_importance.append({ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP[feat]: float(importance)})
  97. return cust_features_importance, product_features_importance
  98. def generate_shap_interance(self, data):
  99. # 初始化SHAP解释器
  100. if self._explanier is None:
  101. self._explanier = shap.TreeExplainer(self.gbdt_model)
  102. # 计算SHAP交互值(限制样本数量以提高性能)
  103. shap_interaction = self._explanier.shap_interaction_values(data)
  104. # 取平均交互值(绝对值)
  105. mean_interaction = np.abs(shap_interaction).mean(0)
  106. # 构建交互矩阵DataFrame
  107. interaction_df = pd.DataFrame(
  108. mean_interaction,
  109. index=data.columns,
  110. columns=data.columns
  111. )
  112. # 分离卷烟和商户特征(修正点1:直接生成特征名列表)
  113. product_feats = [
  114. f"{feat}_{item}"
  115. for feat, categories in ProductConfig.ONEHOT_CAT.items()
  116. for item in categories
  117. ]
  118. cust_feats = [
  119. f"{feat}_{item}"
  120. for feat, categories in {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}.items()
  121. for item in categories
  122. ]
  123. # 提取交叉区块(修正点2:使用.loc[]和列表索引)
  124. cross_matrix = interaction_df.loc[product_feats, cust_feats]
  125. # 1. 将矩阵转换为长格式(行:卷烟特征,列:商户特征,值:SHAP交互值)
  126. stacked = cross_matrix.stack().reset_index()
  127. stacked.columns = ['product_feat', 'cust_feat', 'relation']
  128. # 2. 过滤掉零值或NaN的配对
  129. filtered = stacked[
  130. (stacked['relation'].abs() > 1e-6) & # 排除极小值
  131. (~stacked['relation'].isna()) # 排除NaN
  132. ].copy()
  133. # 3. 转换为字典列表并按relation降序排序
  134. results = (
  135. filtered
  136. .sort_values(['product_feat', 'relation'], ascending=[False, False])
  137. .to_dict('records')
  138. )
  139. # 4. 替换名字
  140. feats_name_map = {**ImportanceFeaturesMap.CUSTOM_FEATURES_MAP, **ImportanceFeaturesMap.SHOPING_FEATURES_MAP, **ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP}
  141. for item in results:
  142. product_f = item["product_feat"]
  143. product_infos = product_f.split("_")
  144. item["product_feat"] = f"{feats_name_map['_'.join(product_infos[:-1])]}({product_infos[-1]})"
  145. cust_f = item["cust_feat"]
  146. cust_infos = cust_f.split("_")
  147. item["cust_feat"] = f"{feats_name_map['_'.join(cust_infos[:-1])]}({cust_infos[-1]})"
  148. results = pd.DataFrame(results, columns=['product_feat', 'cust_feat', 'relation'])
  149. return results
  150. if __name__ == "__main__":
  151. model_path = "./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl"
  152. city_uuid = "00000000000000000000000011445301"
  153. product_id = "110102"
  154. gbdt_sort = GbdtLrModel(model_path)
  155. # gbdt_sort.sort(city_uuid, product_id)
  156. # cust_features_importance, product_features_importance = gbdt_sort.generate_feats_importance()
  157. # cust_df = pd.DataFrame([
  158. # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  159. # for item in cust_features_importance
  160. # ])
  161. # cust_df.to_csv("./data/cust_feats.csv", index=False)
  162. # product_df = pd.DataFrame([
  163. # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  164. # for item in product_features_importance
  165. # ])
  166. # product_df.to_csv("./data/product_feats.csv", index=False)
  167. _, data = DataLoader("./data/gbdt/train_data.csv").split_dataset()
  168. result = gbdt_sort.generate_shap_interance(data["data"][:2000])
  169. result.to_csv("./data/feats_interaction.csv", index=False, encoding='utf-8-sig')