| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- from database.dao.mysql_dao import MySqlDao
- from database.db.redis_db import RedisDatabaseHelper
- import os
- from models.item2vec.inference import Item2VecModel
- from models.rank.data.config import CustConfig, ProductConfig, ShopConfig
- from models.rank.data.utils import sample_data_clear
- from models.rank.gbdt_lr_inference import GbdtLrModel
- class Recommend:
- def __init__(self, city_uuid):
- self._redis = RedisDatabaseHelper().redis
- self._dao = MySqlDao()
-
- self._load_molde(city_uuid)
-
- def _load_molde(self, city_uuid):
- """加载推演模型"""
- self._city_uuid = city_uuid
- gbdtlr_model_path = os.path.join("./models/rank/weights", city_uuid, "gbdtlr_model.pkl")
- self._gbdtlr_model = GbdtLrModel(gbdtlr_model_path)
- self._item2vec_model = Item2VecModel(city_uuid)
-
- def _get_itemcf_recall(self, product_id):
- """协同召回"""
- key = f"fc:{self._city_uuid}:{product_id}"
- recall_list = self._redis.zrevrange(key, 0, -1, withscores=False)
- return recall_list
-
- def _get_hot_recall(self):
- """热度召回"""
- key = f"hot:{self._city_uuid}:sale_qty"
- recall_list = self._redis.zrevrange(key, 0, -1, withscores=False)
- return recall_list
-
- def _get_recal_cust(self, product_id, recall_count):
- """通过协同过滤和热度召回,召回待推荐商户列表"""
- itemcf_recall_list = self._get_itemcf_recall(product_id)
- hot_recall_list = self._get_hot_recall()
-
- result = list(dict.fromkeys(itemcf_recall_list))
- # 如果结果不足,从hot_recall中补齐
- if len(result) < recall_count:
- hot_recall_set = set(hot_recall_list) - set(result)
- additional_items = [item for item in hot_recall_list if item in hot_recall_set]
- needed = recall_count - len(result)
- result.extend(additional_items[:needed])
-
- return result[:recall_count]
-
- def get_recommend_list_by_gbdtlr(self, product_id, recall_count=100, discovery_count=500):
- """根据gbdt_lr获取商户推荐列表"""
- # 获取召回的商户列表
- recall_cust_list = self._get_recal_cust(product_id, recall_count)
- print(len(recall_cust_list))
- # 获取卷烟数据
- product_data = self._dao.get_product_by_id(self._city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
- product_data = sample_data_clear(product_data, ProductConfig)
-
- # 获取整合商户数据
- cust_data = self._dao.get_cust_by_ids(self._city_uuid, recall_cust_list)[CustConfig.FEATURE_COLUMNS]
- shop_data = self._dao.get_shop_by_ids(self._city_uuid, recall_cust_list)[ShopConfig.FEATURE_COLUMNS]
- cust_data = sample_data_clear(cust_data, CustConfig)
- shop_data = sample_data_clear(shop_data, ShopConfig)
-
- cust_feats = shop_data.set_index("cust_code")
- cust_data = cust_data.join(cust_feats, on="BB_RETAIL_CUSTOMER_CODE", how="inner")
-
- # 获取推理用的feats_map
- feats_map = self._gbdtlr_model.generate_feats_map(product_data, cust_data)
- print(len(cust_data))
- recommend_list = self._gbdtlr_model.get_recommend_list(feats_map, recall_cust_list)
-
- return recommend_list
-
- if __name__ == "__main__":
- city_uuid = "00000000000000000000000011445301"
- product_id = '110110'
- recommend = Recommend(city_uuid)
- recommend_list = recommend.get_recommend_list_by_gbdtlr(product_id)
|