recommend.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. from database.dao.mysql_dao import MySqlDao
  2. from database.db.redis_db import RedisDatabaseHelper
  3. import os
  4. from models.item2vec.inference import Item2VecModel
  5. from models.rank.data.config import CustConfig, ProductConfig
  6. from models.rank.data.utils import sample_data_clear
  7. from models.rank import GbdtLrModel, generate_feats_map
  8. import pandas as pd
  9. from core import get_logger
  10. logger = get_logger("models.recommend")
  11. CORE_RERANK_CONFIG = {
  12. "existing": {
  13. "core_model_weight": 0.75,
  14. "core_quality_weight": 0.15,
  15. "core_boost": 35,
  16. "low_model_threshold": 35,
  17. "low_model_weight": 0.85,
  18. "low_quality_weight": 0.10,
  19. "low_core_boost": 65,
  20. "normal_model_weight": 0.90,
  21. },
  22. "new": {
  23. "core_model_weight": 0.55,
  24. "core_quality_weight": 0.25,
  25. "core_boost": 50,
  26. "normal_model_weight": 0.90,
  27. },
  28. }
  29. class Recommend:
  30. def __init__(self, city_uuid):
  31. self._redis = RedisDatabaseHelper().redis
  32. self._dao = MySqlDao()
  33. self._load_molde(city_uuid)
  34. def _load_molde(self, city_uuid):
  35. """加载推演模型"""
  36. self._city_uuid = city_uuid
  37. gbdtlr_model_path = os.path.join("./models/rank/weights", city_uuid, "gbdtlr_model.pkl")
  38. self._gbdtlr_model = GbdtLrModel(gbdtlr_model_path)
  39. self._item2vec_model = Item2VecModel(city_uuid)
  40. logger.info(f"Models loaded for city_uuid={city_uuid}")
  41. def _get_itemcf_recall(self, product_id):
  42. """协同召回"""
  43. key = f"fc:{self._city_uuid}:{product_id}"
  44. recall_list = self._redis.zrevrange(key, 0, -1, withscores=False)
  45. return recall_list
  46. def get_recal_cust(self, product_id, cust_code_list):
  47. """通过协同过滤召回与核心零售户列表取并集,得到待推荐商户列表"""
  48. itemcf_recall_list = self._get_itemcf_recall(product_id)
  49. seen = set(itemcf_recall_list)
  50. extra = [c for c in cust_code_list if c not in seen]
  51. result = list(itemcf_recall_list) + extra
  52. logger.info(f"Recall completed: {len(result)} customers (itemcf={len(itemcf_recall_list)}, core_extra={len(extra)}) for product {product_id}")
  53. return result
  54. def get_recommend_list_by_gbdtlr(self, product_id, cust_code_list=None):
  55. """根据gbdt_lr获取商户推荐列表"""
  56. logger.info(f"GBDT-LR recommend started for product {product_id}")
  57. # 获取召回的商户列表
  58. if cust_code_list is None:
  59. cust_code_list = []
  60. recall_cust_list = self.get_recal_cust(product_id, cust_code_list)
  61. # 获取卷烟数据
  62. product_data = self._dao.get_product_by_id(self._city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  63. product_data = sample_data_clear(product_data, ProductConfig)
  64. # 获取整合商户数据
  65. cust_data = self._dao.get_cust_by_ids(self._city_uuid, recall_cust_list)[CustConfig.FEATURE_COLUMNS]
  66. # shop_data = self._dao.get_shop_by_ids(self._city_uuid, recall_cust_list)[ShopConfig.FEATURE_COLUMNS]
  67. cust_data = sample_data_clear(cust_data, CustConfig)
  68. # shop_data = sample_data_clear(shop_data, ShopConfig)
  69. # cust_feats = shop_data.set_index("cust_code")
  70. # cust_data = cust_data.join(cust_feats, on="BB_RETAIL_CUSTOMER_CODE", how="inner")
  71. # 按 recall_cust_list 顺序对齐 cust_data,确保 feats_map 行顺序与 recall_list 一致
  72. # 否则 get_recommend_list 中 zip(recall_list, scores) 会错配商户ID和分数
  73. cust_codes_in_data = set(cust_data["cust_code"].tolist())
  74. ordered_recall_list = [c for c in recall_cust_list if c in cust_codes_in_data]
  75. cust_order = {code: i for i, code in enumerate(ordered_recall_list)}
  76. cust_data = cust_data.sort_values("cust_code", key=lambda x: x.map(cust_order)).reset_index(drop=True)
  77. # 获取推理用的feats_map
  78. feats_map = generate_feats_map(product_data, cust_data)
  79. recommend_list = self._gbdtlr_model.get_recommend_list(feats_map, ordered_recall_list)
  80. recommend_list = self._rerank_existing_product(recommend_list, cust_code_list)
  81. # recommend_list = self.filter_recommend_list(recommend_list)
  82. logger.info(f"GBDT-LR recommend completed: {len(recommend_list)} results")
  83. return recommend_list
  84. def get_recommend_list_by_item2vec(self, product_id, cust_code_list=None):
  85. """根据item2vec获取商户推荐列表,核心商户并入候选集统一评分"""
  86. if cust_code_list is None:
  87. cust_code_list = []
  88. logger.info(f"Item2Vec recommend started for product {product_id}")
  89. recommend_list = self._item2vec_model.get_recommend_cust_list(product_id, cust_code_list=cust_code_list)
  90. recommend_list = recommend_list.drop(columns=["sale_qty"])
  91. recommend_list = recommend_list.to_dict(orient='records')
  92. recommend_list = self._rerank_new_product(recommend_list, cust_code_list)
  93. # recommend_list = self.filter_recommend_list(recommend_list)
  94. logger.info(f"Item2Vec recommend completed: {len(recommend_list)} results")
  95. return recommend_list
  96. def _rerank_existing_product(self, recommend_list, core_cust_list):
  97. """Rerank existing-product results with core-customer boosts, then sort by score."""
  98. core_set = {str(cust_code) for cust_code in (core_cust_list or [])}
  99. if not core_set or not recommend_list:
  100. return recommend_list
  101. quality_score_map = self._build_quality_score_map(core_set)
  102. cfg = CORE_RERANK_CONFIG["existing"]
  103. for item in recommend_list:
  104. cust_code = str(item["cust_code"])
  105. model_score = float(item.get("recommend_score", 0) or 0)
  106. is_core = cust_code in core_set
  107. quality_score = quality_score_map.get(cust_code, 60.0)
  108. if is_core:
  109. if model_score >= cfg["low_model_threshold"]:
  110. final_score = (
  111. model_score * cfg["core_model_weight"]
  112. + quality_score * cfg["core_quality_weight"]
  113. + cfg["core_boost"]
  114. )
  115. else:
  116. final_score = (
  117. model_score * cfg["low_model_weight"]
  118. + quality_score * cfg["low_quality_weight"]
  119. + cfg["low_core_boost"]
  120. )
  121. else:
  122. final_score = model_score * cfg["normal_model_weight"]
  123. item["model_score"] = model_score
  124. item["is_core_cust"] = is_core
  125. item["core_quality_score"] = quality_score if is_core else None
  126. item["recommend_score"] = min(float(final_score), 100.0)
  127. recommend_list.sort(key=lambda x: x["recommend_score"], reverse=True)
  128. logger.info(f"Core boost rerank completed for existing product: core_count={len(core_set)}")
  129. return recommend_list
  130. def _rerank_new_product(self, recommend_list, core_cust_list):
  131. """Rerank Item2Vec cold-start results with core-customer boosts, then sort by score."""
  132. core_set = {str(cust_code) for cust_code in (core_cust_list or [])}
  133. if not core_set or not recommend_list:
  134. return recommend_list
  135. quality_score_map = self._build_quality_score_map(core_set)
  136. cfg = CORE_RERANK_CONFIG["new"]
  137. for item in recommend_list:
  138. cust_code = str(item["cust_code"])
  139. model_score = float(item.get("recommend_score", 0) or 0)
  140. is_core = cust_code in core_set
  141. quality_score = quality_score_map.get(cust_code, 60.0)
  142. if is_core:
  143. final_score = (
  144. model_score * cfg["core_model_weight"]
  145. + quality_score * cfg["core_quality_weight"]
  146. + cfg["core_boost"]
  147. )
  148. else:
  149. final_score = model_score * cfg["normal_model_weight"]
  150. item["item2vec_score"] = model_score
  151. item["is_core_cust"] = is_core
  152. item["core_quality_score"] = quality_score if is_core else None
  153. item["recommend_score"] = min(float(final_score), 100.0)
  154. recommend_list.sort(key=lambda x: x["recommend_score"], reverse=True)
  155. logger.info(f"Core boost rerank completed for new product: core_count={len(core_set)}")
  156. return recommend_list
  157. def _build_quality_score_map(self, cust_list):
  158. """Build a 0-100 business-quality score for candidate customers."""
  159. if not cust_list:
  160. return {}
  161. unique_cust_list = list(dict.fromkeys(str(cust_code) for cust_code in cust_list))
  162. cust_data = self._dao.get_cust_by_ids(self._city_uuid, unique_cust_list)
  163. if cust_data.empty:
  164. return {cust_code: 60.0 for cust_code in unique_cust_list}
  165. score_map = {}
  166. for _, row in cust_data.iterrows():
  167. cust_code = row.get("cust_code")
  168. if pd.isna(cust_code):
  169. continue
  170. score_map[str(cust_code)] = self._calculate_core_quality_score(row)
  171. for cust_code in unique_cust_list:
  172. score_map.setdefault(cust_code, 60.0)
  173. return score_map
  174. def _calculate_core_quality_score(self, row):
  175. """Calculate a 0-100 quality score using only fields defined in CustConfig."""
  176. field_scores = [
  177. ("terminal_star_name", {
  178. "五星终端": 100,
  179. "四星终端": 90,
  180. "三星终端": 80,
  181. "二星终端": 70,
  182. "一星终端": 60,
  183. "其他": 50,
  184. "无": 40,
  185. }, 0.18),
  186. ("cooperate_codename", {
  187. "好": 90,
  188. "较好": 75,
  189. "一般": 60,
  190. }, 0.14),
  191. ("store_appearance_name", {
  192. "好": 90,
  193. "较好": 75,
  194. "一般": 60,
  195. "差": 40,
  196. }, 0.12),
  197. ("is_modern_terminalname", {
  198. "是": 85,
  199. "否": 55,
  200. }, 0.10),
  201. ("modern_terminal_name", {
  202. "直营终端": 95,
  203. "合作终端": 90,
  204. "加盟终端": 85,
  205. "一般现代终端": 75,
  206. "普通终端": 60,
  207. "无法识别": 50,
  208. }, 0.08),
  209. ("cooperate_type_name", {
  210. "品牌加盟": 90,
  211. "冠名加盟": 85,
  212. "无": 55,
  213. }, 0.08),
  214. ("creditclass_name", {
  215. "AAA": 95,
  216. "AA": 90,
  217. "A": 85,
  218. "C": 60,
  219. "D": 45,
  220. }, 0.10),
  221. ("counter_status_name", {
  222. "有": 80,
  223. "计划中": 65,
  224. "无": 50,
  225. }, 0.05),
  226. ("counter_put_type_name", {
  227. "独立陈列": 85,
  228. "混杂陈列": 70,
  229. "无陈列": 50,
  230. }, 0.05),
  231. ("back_counter_status_name", {
  232. "有": 80,
  233. "计划中": 65,
  234. "无": 50,
  235. }, 0.04),
  236. ("back_counter_put_type_name", {
  237. "独立陈列": 85,
  238. "混杂陈列": 70,
  239. "无陈列": 50,
  240. }, 0.03),
  241. ("back_counter_has_show_name", {
  242. "有": 80,
  243. "无": 50,
  244. }, 0.03),
  245. ]
  246. weighted_score = 0.0
  247. total_weight = 0.0
  248. for field, score_map, weight in field_scores:
  249. score = self._score_by_config_value(row, field, score_map)
  250. if score is None:
  251. continue
  252. weighted_score += score * weight
  253. total_weight += weight
  254. for field, weight in [("counter_number", 0.05), ("back_counter_number", 0.05)]:
  255. score = self._counter_score(self._get_row_value(row, field))
  256. if score is None:
  257. continue
  258. weighted_score += score * weight
  259. total_weight += weight
  260. if total_weight == 0:
  261. return 60.0
  262. return round(weighted_score / total_weight, 4)
  263. def _score_by_config_value(self, row, field, score_map):
  264. if field not in CustConfig.FEATURE_COLUMNS:
  265. return None
  266. value = self._get_row_value(row, field)
  267. if pd.isna(value):
  268. return None
  269. text = str(value)
  270. if text not in CustConfig.ONEHOT_CAT.get(field, []):
  271. return None
  272. return float(score_map.get(text, 60.0))
  273. @staticmethod
  274. def _get_row_value(row, field):
  275. if field not in row.index:
  276. return None
  277. return row.get(field)
  278. @staticmethod
  279. def _counter_score(value):
  280. if pd.isna(value):
  281. return None
  282. try:
  283. number = float(value)
  284. except (TypeError, ValueError):
  285. return None
  286. if number <= 0:
  287. return 50.0
  288. if number >= 4:
  289. return 90.0
  290. return 50.0 + number * 10.0
  291. def filter_recommend_list(self, recommend_list):
  292. """过滤掉已经歇业的商铺"""
  293. cust_set = set(self._dao.get_cust_list(self._city_uuid))
  294. filter_recommend_list = [
  295. item for item in recommend_list
  296. if item["cust_code"] in cust_set
  297. ]
  298. return filter_recommend_list
  299. if __name__ == "__main__":
  300. city_uuid = "00000000000000000000000011445301"
  301. product_id = '350139'
  302. recommend = Recommend(city_uuid)
  303. recommend_list = recommend.get_recommend_list_by_gbdtlr(product_id)
  304. # for i in recommend_list:
  305. # print(i)