import joblib from dao import Redis, get_product_by_id, get_custs_by_ids from models.rank.data import ProductConfig, CustConfig class GbdtLrSort: def __init__(self, model_path): self.load_model(model_path) self.redis = Redis().redis def load_model(self, model_path): models = joblib.load(model_path) self.gbdt_model = models["gbdt_model"], models["lr_model"], models["onehot_encoder"] def get_recall_list(self, city_uuid, product_id): """根据卷烟id获取召回的商铺列表""" key = f"fc:{city_uuid}:{product_id}" self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False) def load_recall_data(self, city_uuid, product_id): self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS] self.custs_data = get_custs_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS] print(self.product_data) def generate_feats_map(self, city_uuid, product_id): """组合卷烟、商户特征矩阵""" self.get_recall_list(city_uuid, product_id) self.load_recall_data(city_uuid, product_id) # 做数据清洗 def sort(self, city_uuid, product_id): pass def generate_feats_importance(self): pass if __name__ == "__main__": model_path = "./models/rank/weights/model.pkl" city_uuid = "00000000000000000000000011445301" product_id = "110102" gbdt_sort = GbdtLrSort(model_path) gbdt_sort.generate_feats_map(city_uuid, product_id)