gbdt_lr_sort.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import joblib
  2. # from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
  3. from database import RedisDatabaseHelper, MySqlDao
  4. from models.rank.data import ProductConfig, CustConfig, ImportanceFeaturesMap
  5. from models.rank.data.utils import one_hot_embedding, sample_data_clear
  6. import pandas as pd
  7. from sklearn.preprocessing import StandardScaler
  8. class GbdtLrModel:
  9. def __init__(self, model_path):
  10. self.load_model(model_path)
  11. self.redis = RedisDatabaseHelper().redis
  12. self._mysql_dao = MySqlDao()
  13. def load_model(self, model_path):
  14. models = joblib.load(model_path)
  15. self.gbdt_model, self.lr_model, self.onehot_encoder = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
  16. # def get_recall_list(self, city_uuid, product_id):
  17. # """根据卷烟id获取召回的商铺列表"""
  18. # key = f"fc:{city_uuid}:{product_id}"
  19. # self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
  20. # def load_recall_data(self, city_uuid, product_id):
  21. # self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  22. # self.custs_data = self._mysql_dao.get_cust_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
  23. def get_cust_and_product_data(self, city_uuid, product_id):
  24. """从商户数据库中获取指定城市所有商户的id"""
  25. self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  26. self.custs_data = self._mysql_dao.load_cust_data(city_uuid)[CustConfig.FEATURE_COLUMNS]
  27. def generate_feats_map(self, city_uuid, product_id):
  28. """组合卷烟、商户特征矩阵"""
  29. # self.get_recall_list(city_uuid, product_id)
  30. # self.load_recall_data(city_uuid, product_id)
  31. self.get_cust_and_product_data(city_uuid, product_id)
  32. # 做数据清洗
  33. self.product_data = sample_data_clear(self.product_data, ProductConfig)
  34. self.custs_data = sample_data_clear(self.custs_data, CustConfig)
  35. # 笛卡尔积联合
  36. self.custs_data["descartes"] = 1
  37. self.product_data["descartes"] = 1
  38. self.feats_map = pd.merge(self.custs_data, self.product_data, on="descartes").drop("descartes", axis=1)
  39. self.recall_cust_list = self.feats_map["BB_RETAIL_CUSTOMER_CODE"].to_list()
  40. self.feats_map.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
  41. self.feats_map.drop('product_code', axis=1, inplace=True)
  42. # onehot编码
  43. onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  44. onehot_columns = list(onehot_feats.keys())
  45. numeric_columns = self.feats_map.drop(onehot_columns, axis=1).columns
  46. self.feats_map = one_hot_embedding(self.feats_map, onehot_feats)
  47. # 数字特征归一化
  48. scaler = StandardScaler()
  49. self.feats_map[numeric_columns] = scaler.fit_transform(self.feats_map[numeric_columns])
  50. def sort(self, city_uuid, product_id):
  51. self.generate_feats_map(city_uuid, product_id)
  52. gbdt_preds = self.gbdt_model.apply(self.feats_map)[:, :, 0]
  53. gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
  54. scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1]
  55. self.recommend_list = []
  56. for cust_id, score in zip(self.recall_cust_list, scores):
  57. self.recommend_list.append({cust_id: float(score)})
  58. self.recommend_list = sorted(self.recommend_list, key=lambda x: list(x.values())[0], reverse=True)
  59. # for res in self.recommend_list[:200]:
  60. # print(res)
  61. return self.recommend_list
  62. def generate_feats_importance(self):
  63. """生成特征重要性"""
  64. # 获取GBDT模型的特征重要性
  65. feats_importance = self.gbdt_model.feature_importances_
  66. # 获取特征名称
  67. feats_names = self.gbdt_model.feature_names_in_
  68. importance_dict = dict(zip(feats_names, feats_importance))
  69. onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  70. for feat, categories in onehot_feats.items():
  71. related_columns = [col for col in feats_names if col.startswith(feat)]
  72. if related_columns:
  73. # 合并类别重要性
  74. combined_importance = sum(importance_dict[col] for col in related_columns)
  75. # 删除onehot类别列
  76. for col in related_columns:
  77. del importance_dict[col]
  78. # 添加合并后的重要性
  79. importance_dict[feat] = combined_importance
  80. # 排序
  81. sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
  82. # 输出特征重要性
  83. cust_features_importance = []
  84. product_features_importance = []
  85. for feat, importance in sorted_importance:
  86. if feat in list(ImportanceFeaturesMap.CUSTOM_FEATRUES_MAP.keys()):
  87. cust_features_importance.append({ImportanceFeaturesMap.CUSTOM_FEATRUES_MAP[feat]: float(importance)})
  88. if feat in list(ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP.keys()):
  89. product_features_importance.append({ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP[feat]: float(importance)})
  90. if feat in list(ImportanceFeaturesMap.SHOPING_FEATURES_MAP.keys()):
  91. product_features_importance.append({ImportanceFeaturesMap.SHOPING_FEATURES_MAP[feat]: float(importance)})
  92. return cust_features_importance, product_features_importance
  93. if __name__ == "__main__":
  94. model_path = "./models/rank/weights/model.pkl"
  95. city_uuid = "00000000000000000000000011445301"
  96. product_id = "110102"
  97. gbdt_sort = GbdtLrModel(model_path)
  98. # gbdt_sort.sort(city_uuid, product_id)
  99. importances = gbdt_sort.generate_feats_importance()
  100. for importance in importances:
  101. print(importance)