gbdt_lr_sort.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import joblib
  2. # from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
  3. from database import RedisDatabaseHelper, MySqlDao
  4. from models.rank.data import DataLoader
  5. from models.rank.data import ProductConfig, CustConfig, ShopConfig, ImportanceFeaturesMap
  6. from models.rank.data.utils import one_hot_embedding, sample_data_clear
  7. import numpy as np
  8. import pandas as pd
  9. from sklearn.preprocessing import StandardScaler
  10. import shap
  11. from tqdm import tqdm
  12. from utils import split_relation_subtable
  13. import os
  14. import tempfile
  15. class GbdtLrModel:
  16. def __init__(self, model_path):
  17. self.load_model(model_path)
  18. self.redis = RedisDatabaseHelper().redis
  19. self._mysql_dao = MySqlDao()
  20. self._explanier = None
  21. def load_model(self, model_path):
  22. models = joblib.load(model_path)
  23. self.gbdt_model, self.lr_model, self.onehot_encoder = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
  24. # def get_recall_list(self, city_uuid, product_id):
  25. # """根据卷烟id获取召回的商铺列表"""
  26. # key = f"fc:{city_uuid}:{product_id}"
  27. # self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
  28. # def load_recall_data(self, city_uuid, product_id):
  29. # self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  30. # self.custs_data = self._mysql_dao.get_cust_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
  31. def get_cust_and_product_data(self, city_uuid, product_id):
  32. """从商户数据库中获取指定城市所有商户的id"""
  33. self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  34. self.custs_data = self._mysql_dao.load_cust_data(city_uuid)[CustConfig.FEATURE_COLUMNS]
  35. def generate_feats_map(self, city_uuid, product_id):
  36. """组合卷烟、商户特征矩阵"""
  37. # self.get_recall_list(city_uuid, product_id)
  38. # self.load_recall_data(city_uuid, product_id)
  39. self.get_cust_and_product_data(city_uuid, product_id)
  40. # 做数据清洗
  41. self.product_data = sample_data_clear(self.product_data, ProductConfig)
  42. self.custs_data = sample_data_clear(self.custs_data, CustConfig)
  43. # 笛卡尔积联合
  44. self.custs_data["descartes"] = 1
  45. self.product_data["descartes"] = 1
  46. self.feats_map = pd.merge(self.custs_data, self.product_data, on="descartes").drop("descartes", axis=1)
  47. self.recall_cust_list = self.feats_map["BB_RETAIL_CUSTOMER_CODE"].to_list()
  48. self.feats_map.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
  49. self.feats_map.drop('product_code', axis=1, inplace=True)
  50. # onehot编码
  51. onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  52. onehot_columns = list(onehot_feats.keys())
  53. numeric_columns = self.feats_map.drop(onehot_columns, axis=1).columns
  54. self.feats_map = one_hot_embedding(self.feats_map, onehot_feats)
  55. # 数字特征归一化
  56. scaler = StandardScaler()
  57. self.feats_map[numeric_columns] = scaler.fit_transform(self.feats_map[numeric_columns])
  58. def recommend_sort(self, city_uuid, product_id):
  59. self.generate_feats_map(city_uuid, product_id)
  60. gbdt_preds = self.gbdt_model.apply(self.feats_map)[:, :, 0]
  61. gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
  62. scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1]
  63. self.recommend_list = []
  64. for cust_id, score in zip(self.recall_cust_list, scores):
  65. self.recommend_list.append({cust_id: float(score)})
  66. self.recommend_list = sorted(self.recommend_list, key=lambda x: list(x.values())[0], reverse=True)
  67. # for res in self.recommend_list[:200]:
  68. # print(res)
  69. return self.recommend_list
  70. def generate_feats_importance(self):
  71. """生成特征重要性"""
  72. # 获取GBDT模型的特征重要性
  73. feats_importance = self.gbdt_model.feature_importances_
  74. # 获取特征名称
  75. feats_names = self.gbdt_model.feature_names_in_
  76. importance_dict = dict(zip(feats_names, feats_importance))
  77. onehot_feats = {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  78. for feat, categories in onehot_feats.items():
  79. related_columns = [f"{feat}_{item}" for item in categories]
  80. if related_columns:
  81. # 合并类别重要性
  82. combined_importance = sum(importance_dict[col] for col in related_columns)
  83. # 删除onehot类别列
  84. for col in related_columns:
  85. del importance_dict[col]
  86. # 添加合并后的重要性
  87. importance_dict[feat] = combined_importance
  88. # 排序
  89. sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
  90. # 输出特征重要性
  91. cust_features_importance = []
  92. product_features_importance = []
  93. for feat, importance in sorted_importance:
  94. if feat in list(ImportanceFeaturesMap.CUSTOM_FEATURES_MAP.keys()):
  95. cust_features_importance.append({ImportanceFeaturesMap.CUSTOM_FEATURES_MAP[feat]: float(importance)})
  96. if feat in list(ImportanceFeaturesMap.SHOPING_FEATURES_MAP.keys()):
  97. cust_features_importance.append({ImportanceFeaturesMap.SHOPING_FEATURES_MAP[feat]: float(importance)})
  98. if feat in list(ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP.keys()):
  99. product_features_importance.append({ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP[feat]: float(importance)})
  100. return cust_features_importance, product_features_importance
  101. def generate_shap_interance(self, data):
  102. # 初始化SHAP解释器
  103. if self._explanier is None:
  104. self._explanier = shap.TreeExplainer(self.gbdt_model)
  105. # 获取数据基本信息
  106. n_samples = len(data)
  107. n_features = len(data.columns)
  108. batch_size = 500 # 可根据内存调整
  109. # 创建临时内存映射文件
  110. temp_dir = tempfile.mkdtemp()
  111. temp_file = os.path.join(temp_dir, "shap_interaction_temp.dat")
  112. try:
  113. # 预创建内存映射文件
  114. fp_shape = (n_samples, n_features, n_features)
  115. fp = np.memmap(temp_file, dtype=np.float32,
  116. mode='w+',
  117. shape=fp_shape)
  118. # 分批计算并存储SHAP交互值
  119. for i in tqdm(range(0, n_samples, batch_size), desc="计算SHAP交互值..."):
  120. batch_data = data.iloc[i:i+batch_size]
  121. batch_interaction = self._explanier.shap_interaction_values(batch_data)
  122. fp[i:i+len(batch_interaction)] = batch_interaction.astype(np.float32)
  123. fp.flush() # 确保数据写入磁盘
  124. print("SHAP交互值计算并存储完成")
  125. # 分批计算均值
  126. mean_interaction = np.zeros((n_features, n_features), dtype=np.float32)
  127. for i in tqdm(range(0, n_samples, batch_size), desc="计算均值..."):
  128. batch = np.abs(fp[i:i+batch_size]) # 读取批数据并取绝对值
  129. mean_interaction += batch.sum(axis=0) # 按批累加
  130. mean_interaction /= n_samples # 计算最终均值
  131. print("均值计算完成")
  132. # 构建交互矩阵DataFrame
  133. interaction_df = pd.DataFrame(
  134. mean_interaction,
  135. index=data.columns,
  136. columns=data.columns
  137. )
  138. print("交互矩阵构建完成")
  139. # 分离卷烟和商户特征
  140. product_feats = [
  141. f"{feat}_{item}"
  142. for feat, categories in ProductConfig.ONEHOT_CAT.items()
  143. for item in categories
  144. ]
  145. cust_feats = [
  146. f"{feat}_{item}"
  147. for feat, categories in {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}.items()
  148. for item in categories
  149. ]
  150. print("特征分离完成")
  151. # 提取交叉区块
  152. cross_matrix = interaction_df.loc[product_feats, cust_feats]
  153. print("交叉区块提取完成")
  154. # 转换为长格式
  155. stacked = cross_matrix.stack().reset_index()
  156. stacked.columns = ['product_feat', 'cust_feat', 'relation']
  157. print("转换为长格式完成")
  158. # 过滤掉零值或NaN的配对
  159. filtered = stacked[
  160. (stacked['relation'].abs() > 1e-6) & # 排除极小值
  161. (~stacked['relation'].isna()) # 排除NaN
  162. ].copy()
  163. print("过滤完成")
  164. # 排序结果
  165. results = (
  166. filtered
  167. .sort_values('relation', ascending=False)
  168. .to_dict('records')
  169. )
  170. print("排序完成")
  171. # 替换特征名称
  172. feats_name_map = {
  173. **ImportanceFeaturesMap.CUSTOM_FEATURES_MAP,
  174. **ImportanceFeaturesMap.SHOPING_FEATURES_MAP,
  175. **ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP
  176. }
  177. for item in results:
  178. # 处理产品特征名
  179. product_f = item["product_feat"]
  180. product_infos = product_f.split("_")
  181. item["product_feat"] = f"{feats_name_map['_'.join(product_infos[:-1])]}({product_infos[-1]})"
  182. # 处理客户特征名
  183. cust_f = item["cust_feat"]
  184. cust_infos = cust_f.split("_")
  185. item["cust_feat"] = f"{feats_name_map['_'.join(cust_infos[:-1])]}({cust_infos[-1]})"
  186. print("名称替换完成")
  187. # 返回最终结果
  188. return pd.DataFrame(results, columns=['product_feat', 'cust_feat', 'relation'])
  189. finally:
  190. # 清理临时文件
  191. try:
  192. del fp # 必须先删除内存映射对象
  193. os.remove(temp_file)
  194. os.rmdir(temp_dir)
  195. except Exception as e:
  196. print(f"清理临时文件时出错: {e}")
  197. if __name__ == "__main__":
  198. model_path = "./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl"
  199. city_uuid = "00000000000000000000000011445301"
  200. product_id = "110102"
  201. gbdt_sort = GbdtLrModel(model_path)
  202. # gbdt_sort.sort(city_uuid, product_id)
  203. # cust_features_importance, product_features_importance = gbdt_sort.generate_feats_importance()
  204. # cust_df = pd.DataFrame([
  205. # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  206. # for item in cust_features_importance
  207. # ])
  208. # cust_df.to_csv("./data/cust_feats.csv", index=False)
  209. # product_df = pd.DataFrame([
  210. # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  211. # for item in product_features_importance
  212. # ])
  213. # product_df.to_csv("./data/product_feats.csv", index=False)
  214. data, _ = DataLoader("./data/gbdt/train_data.csv").split_dataset()
  215. # data = data["data"].sample(n=1000, replace=True, random_state=42)
  216. data = data["data"]
  217. result = gbdt_sort.generate_shap_interance(data)
  218. print("保存结果")
  219. result.to_csv("./data/feats_interaction.csv", index=False, encoding='utf-8-sig')
  220. split_relation_subtable(result, "./data")