| 123456789101112131415161718192021222324252627282930313233343536373839404142 |
- import joblib
- from dao import Redis, get_product_by_id, get_custs_by_ids
- from models.rank.data import ProductConfig, CustConfig
- class GbdtLrSort:
- def __init__(self, model_path):
- self.load_model(model_path)
- self.redis = Redis().redis
-
- def load_model(self, model_path):
- models = joblib.load(model_path)
- self.gbdt_model = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
-
-
- def get_recall_list(self, city_uuid, product_id):
- """根据卷烟id获取召回的商铺列表"""
- key = f"fc:{city_uuid}:{product_id}"
- self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
-
- def load_recall_data(self, city_uuid, product_id):
- self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
- self.custs_data = get_custs_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
- print(self.product_data)
-
- def generate_feats_map(self, city_uuid, product_id):
- """组合卷烟、商户特征矩阵"""
- self.get_recall_list(city_uuid, product_id)
- self.load_recall_data(city_uuid, product_id)
- # 做数据清洗
-
-
- def sort(self, city_uuid, product_id):
- pass
-
- def generate_feats_importance(self):
- pass
-
- if __name__ == "__main__":
- model_path = "./models/rank/weights/model.pkl"
- city_uuid = "00000000000000000000000011445301"
- product_id = "110102"
- gbdt_sort = GbdtLrSort(model_path)
- gbdt_sort.generate_feats_map(city_uuid, product_id)
|