ItemCF.py 4.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from database import RedisDatabaseHelper, MySqlDao
  2. import pandas as pd
  3. from models import UserItemScore, SimilarityMatrix
  4. import numpy as np
  5. from tqdm import tqdm
  6. from scipy.sparse import csr_matrix
  7. from joblib import Parallel, delayed
  8. class ItemCFModel:
  9. def __init__(self):
  10. self._recommendations = {}
  11. self._dao = MySqlDao()
  12. def train(self, city_uuid, n=300, k=100, top_n=300, n_jobs=4):
  13. # self._score_df = pd.read_csv(score_path)
  14. # self._similarity_df = pd.read_csv(similatity_path, index_col=0)
  15. print("itemcf: 正在加载order_info...")
  16. self._order_data = self._dao.load_order_data(city_uuid)
  17. print("正在计算品规培育分数...")
  18. self._score_df = UserItemScore(self._order_data).generate_product_scores()
  19. print("正在计算商户相似度矩阵...")
  20. self._similarity_df = SimilarityMatrix(self._order_data).generate_similarity_matrix()
  21. similarity_matrix = csr_matrix(self._similarity_df.values)
  22. shop_index = {shop: idx for idx, shop in enumerate(self._similarity_df.index)}
  23. index_shop = {idx: shop for idx, shop in enumerate(self._similarity_df.index)}
  24. def process_product(product_code, scores):
  25. # 获取热度最高的n个商户
  26. top_n_shops = scores.nlargest(n, "score")["cust_code"].values
  27. top_n_indices = [shop_index[shop] for shop in top_n_shops]
  28. # 找到每个商户最相似的k个商户
  29. similar_shops = {}
  30. for shop_idx in top_n_indices:
  31. similarities = similarity_matrix[shop_idx].toarray().flatten()
  32. similar_indices = np.argpartition(similarities, -k-1)[-k-1:]
  33. similar_indices = similar_indices[similar_indices != shop_idx][:k]
  34. similar_shops[index_shop[shop_idx]] = [index_shop[idx] for idx in similar_indices]
  35. # 生成候选商户列表
  36. candidate_shops = list(set(top_n_shops).union(set([m for sublist in similar_shops.values() for m in sublist])))
  37. candidate_indices = [shop_index[shop] for shop in candidate_shops]
  38. # 计算每个候选商户的兴趣得分
  39. interest_scores = {}
  40. for candidate_idx in candidate_indices:
  41. interest_score = 0
  42. for shop_idx in top_n_indices:
  43. if index_shop[candidate_idx] in similar_shops[index_shop[shop_idx]]:
  44. shop_score = scores[scores["cust_code"]==index_shop[shop_idx]]["score"].values[0]
  45. interest_score += shop_score * similarity_matrix[shop_idx, candidate_idx]
  46. interest_scores[index_shop[candidate_idx]] = interest_score
  47. # 将候选商户的兴趣得分转换为字典列表,并按照从大到小排列
  48. sorted_candidates = sorted([{shop_id: s} for shop_id, s in interest_scores.items()],
  49. key=lambda x: list(x.values())[0], reverse=True)[:top_n]
  50. return product_code, sorted_candidates
  51. # 并行处理每个品规
  52. results = Parallel(n_jobs=n_jobs)(delayed(process_product)(product_code, scores)
  53. for product_code, scores in tqdm(self._score_df.groupby("product_code"), desc="train:正在计算候选得分"))
  54. # 存储结果
  55. self._recommendations = {product_code: sorted_candidates for product_code, sorted_candidates in results}
  56. self.to_redis_zset(city_uuid)
  57. def to_redis_zset(self, city_uuid):
  58. """
  59. 将 self._recommendations 中的数据保存到 Redis 的 Sorted Set (ZSET) 中
  60. 存储格式为 fc:product_code,其中商户 ID 作为成员,得分作为分数
  61. """
  62. redis_db = RedisDatabaseHelper()
  63. # 存redis之前,先进行删除操作
  64. pattern = f"fc:{city_uuid}:*"
  65. keys_to_delete = redis_db.redis.keys(pattern)
  66. if keys_to_delete:
  67. redis_db.redis.delete(*keys_to_delete)
  68. for product_code, recommendations in tqdm(self._recommendations.items(), desc="train:正在存储推荐结果"):
  69. redis_key = f"fc:{city_uuid}:{product_code}"
  70. zset_data = {}
  71. for rec in recommendations:
  72. for shop_id, score in rec.items():
  73. try:
  74. zset_data[shop_id] = float(score)
  75. except ValueError as e:
  76. print(f"Error converting score to float for shop_id {shop_id}: {score}")
  77. raise e
  78. redis_db.redis.zadd(redis_key, zset_data)
  79. if __name__ == "__main__":
  80. itemcf_model = ItemCFModel()
  81. itemcf_model.train("00000000000000000000000011445301", n_jobs=4)