gbdt_lr_sort.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import joblib
  2. from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
  3. from models.rank.data import ProductConfig, CustConfig
  4. from models.rank.data.utils import one_hot_embedding, sample_data_clear
  5. import pandas as pd
  6. from sklearn.preprocessing import StandardScaler
  7. class GbdtLrModel:
  8. def __init__(self, model_path):
  9. self.load_model(model_path)
  10. self.redis = Redis().redis
  11. def load_model(self, model_path):
  12. models = joblib.load(model_path)
  13. self.gbdt_model, self.lr_model, self.onehot_encoder = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
  14. # def get_recall_list(self, city_uuid, product_id):
  15. # """根据卷烟id获取召回的商铺列表"""
  16. # key = f"fc:{city_uuid}:{product_id}"
  17. # self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
  18. # def load_recall_data(self, city_uuid, product_id):
  19. # self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  20. # self.custs_data = get_custs_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
  21. def get_cust_and_product_data(self, city_uuid, product_id):
  22. """从商户数据库中获取指定城市所有商户的id"""
  23. self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  24. self.custs_data = load_cust_data_from_mysql(city_uuid)[CustConfig.FEATURE_COLUMNS]
  25. def generate_feats_map(self, city_uuid, product_id):
  26. """组合卷烟、商户特征矩阵"""
  27. # self.get_recall_list(city_uuid, product_id)
  28. # self.load_recall_data(city_uuid, product_id)
  29. self.get_cust_and_product_data(city_uuid, product_id)
  30. # 做数据清洗
  31. self.product_data = sample_data_clear(self.product_data, ProductConfig)
  32. self.custs_data = sample_data_clear(self.custs_data, CustConfig)
  33. # 笛卡尔积联合
  34. self.custs_data["descartes"] = 1
  35. self.product_data["descartes"] = 1
  36. self.feats_map = pd.merge(self.custs_data, self.product_data, on="descartes").drop("descartes", axis=1)
  37. self.recall_cust_list = self.feats_map["BB_RETAIL_CUSTOMER_CODE"].to_list()
  38. self.feats_map.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
  39. self.feats_map.drop('product_code', axis=1, inplace=True)
  40. # onehot编码
  41. onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  42. onehot_columns = list(onehot_feats.keys())
  43. numeric_columns = self.feats_map.drop(onehot_columns, axis=1).columns
  44. self.feats_map = one_hot_embedding(self.feats_map, onehot_feats)
  45. # 数字特征归一化
  46. scaler = StandardScaler()
  47. self.feats_map[numeric_columns] = scaler.fit_transform(self.feats_map[numeric_columns])
  48. def sort(self, city_uuid, product_id):
  49. self.generate_feats_map(city_uuid, product_id)
  50. gbdt_preds = self.gbdt_model.apply(self.feats_map)[:, :, 0]
  51. gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
  52. scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1]
  53. self.recommend_list = []
  54. for cust_id, score in zip(self.recall_cust_list, scores):
  55. self.recommend_list.append({cust_id: float(score)})
  56. self.recommend_list = sorted(self.recommend_list, key=lambda x: list(x.values())[0], reverse=True)
  57. for res in self.recommend_list[:200]:
  58. print(res)
  59. return self.recommend_list
  60. def generate_feats_importance(self):
  61. """生成特征重要性"""
  62. # 获取GBDT模型的特征重要性
  63. feats_importance = self.gbdt_model.feature_importances_
  64. # 获取特征名称
  65. feats_names = self.gbdt_model.feature_names_in_
  66. importance_dict = dict(zip(feats_names, feats_importance))
  67. onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
  68. for feat, categories in onehot_feats.items():
  69. related_columns = [col for col in feats_names if col.startswith(feat)]
  70. if related_columns:
  71. # 合并类别重要性
  72. combined_importance = sum(importance_dict[col] for col in related_columns)
  73. # 删除onehot类别列
  74. for col in related_columns:
  75. del importance_dict[col]
  76. # 添加合并后的重要性
  77. importance_dict[feat] = combined_importance
  78. # 排序
  79. sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
  80. # 输出特征重要性
  81. features_importance = []
  82. for feat, importance in sorted_importance:
  83. features_importance.append({feat: float(importance)})
  84. return features_importance
  85. if __name__ == "__main__":
  86. model_path = "./models/rank/weights/model.pkl"
  87. city_uuid = "00000000000000000000000011445301"
  88. product_id = "110102"
  89. gbdt_sort = GbdtLrModel(model_path)
  90. gbdt_sort.sort(city_uuid, product_id)
  91. # importances = gbdt_sort.generate_feats_importance()
  92. # for importance in importances:
  93. # print(importance)