calculate_similarity_matrix.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from database import MySqlDao
  2. import pandas as pd
  3. import numpy as np
  4. from itertools import combinations
  5. from tqdm import tqdm
  6. dao = MySqlDao()
  7. def build_co_occurence_matrix(order_data):
  8. """
  9. 构建商户共现矩阵
  10. """
  11. # 获取所有商户的唯一列表
  12. shops = order_data["BB_RETAIL_CUSTOMER_CODE"].unique()
  13. num_shops = len(shops)
  14. # 创建商户到索引的映射
  15. shops_to_index = {shop: idx for idx, shop in enumerate(shops)}
  16. # 初始化共现矩阵(上三角部分)
  17. co_occurrence_matrix = np.zeros((num_shops, num_shops), dtype=int)
  18. # 按照品规分组
  19. grouped = order_data.groupby("PRODUCT_CODE")["BB_RETAIL_CUSTOMER_CODE"].apply(list)
  20. # 遍历每个品规的商户列表
  21. for shop_in_product in grouped:
  22. # 生成商户对
  23. shop_pairs = combinations(shop_in_product, 2)
  24. for shop1, shop2 in shop_pairs:
  25. # 获取商户索引
  26. idx1 = shops_to_index[shop1]
  27. idx2 = shops_to_index[shop2]
  28. # 更新共现矩阵
  29. co_occurrence_matrix[idx1, idx2] += 1
  30. co_occurrence_matrix[idx2, idx1] += 1
  31. return co_occurrence_matrix, shops, shops_to_index
  32. def calculate_similarity_matrix(co_occurrence_matrix, order_data, shops_to_index):
  33. """
  34. 使用向量计算商铺之间的相似度矩阵
  35. """
  36. # 计算每个商铺售卖品规的总次数
  37. shop_counts = order_data.groupby("BB_RETAIL_CUSTOMER_CODE").size()
  38. # 将商户售卖次数转换为数组
  39. counts = np.array([shop_counts[shop] for shop in shops_to_index.keys()])
  40. # 计算分母部分 (sqrt(count_i * count_j))
  41. denominator = np.sqrt(np.outer(counts, counts))
  42. # 计算相似度矩阵
  43. similarity_matrix = co_occurrence_matrix / denominator
  44. # 将对角线设置为1
  45. np.fill_diagonal(similarity_matrix, 1.0)
  46. return similarity_matrix
  47. def save_matrix(matrix, shops, save_path):
  48. """
  49. 保存共现矩阵
  50. """
  51. matrix_df = pd.DataFrame(matrix, index=shops, columns=shops)
  52. matrix_df.to_csv(save_path, index=True, encoding="utf-8")
  53. def calculate_similarity_and_save_results(order_data, similarity_matrix_save_path):
  54. co_occurrence_matrix, shops, shops_to_index = build_co_occurence_matrix(order_data)
  55. similarity_matrix = calculate_similarity_matrix(co_occurrence_matrix, order_data, shops_to_index)
  56. save_matrix(similarity_matrix, shops, similarity_matrix_save_path)
  57. if __name__ == "__main__":
  58. co_occurrence_save_path = "./models/recall/itemCF/matrix/occurrence.csv"
  59. similarity_matrix_save_path = "./models/recall/itemCF/matrix/similarity.csv"
  60. # 从数据库中读取订单数据
  61. order_data = dao.load_order_data()
  62. calculate_similarity_and_save_results(order_data, similarity_matrix_save_path)