gbdt_lr_sort.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import joblib
  2. # from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
  3. from database import RedisDatabaseHelper, MySqlDao
  4. from models.rank.data import ProductConfig, CustConfig, ShopConfig, ImportanceFeaturesMap
  5. from models.rank.data.utils import one_hot_embedding, sample_data_clear
  6. import pandas as pd
  7. from sklearn.preprocessing import StandardScaler
  8. import os
  9. class GbdtLrModel:
  10. def __init__(self, model_path):
  11. self.load_model(model_path)
  12. self.redis = RedisDatabaseHelper().redis
  13. self._mysql_dao = MySqlDao()
  14. def load_model(self, model_path):
  15. models = joblib.load(model_path)
  16. self.gbdt_model, self.lr_model, self.onehot_encoder = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
  17. # def get_recall_list(self, city_uuid, product_id):
  18. # """根据卷烟id获取召回的商铺列表"""
  19. # key = f"fc:{city_uuid}:{product_id}"
  20. # self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
  21. # def load_recall_data(self, city_uuid, product_id):
  22. # self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  23. # self.custs_data = self._mysql_dao.get_cust_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
  24. def get_cust_and_product_data(self, city_uuid, product_id):
  25. """从商户数据库中获取指定城市所有商户的id"""
  26. self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  27. self.custs_data = self._mysql_dao.load_cust_data(city_uuid)[CustConfig.FEATURE_COLUMNS]
  28. def generate_feats_map(self, city_uuid, product_id):
  29. """组合卷烟、商户特征矩阵"""
  30. # self.get_recall_list(city_uuid, product_id)
  31. # self.load_recall_data(city_uuid, product_id)
  32. self.get_cust_and_product_data(city_uuid, product_id)
  33. # 做数据清洗
  34. self.product_data = sample_data_clear(self.product_data, ProductConfig)
  35. self.custs_data = sample_data_clear(self.custs_data, CustConfig)
  36. # 笛卡尔积联合
  37. self.custs_data["descartes"] = 1
  38. self.product_data["descartes"] = 1
  39. self.feats_map = pd.merge(self.custs_data, self.product_data, on="descartes").drop("descartes", axis=1)
  40. self.recall_cust_list = self.feats_map["BB_RETAIL_CUSTOMER_CODE"].to_list()
  41. self.feats_map.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
  42. self.feats_map.drop('product_code', axis=1, inplace=True)
  43. # onehot编码
  44. onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  45. onehot_columns = list(onehot_feats.keys())
  46. numeric_columns = self.feats_map.drop(onehot_columns, axis=1).columns
  47. self.feats_map = one_hot_embedding(self.feats_map, onehot_feats)
  48. # 数字特征归一化
  49. scaler = StandardScaler()
  50. self.feats_map[numeric_columns] = scaler.fit_transform(self.feats_map[numeric_columns])
  51. def sort(self, city_uuid, product_id):
  52. self.generate_feats_map(city_uuid, product_id)
  53. gbdt_preds = self.gbdt_model.apply(self.feats_map)[:, :, 0]
  54. gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
  55. scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1]
  56. self.recommend_list = []
  57. for cust_id, score in zip(self.recall_cust_list, scores):
  58. self.recommend_list.append({cust_id: float(score)})
  59. self.recommend_list = sorted(self.recommend_list, key=lambda x: list(x.values())[0], reverse=True)
  60. # for res in self.recommend_list[:200]:
  61. # print(res)
  62. return self.recommend_list
  63. def generate_feats_importance(self):
  64. """生成特征重要性"""
  65. # 获取GBDT模型的特征重要性
  66. feats_importance = self.gbdt_model.feature_importances_
  67. # 获取特征名称
  68. feats_names = self.gbdt_model.feature_names_in_
  69. importance_dict = dict(zip(feats_names, feats_importance))
  70. onehot_feats = {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  71. for feat, categories in onehot_feats.items():
  72. related_columns = [col for col in feats_names if col.startswith(feat)]
  73. if related_columns:
  74. # 合并类别重要性
  75. combined_importance = sum(importance_dict[col] for col in related_columns)
  76. # 删除onehot类别列
  77. for col in related_columns:
  78. del importance_dict[col]
  79. # 添加合并后的重要性
  80. importance_dict[feat] = combined_importance
  81. # 排序
  82. sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
  83. # 输出特征重要性
  84. cust_features_importance = []
  85. product_features_importance = []
  86. for feat, importance in sorted_importance:
  87. if feat in list(ImportanceFeaturesMap.CUSTOM_FEATURES_MAP.keys()):
  88. cust_features_importance.append({ImportanceFeaturesMap.CUSTOM_FEATURES_MAP[feat]: float(importance)})
  89. if feat in list(ImportanceFeaturesMap.SHOPING_FEATURES_MAP.keys()):
  90. cust_features_importance.append({ImportanceFeaturesMap.SHOPING_FEATURES_MAP[feat]: float(importance)})
  91. if feat in list(ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP.keys()):
  92. product_features_importance.append({ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP[feat]: float(importance)})
  93. return cust_features_importance, product_features_importance
  94. if __name__ == "__main__":
  95. model_path = "./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl"
  96. city_uuid = "00000000000000000000000011445301"
  97. product_id = "110102"
  98. gbdt_sort = GbdtLrModel(model_path)
  99. # gbdt_sort.sort(city_uuid, product_id)
  100. cust_features_importance, product_features_importance = gbdt_sort.generate_feats_importance()
  101. cust_df = pd.DataFrame([
  102. {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  103. for item in cust_features_importance
  104. ])
  105. cust_df.to_csv("./data/cust_feats.csv", index=False)
  106. product_df = pd.DataFrame([
  107. {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  108. for item in product_features_importance
  109. ])
  110. product_df.to_csv("./data/product_feats.csv", index=False)