recommend.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from database.dao.mysql_dao import MySqlDao
  2. from database.db.redis_db import RedisDatabaseHelper
  3. import os
  4. from models.item2vec.inference import Item2VecModel
  5. from models.rank.data.config import CustConfig, ProductConfig, ShopConfig
  6. from models.rank.data.utils import sample_data_clear
  7. from models.rank.gbdt_lr_inference import GbdtLrModel
  8. class Recommend:
  9. def __init__(self, city_uuid):
  10. self._redis = RedisDatabaseHelper().redis
  11. self._dao = MySqlDao()
  12. self._load_molde(city_uuid)
  13. def _load_molde(self, city_uuid):
  14. """加载推演模型"""
  15. self._city_uuid = city_uuid
  16. gbdtlr_model_path = os.path.join("./models/rank/weights", city_uuid, "gbdtlr_model.pkl")
  17. self._gbdtlr_model = GbdtLrModel(gbdtlr_model_path)
  18. self._item2vec_model = Item2VecModel(city_uuid)
  19. def _get_itemcf_recall(self, product_id):
  20. """协同召回"""
  21. key = f"fc:{self._city_uuid}:{product_id}"
  22. recall_list = self._redis.zrevrange(key, 0, -1, withscores=False)
  23. return recall_list
  24. def _get_hot_recall(self):
  25. """热度召回"""
  26. key = f"hot:{self._city_uuid}:sale_qty"
  27. recall_list = self._redis.zrevrange(key, 0, -1, withscores=False)
  28. return recall_list
  29. def _get_recal_cust(self, product_id, recall_count):
  30. """通过协同过滤和热度召回,召回待推荐商户列表"""
  31. itemcf_recall_list = self._get_itemcf_recall(product_id)
  32. hot_recall_list = self._get_hot_recall()
  33. result = list(dict.fromkeys(itemcf_recall_list))
  34. # 如果结果不足,从hot_recall中补齐
  35. if len(result) < recall_count:
  36. hot_recall_set = set(hot_recall_list) - set(result)
  37. additional_items = [item for item in hot_recall_list if item in hot_recall_set]
  38. needed = recall_count - len(result)
  39. result.extend(additional_items[:needed])
  40. return result[:recall_count]
  41. def get_recommend_list_by_gbdtlr(self, product_id, recall_count=100, discovery_count=500):
  42. """根据gbdt_lr获取商户推荐列表"""
  43. # 获取召回的商户列表
  44. recall_cust_list = self._get_recal_cust(product_id, recall_count)
  45. print(len(recall_cust_list))
  46. # 获取卷烟数据
  47. product_data = self._dao.get_product_by_id(self._city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  48. product_data = sample_data_clear(product_data, ProductConfig)
  49. # 获取整合商户数据
  50. cust_data = self._dao.get_cust_by_ids(self._city_uuid, recall_cust_list)[CustConfig.FEATURE_COLUMNS]
  51. shop_data = self._dao.get_shop_by_ids(self._city_uuid, recall_cust_list)[ShopConfig.FEATURE_COLUMNS]
  52. cust_data = sample_data_clear(cust_data, CustConfig)
  53. shop_data = sample_data_clear(shop_data, ShopConfig)
  54. cust_feats = shop_data.set_index("cust_code")
  55. cust_data = cust_data.join(cust_feats, on="BB_RETAIL_CUSTOMER_CODE", how="inner")
  56. # 获取推理用的feats_map
  57. feats_map = self._gbdtlr_model.generate_feats_map(product_data, cust_data)
  58. print(len(cust_data))
  59. recommend_list = self._gbdtlr_model.get_recommend_list(feats_map, recall_cust_list)
  60. return recommend_list
  61. if __name__ == "__main__":
  62. city_uuid = "00000000000000000000000011445301"
  63. product_id = '110110'
  64. recommend = Recommend(city_uuid)
  65. recommend_list = recommend.get_recommend_list_by_gbdtlr(product_id)