hot_recall.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from sklearn.preprocessing import StandardScaler
  2. from config import load_model_config
  3. from database import RedisDatabaseHelper, MySqlDao
  4. from tqdm import tqdm
  5. from models.rank.data.config import OrderConfig
  6. import numpy as np
  7. from core import get_logger
  8. logger = get_logger("models.recall.hot")
  9. cfgs = load_model_config()
  10. class HotRecallModel:
  11. def __init__(self, city_uuid):
  12. self._city_uuid = city_uuid
  13. self._redis_db = RedisDatabaseHelper().redis
  14. self._dao = MySqlDao()
  15. self._load_data()
  16. self._hotkeys = cfgs["hot_recall"]["hot_keys"]
  17. def _load_data(self):
  18. """加载订单记录表"""
  19. logger.info("Loading order data")
  20. self._order_data = self._dao.load_order_data(self._city_uuid)
  21. self._order_data =self._order_data[OrderConfig.FEATURE_COLUMNS]
  22. # 数据清洗
  23. self._order_data["sale_qty"] = self._order_data["sale_qty"].fillna(0)
  24. self._order_data = self._order_data.groupby(["cust_code", "product_code"], as_index=False)["sale_qty"].sum()
  25. self._order_data = self._order_data[self._order_data["sale_qty"] != 0]
  26. def _calculate_hot_score(self, hot_name):
  27. """
  28. 根据热度指标计算热度得分
  29. :param hot_name: 热度指标A
  30. :type param: string
  31. :return: 所有热度指标的得分
  32. :rtype: list
  33. """
  34. results = self._order_data.groupby("cust_code")[hot_name].sum().reset_index()
  35. sorted_results = results.sort_values(by=hot_name, ascending=False).reset_index(drop=True)
  36. scaler = StandardScaler()
  37. normalized = scaler.fit_transform(sorted_results["sale_qty"].values.reshape(-1, 1))
  38. sorted_results["sale_qty"] = ((1 / (1 + np.exp(-normalized))) * 100).flatten()
  39. item_hot_score = []
  40. for _, row in sorted_results.iterrows():
  41. item_hot_score.append({row["cust_code"]: row[hot_name]})
  42. return {"key":f"{hot_name}", "value":item_hot_score}
  43. def _to_redis(self, rec_content_score):
  44. hotkey_name = rec_content_score["key"]
  45. rec_item_id = f"hot:{self._city_uuid}:{str(hotkey_name)}" # 修正 rec_item_id 拼接方式
  46. # 清空 sorted set 数据,确保不会影响后续的存储
  47. self._redis_db.delete(rec_item_id)
  48. res = {}
  49. for item in rec_content_score["value"]:
  50. for content, score in item.items(): # item 形如 {A001: 75.0}
  51. res[content] = float(score) # 确保 score 是 float 类型
  52. if res: # 只有当 res 不为空时才执行 zadd
  53. self._redis_db.zadd(rec_item_id, res)
  54. def calculate_all_hot_score(self):
  55. """
  56. 计算所有的热度指标得分
  57. """
  58. # hot_datas = []
  59. for hotkey_name in tqdm(self._hotkeys, desc="hot_recall:正在计算热度分数"):
  60. self._to_redis(self._calculate_hot_score(hotkey_name))
  61. if __name__ == "__main__":
  62. hot_recall = HotRecallModel("00000000000000000000000011445301")
  63. hot_recall.calculate_all_hot_score()