hot_recall.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from sklearn.preprocessing import StandardScaler
  2. from config import load_model_config
  3. from database import RedisDatabaseHelper, MySqlDao
  4. from tqdm import tqdm
  5. from models.rank.data.config import OrderConfig
  6. import numpy as np
  7. cfgs = load_model_config()
  8. class HotRecallModel:
  9. def __init__(self, city_uuid):
  10. self._city_uuid = city_uuid
  11. self._redis_db = RedisDatabaseHelper().redis
  12. self._dao = MySqlDao()
  13. self._load_data()
  14. self._hotkeys = cfgs["hot_recall"]["hot_keys"]
  15. def _load_data(self):
  16. """加载订单记录表"""
  17. print("hot_recall: 正在加载order_info...")
  18. self._order_data = self._dao.load_order_data(self._city_uuid)
  19. self._order_data =self._order_data[OrderConfig.FEATURE_COLUMNS]
  20. # 数据清洗
  21. self._order_data["sale_qty"] = self._order_data["sale_qty"].fillna(0)
  22. self._order_data = self._order_data.groupby(["cust_code", "product_code"], as_index=False)["sale_qty"].sum()
  23. self._order_data = self._order_data[self._order_data["sale_qty"] != 0]
  24. def _calculate_hot_score(self, hot_name):
  25. """
  26. 根据热度指标计算热度得分
  27. :param hot_name: 热度指标A
  28. :type param: string
  29. :return: 所有热度指标的得分
  30. :rtype: list
  31. """
  32. results = self._order_data.groupby("cust_code")[hot_name].sum().reset_index()
  33. sorted_results = results.sort_values(by=hot_name, ascending=False).reset_index(drop=True)
  34. scaler = StandardScaler()
  35. normalized = scaler.fit_transform(sorted_results["sale_qty"].values.reshape(-1, 1))
  36. sorted_results["sale_qty"] = ((1 / (1 + np.exp(-normalized))) * 100).flatten()
  37. item_hot_score = []
  38. for _, row in sorted_results.iterrows():
  39. item_hot_score.append({row["cust_code"]: row[hot_name]})
  40. return {"key":f"{hot_name}", "value":item_hot_score}
  41. def _to_redis(self, rec_content_score):
  42. hotkey_name = rec_content_score["key"]
  43. rec_item_id = f"hot:{self._city_uuid}:{str(hotkey_name)}" # 修正 rec_item_id 拼接方式
  44. # 清空 sorted set 数据,确保不会影响后续的存储
  45. self._redis_db.delete(rec_item_id)
  46. res = {}
  47. for item in rec_content_score["value"]:
  48. for content, score in item.items(): # item 形如 {A001: 75.0}
  49. res[content] = float(score) # 确保 score 是 float 类型
  50. if res: # 只有当 res 不为空时才执行 zadd
  51. self._redis_db.zadd(rec_item_id, res)
  52. def calculate_all_hot_score(self):
  53. """
  54. 计算所有的热度指标得分
  55. """
  56. # hot_datas = []
  57. for hotkey_name in tqdm(self._hotkeys, desc="hot_recall:正在计算热度分数"):
  58. self._to_redis(self._calculate_hot_score(hotkey_name))
  59. if __name__ == "__main__":
  60. hot_recall = HotRecallModel("00000000000000000000000011445301")
  61. hot_recall.calculate_all_hot_score()