gbdt_lr_sort.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import joblib
  2. from dao import Redis, get_product_by_id, get_custs_by_ids
  3. from models.rank.data import ProductConfig, CustConfig
  4. from models.rank.data.utils import one_hot_embedding, sample_data_clear
  5. import pandas as pd
  6. from sklearn.preprocessing import StandardScaler
  7. class GbdtLrModel:
  8. def __init__(self, model_path):
  9. self.load_model(model_path)
  10. self.redis = Redis().redis
  11. def load_model(self, model_path):
  12. models = joblib.load(model_path)
  13. self.gbdt_model, self.lr_model, self.onehot_encoder = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
  14. def get_recall_list(self, city_uuid, product_id):
  15. """根据卷烟id获取召回的商铺列表"""
  16. key = f"fc:{city_uuid}:{product_id}"
  17. self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
  18. def load_recall_data(self, city_uuid, product_id):
  19. self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  20. self.custs_data = get_custs_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
  21. def generate_feats_map(self, city_uuid, product_id):
  22. """组合卷烟、商户特征矩阵"""
  23. self.get_recall_list(city_uuid, product_id)
  24. self.load_recall_data(city_uuid, product_id)
  25. # 做数据清洗
  26. self.product_data = sample_data_clear(self.product_data, ProductConfig)
  27. self.custs_data = sample_data_clear(self.custs_data, CustConfig)
  28. # 笛卡尔积联合
  29. self.custs_data["descartes"] = 1
  30. self.product_data["descartes"] = 1
  31. self.feats_map = pd.merge(self.custs_data, self.product_data, on="descartes").drop("descartes", axis=1)
  32. self.recall_cust_list = self.feats_map["BB_RETAIL_CUSTOMER_CODE"].to_list()
  33. self.feats_map.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
  34. self.feats_map.drop('product_code', axis=1, inplace=True)
  35. # onehot编码
  36. onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  37. onehot_columns = list(onehot_feats.keys())
  38. numeric_columns = self.feats_map.drop(onehot_columns, axis=1).columns
  39. self.feats_map = one_hot_embedding(self.feats_map, onehot_feats)
  40. # 数字特征归一化
  41. scaler = StandardScaler()
  42. self.feats_map[numeric_columns] = scaler.fit_transform(self.feats_map[numeric_columns])
  43. def sort(self, city_uuid, product_id):
  44. self.generate_feats_map(city_uuid, product_id)
  45. gbdt_preds = self.gbdt_model.apply(self.feats_map)[:, :, 0]
  46. gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
  47. scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1]
  48. self.recommend_list = []
  49. for cust_id, score in zip(self.recall_cust_list, scores):
  50. self.recommend_list.append({cust_id: float(score)})
  51. self.recommend_list = sorted(self.recommend_list, key=lambda x: list(x.values())[0], reverse=True)
  52. for res in self.recommend_list[:200]:
  53. print(res)
  54. def generate_feats_importance(self):
  55. """生成特征重要性"""
  56. # 获取GBDT模型的特征重要性
  57. feats_importance = self.gbdt_model.feature_importances_
  58. # 获取特征名称
  59. feats_names = self.gbdt_model.feature_names_in_
  60. importance_dict = dict(zip(feats_names, feats_importance))
  61. onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  62. for feat, categories in onehot_feats.items():
  63. related_columns = [col for col in feats_names if col.startswith(feat)]
  64. if related_columns:
  65. # 合并类别重要性
  66. combined_importance = sum(importance_dict[col] for col in related_columns)
  67. # 删除onehot类别列
  68. for col in related_columns:
  69. del importance_dict[col]
  70. # 添加合并后的重要性
  71. importance_dict[feat] = combined_importance
  72. # 排序
  73. sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
  74. # 输出特征重要性
  75. features_importance = []
  76. for feat, importance in sorted_importance:
  77. features_importance.append({feat: float(importance)})
  78. return features_importance
  79. if __name__ == "__main__":
  80. model_path = "./models/rank/weights/model.pkl"
  81. city_uuid = "00000000000000000000000011445301"
  82. product_id = "110102"
  83. gbdt_sort = GbdtLrModel(model_path)
  84. # gbdt_sort.sort(city_uuid, product_id)
  85. importances = gbdt_sort.generate_feats_importance()
  86. for importance in importances:
  87. print(importance)