recommend.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from database.dao.mysql_dao import MySqlDao
  2. from database.db.redis_db import RedisDatabaseHelper
  3. import os
  4. from models.item2vec.inference import Item2VecModel
  5. from models.rank.data.config import CustConfig, ProductConfig
  6. from models.rank.data.utils import sample_data_clear
  7. from models.rank import GbdtLrModel, generate_feats_map
  8. import pandas as pd
  9. from core import get_logger
  10. logger = get_logger("models.recommend")
  11. class Recommend:
  12. def __init__(self, city_uuid):
  13. self._redis = RedisDatabaseHelper().redis
  14. self._dao = MySqlDao()
  15. self._load_molde(city_uuid)
  16. def _load_molde(self, city_uuid):
  17. """加载推演模型"""
  18. self._city_uuid = city_uuid
  19. gbdtlr_model_path = os.path.join("./models/rank/weights", city_uuid, "gbdtlr_model.pkl")
  20. self._gbdtlr_model = GbdtLrModel(gbdtlr_model_path)
  21. self._item2vec_model = Item2VecModel(city_uuid)
  22. logger.info(f"Models loaded for city_uuid={city_uuid}")
  23. def _get_itemcf_recall(self, product_id):
  24. """协同召回"""
  25. key = f"fc:{self._city_uuid}:{product_id}"
  26. recall_list = self._redis.zrevrange(key, 0, -1, withscores=False)
  27. return recall_list
  28. def get_recal_cust(self, product_id, cust_code_list):
  29. """通过协同过滤召回与核心零售户列表取并集,得到待推荐商户列表"""
  30. itemcf_recall_list = self._get_itemcf_recall(product_id)
  31. seen = set(itemcf_recall_list)
  32. extra = [c for c in cust_code_list if c not in seen]
  33. result = list(itemcf_recall_list) + extra
  34. logger.info(f"Recall completed: {len(result)} customers (itemcf={len(itemcf_recall_list)}, core_extra={len(extra)}) for product {product_id}")
  35. return result
  36. def get_recommend_list_by_gbdtlr(self, product_id, cust_code_list=None):
  37. """根据gbdt_lr获取商户推荐列表"""
  38. logger.info(f"GBDT-LR recommend started for product {product_id}")
  39. # 获取召回的商户列表
  40. if cust_code_list is None:
  41. cust_code_list = []
  42. recall_cust_list = self.get_recal_cust(product_id, cust_code_list)
  43. # 获取卷烟数据
  44. product_data = self._dao.get_product_by_id(self._city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  45. product_data = sample_data_clear(product_data, ProductConfig)
  46. # 获取整合商户数据
  47. cust_data = self._dao.get_cust_by_ids(self._city_uuid, recall_cust_list)[CustConfig.FEATURE_COLUMNS]
  48. # shop_data = self._dao.get_shop_by_ids(self._city_uuid, recall_cust_list)[ShopConfig.FEATURE_COLUMNS]
  49. cust_data = sample_data_clear(cust_data, CustConfig)
  50. # shop_data = sample_data_clear(shop_data, ShopConfig)
  51. # cust_feats = shop_data.set_index("cust_code")
  52. # cust_data = cust_data.join(cust_feats, on="BB_RETAIL_CUSTOMER_CODE", how="inner")
  53. # 按 recall_cust_list 顺序对齐 cust_data,确保 feats_map 行顺序与 recall_list 一致
  54. # 否则 get_recommend_list 中 zip(recall_list, scores) 会错配商户ID和分数
  55. cust_codes_in_data = set(cust_data["cust_code"].tolist())
  56. ordered_recall_list = [c for c in recall_cust_list if c in cust_codes_in_data]
  57. cust_order = {code: i for i, code in enumerate(ordered_recall_list)}
  58. cust_data = cust_data.sort_values("cust_code", key=lambda x: x.map(cust_order)).reset_index(drop=True)
  59. # 获取推理用的feats_map
  60. feats_map = generate_feats_map(product_data, cust_data)
  61. recommend_list = self._gbdtlr_model.get_recommend_list(feats_map, ordered_recall_list)
  62. # recommend_list = self.filter_recommend_list(recommend_list)
  63. logger.info(f"GBDT-LR recommend completed: {len(recommend_list)} results")
  64. return recommend_list
  65. def get_recommend_list_by_item2vec(self, product_id, cust_code_list=None):
  66. """根据item2vec获取商户推荐列表,核心商户并入候选集统一评分"""
  67. if cust_code_list is None:
  68. cust_code_list = []
  69. logger.info(f"Item2Vec recommend started for product {product_id}")
  70. recommend_list = self._item2vec_model.get_recommend_cust_list(product_id, cust_code_list=cust_code_list)
  71. recommend_list = recommend_list.drop(columns=["sale_qty"])
  72. recommend_list = recommend_list.to_dict(orient='records')
  73. # recommend_list = self.filter_recommend_list(recommend_list)
  74. logger.info(f"Item2Vec recommend completed: {len(recommend_list)} results")
  75. return recommend_list
  76. def filter_recommend_list(self, recommend_list):
  77. """过滤掉已经歇业的商铺"""
  78. cust_set = set(self._dao.get_cust_list(self._city_uuid))
  79. filter_recommend_list = [
  80. item for item in recommend_list
  81. if item["cust_code"] in cust_set
  82. ]
  83. return filter_recommend_list
  84. if __name__ == "__main__":
  85. city_uuid = "00000000000000000000000011445301"
  86. product_id = '350139'
  87. recommend = Recommend(city_uuid)
  88. recommend_list = recommend.get_recommend_list_by_gbdtlr(product_id)
  89. # for i in recommend_list:
  90. # print(i)