gbdt_lr_sort.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import joblib
  2. from dao import Redis, get_product_by_id, get_custs_by_ids
  3. from models.rank.data import ProductConfig, CustConfig
  4. class GbdtLrSort:
  5. def __init__(self, model_path):
  6. self.load_model(model_path)
  7. self.redis = Redis().redis
  8. def load_model(self, model_path):
  9. models = joblib.load(model_path)
  10. self.gbdt_model = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
  11. def get_recall_list(self, city_uuid, product_id):
  12. """根据卷烟id获取召回的商铺列表"""
  13. key = f"fc:{city_uuid}:{product_id}"
  14. self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
  15. def load_recall_data(self, city_uuid, product_id):
  16. self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  17. self.custs_data = get_custs_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
  18. print(self.product_data)
  19. def generate_feats_map(self, city_uuid, product_id):
  20. """组合卷烟、商户特征矩阵"""
  21. self.get_recall_list(city_uuid, product_id)
  22. self.load_recall_data(city_uuid, product_id)
  23. # 做数据清洗
  24. def sort(self, city_uuid, product_id):
  25. pass
  26. def generate_feats_importance(self):
  27. pass
  28. if __name__ == "__main__":
  29. model_path = "./models/rank/weights/model.pkl"
  30. city_uuid = "00000000000000000000000011445301"
  31. product_id = "110102"
  32. gbdt_sort = GbdtLrSort(model_path)
  33. gbdt_sort.generate_feats_map(city_uuid, product_id)