gbdt_lr_inference.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import gc
  2. import joblib
  3. import re
  4. # from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
  5. from database import RedisDatabaseHelper, MySqlDao
  6. from models.rank.data import DataLoader
  7. from models.rank.data import ProductConfig, CustConfig, ImportanceFeaturesMap
  8. from models.rank.data.utils import one_hot_embedding, sample_data_clear
  9. import numpy as np
  10. import pandas as pd
  11. from sklearn.preprocessing import StandardScaler
  12. from core import get_logger
  13. logger = get_logger("models.rank.gbdtlr")
  14. def clean_column_name(col):
  15. """清理列名中的特殊字符,与 one_hot_embedding 保持一致"""
  16. return (re.sub(r'[",\\\n\r\t\b\f]', '_', col)
  17. .replace(' ', '_'))
  18. import shap
  19. from tqdm import tqdm
  20. import os
  21. def generate_feats_map(product_data, cust_data):
  22. """组合卷烟、商户特征矩阵"""
  23. # 笛卡尔积联合
  24. cust_data["descartes"] = 1
  25. product_data["descartes"] = 1
  26. feats_map = pd.merge(cust_data, product_data, on="descartes").drop("descartes", axis=1)
  27. # recall_cust_list = feats_map["BB_RETAIL_CUSTOMER_CODE"].to_list()
  28. feats_map.drop('cust_code', axis=1, inplace=True)
  29. feats_map.drop('product_code', axis=1, inplace=True)
  30. # onehot编码
  31. onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  32. onehot_columns = list(onehot_feats.keys())
  33. numeric_columns = feats_map.drop(onehot_columns, axis=1).columns
  34. feats_map = one_hot_embedding(feats_map, onehot_feats)
  35. # 数字特征归一化
  36. if len(numeric_columns) != 0:
  37. scaler = StandardScaler()
  38. feats_map[numeric_columns] = scaler.fit_transform(feats_map[numeric_columns])
  39. return feats_map
  40. class GbdtLrModel:
  41. def __init__(self, model_path):
  42. self.load_model(model_path)
  43. self.redis = RedisDatabaseHelper().redis
  44. self._mysql_dao = MySqlDao()
  45. self._explanier = None
  46. def load_model(self, model_path):
  47. models = joblib.load(model_path)
  48. self.gbdt_model, self.lr_model, self.onehot_encoder = models["lgbm_model"], models["lr_model"], models["onehot_encoder"]
  49. def get_cust_and_product_data(self, city_uuid, product_id):
  50. """从商户数据库中获取指定城市所有商户的id"""
  51. self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  52. self.custs_data = self._mysql_dao.load_cust_data(city_uuid)[CustConfig.FEATURE_COLUMNS]
  53. def get_recommend_list(self, recommend_sample, recall_list):
  54. gbdt_preds = self.gbdt_model.predict(recommend_sample, pred_leaf=True)
  55. gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
  56. scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1] * 100
  57. recommend_list = [
  58. {"cust_code": cust_id, "recommend_score": float(score)}
  59. for cust_id, score in zip(recall_list, scores)
  60. ]
  61. recommend_list.sort(key=lambda x: x["recommend_score"], reverse=True)
  62. logger.info(f"Scored {len(recommend_list)} items in recommend list")
  63. return recommend_list
  64. def generate_feats_importance(self):
  65. """生成特征重要性"""
  66. # 获取GBDT模型的特征重要性
  67. feats_importance = self.gbdt_model.feature_importances_
  68. # 获取特征名称
  69. feats_names = self.gbdt_model.feature_name_
  70. importance_dict = dict(zip(feats_names, feats_importance))
  71. onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  72. for feat, categories in onehot_feats.items():
  73. related_columns = [f"{feat}_{item}" for item in categories]
  74. if related_columns:
  75. # 合并类别重要性
  76. combined_importance = sum(importance_dict[col] for col in related_columns)
  77. # 删除onehot类别列
  78. for col in related_columns:
  79. del importance_dict[col]
  80. # 添加合并后的重要性
  81. importance_dict[feat] = combined_importance
  82. # 排序
  83. sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
  84. # 输出特征重要性
  85. cust_features_importance = []
  86. product_features_importance = []
  87. for feat, importance in sorted_importance:
  88. if feat in list(ImportanceFeaturesMap.CUSTOM_FEATURES_MAP.keys()):
  89. cust_features_importance.append({ImportanceFeaturesMap.CUSTOM_FEATURES_MAP[feat]: float(importance)})
  90. if feat in list(ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP.keys()):
  91. product_features_importance.append({ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP[feat]: float(importance)})
  92. return cust_features_importance, product_features_importance
  93. def generate_shap_interance(self, data):
  94. # 初始化SHAP解释器
  95. if self._explanier is None:
  96. self._explanier = shap.TreeExplainer(self.gbdt_model)
  97. # 获取数据基本信息
  98. n_samples = len(data)
  99. n_features = len(data.columns)
  100. batch_size = 200 # 可根据内存调整
  101. # 创建临时内存映射文件
  102. # temp_dir = tempfile.mkdtemp()
  103. temp_dir = "./data/tmp"
  104. temp_file = os.path.join(temp_dir, "shap_interaction_temp.dat")
  105. if os.path.exists(temp_dir):
  106. os.remove(temp_file)
  107. else:
  108. os.makedirs(temp_dir)
  109. try:
  110. # 预创建内存映射文件
  111. fp_shape = (n_samples, n_features, n_features)
  112. fp = np.memmap(temp_file, dtype=np.float32,
  113. mode='w+',
  114. shape=fp_shape)
  115. # 分批计算并存储SHAP交互值
  116. for i in tqdm(range(0, n_samples, batch_size), desc="计算SHAP交互值..."):
  117. batch_data = data.iloc[i:i+batch_size]
  118. batch_interaction = self._explanier.shap_interaction_values(batch_data)
  119. fp[i:i+len(batch_interaction)] = batch_interaction.astype(np.float32)
  120. fp.flush() # 确保数据写入磁盘
  121. # 分批计算均值
  122. mean_interaction = np.zeros((n_features, n_features), dtype=np.float32)
  123. for i in tqdm(range(0, n_samples, batch_size), desc="计算均值..."):
  124. batch = fp[i:i+batch_size] # 读取批数据并取绝对值
  125. mean_interaction += batch.sum(axis=0) # 按批累加
  126. mean_interaction /= n_samples # 计算最终均值
  127. # 构建交互矩阵DataFrame
  128. interaction_df = pd.DataFrame(
  129. mean_interaction,
  130. index=data.columns,
  131. columns=data.columns
  132. )
  133. # 分离卷烟和商户特征(应用列名清理)
  134. product_feats = [
  135. clean_column_name(f"{feat}_{item}")
  136. for feat, categories in ProductConfig.ONEHOT_CAT.items()
  137. for item in categories
  138. ]
  139. cust_feats = [
  140. clean_column_name(f"{feat}_{item}")
  141. for feat, categories in {**CustConfig.ONEHOT_CAT}.items()
  142. for item in categories
  143. ]
  144. # 提取交叉区块
  145. cross_matrix = interaction_df.loc[product_feats, cust_feats]
  146. # 转换为长格式
  147. stacked = cross_matrix.stack().reset_index()
  148. stacked.columns = ['product_feat', 'cust_feat', 'relation']
  149. # 过滤掉零值或NaN的配对
  150. filtered = stacked[
  151. (stacked['relation'].abs() > 1e-6) & # 排除极小值
  152. (~stacked['relation'].isna()) # 排除NaN
  153. ].copy()
  154. # 排序结果
  155. results = (
  156. filtered
  157. .sort_values('relation', ascending=False)
  158. .to_dict('records')
  159. )
  160. # 替换特征名称
  161. feats_name_map = {
  162. **ImportanceFeaturesMap.CUSTOM_FEATURES_MAP,
  163. **ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP
  164. }
  165. for item in results:
  166. # 处理产品特征名
  167. product_f = item["product_feat"]
  168. product_feat_name = None
  169. product_feat_value = None
  170. for key in feats_name_map.keys():
  171. if product_f.startswith(key + "_"):
  172. product_feat_name = feats_name_map[key]
  173. product_feat_value = product_f[len(key) + 1:]
  174. break
  175. if product_feat_name:
  176. item["product_feat"] = f"{product_feat_name}({product_feat_value})"
  177. # 处理客户特征名
  178. cust_f = item["cust_feat"]
  179. cust_feat_name = None
  180. cust_feat_value = None
  181. for key in feats_name_map.keys():
  182. if cust_f.startswith(key + "_"):
  183. cust_feat_name = feats_name_map[key]
  184. cust_feat_value = cust_f[len(key) + 1:]
  185. break
  186. if cust_feat_name:
  187. item["cust_feat"] = f"{cust_feat_name}({cust_feat_value})"
  188. # 返回最终结果
  189. return pd.DataFrame(results, columns=['product_feat', 'cust_feat', 'relation'])
  190. finally:
  191. # 清理临时文件
  192. try:
  193. del fp # 必须先删除内存映射对象
  194. gc.collect()
  195. os.remove(temp_file)
  196. os.rmdir(temp_dir)
  197. except Exception as e:
  198. logger.error(f"清理临时文件时出错: {e}")
  199. if __name__ == "__main__":
  200. model_path = "./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl"
  201. city_uuid = "00000000000000000000000011445301"
  202. product_id = "110102"
  203. gbdt_sort = GbdtLrModel(model_path)
  204. # gbdt_sort.sort(city_uuid, product_id)
  205. # cust_features_importance, product_features_importance = gbdt_sort.generate_feats_importance()
  206. # cust_df = pd.DataFrame([
  207. # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  208. # for item in cust_features_importance
  209. # ])
  210. # cust_df.to_csv("./data/cust_feats.csv", index=False)
  211. # product_df = pd.DataFrame([
  212. # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  213. # for item in product_features_importance
  214. # ])
  215. # product_df.to_csv("./data/product_feats.csv", index=False)
  216. data, _ = DataLoader("./data/gbdt/train_data.csv").split_dataset()
  217. data = data["data"].sample(n=300, replace=True, random_state=42)
  218. data.to_csv("./data/data.csv", index=False)
  219. # data = data["data"]