hot_recall.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from sklearn.preprocessing import StandardScaler
  2. from config import load_model_config
  3. from database import RedisDatabaseHelper, MySqlDao
  4. from tqdm import tqdm
  5. from models.rank.data.config import OrderConfig
  6. import numpy as np
  7. cfgs = load_model_config()
  8. class HotRecallModel:
  9. def __init__(self, city_uuid):
  10. self._redis_db = RedisDatabaseHelper().redis
  11. self._dao = MySqlDao()
  12. self._load_data(city_uuid)
  13. self._hotkeys = cfgs["hot_recall"]["hot_keys"]
  14. def _load_data(self, city_uuid):
  15. """加载订单记录表"""
  16. print("hot_recall: 正在加载order_info...")
  17. self._order_data = self._dao.load_order_data(city_uuid)
  18. self._order_data =self._order_data[OrderConfig.FEATURE_COLUMNS]
  19. # 数据清洗
  20. self._order_data["sale_qty"] = self._order_data["sale_qty"].fillna(0)
  21. self._order_data = self._order_data.groupby(["cust_code", "product_code"], as_index=False)["sale_qty"].sum()
  22. self._order_data = self._order_data[self._order_data["sale_qty"] != 0]
  23. def _calculate_hot_score(self, hot_name):
  24. """
  25. 根据热度指标计算热度得分
  26. :param hot_name: 热度指标A
  27. :type param: string
  28. :return: 所有热度指标的得分
  29. :rtype: list
  30. """
  31. results = self._order_data.groupby("cust_code")[hot_name].mean().reset_index()
  32. sorted_results = results.sort_values(by=hot_name, ascending=False).reset_index(drop=True)
  33. scaler = StandardScaler()
  34. normalized = scaler.fit_transform(sorted_results["sale_qty"].values.reshape(-1, 1))
  35. sorted_results["sale_qty"] = ((1 / (1 + np.exp(-normalized))) * 100).flatten()
  36. item_hot_score = []
  37. for _, row in sorted_results.iterrows():
  38. item_hot_score.append({row["cust_code"]: row[hot_name]})
  39. return {"key":f"{hot_name}", "value":item_hot_score}
  40. def _to_redis(self, rec_content_score, city_uuid):
  41. hotkey_name = rec_content_score["key"]
  42. rec_item_id = f"hot:{city_uuid}:{str(hotkey_name)}" # 修正 rec_item_id 拼接方式
  43. # 清空 sorted set 数据,确保不会影响后续的存储
  44. self._redis_db.delete(rec_item_id)
  45. res = {}
  46. for item in rec_content_score["value"]:
  47. for content, score in item.items(): # item 形如 {A001: 75.0}
  48. res[content] = float(score) # 确保 score 是 float 类型
  49. if res: # 只有当 res 不为空时才执行 zadd
  50. self._redis_db.zadd(rec_item_id, res)
  51. def calculate_all_hot_score(self, city_uuid):
  52. """
  53. 计算所有的热度指标得分
  54. """
  55. # hot_datas = []
  56. for hotkey_name in tqdm(self._hotkeys, desc="hot_recall:正在计算热度分数"):
  57. self._to_redis(self._calculate_hot_score(hotkey_name), city_uuid)
  58. if __name__ == "__main__":
  59. hot_recall = HotRecallModel("00000000000000000000000011445301")
  60. hot_recall.calculate_all_hot_score("00000000000000000000000011445301")