app.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import argparse
  2. from dao import load_order_data_from_mysql
  3. from dao.redis_db import Redis
  4. from models import HotRecallModel, UserItemScore, ItemCFModel, calculate_similarity_and_save_results
  5. import os
  6. def run_hot_recall(order_data):
  7. """运行热度召回算法"""
  8. hot_model = HotRecallModel(order_data)
  9. hot_model.calculate_all_hot_score()
  10. print("热度召回已完成!")
  11. def run_itemcf(order_data, args):
  12. # """运行协同过滤算法"""
  13. if os.path.exists(args.interst_score_path) and os.path.exists(args.similarity_matrix_path):
  14. os.remove(args.interst_score_path)
  15. os.remove(args.similarity_matrix_path)
  16. # 计算user-score-item数据
  17. cal_interest_scores_model = UserItemScore()
  18. scores = cal_interest_scores_model.score(order_data)
  19. scores.to_csv(args.interst_score_path, index=False, encoding="utf-8")
  20. print("Interest Scores cal done!")
  21. # 计算商户共现矩阵及相似度矩阵
  22. calculate_similarity_and_save_results(order_data, args.similarity_matrix_path)
  23. print("Shops similarity matrix cal done!")
  24. # 运行协同过滤召回
  25. itemcf_model = ItemCFModel()
  26. itemcf_model.train(args.interst_score_path, args.similarity_matrix_path, args.n, args.k, args.top_n, args.n_jobs)
  27. print("协同过滤已完成!")
  28. def run_itemcf_inference(product_code):
  29. """
  30. 从 Redis 中读取推荐结果,并返回 {shop_id: score} 的列表
  31. """
  32. redis_db = Redis()
  33. redis_key = f"fc:{product_code}"
  34. recommendations = redis_db.redis.zrange(redis_key, 0, -1, withscores=True, desc=True)
  35. # 将推荐结果转换为 {shop_id: score} 的字典列表
  36. result = [{shop_id: float(score)} for shop_id, score in recommendations]
  37. return result
  38. def run():
  39. parser = argparse.ArgumentParser()
  40. # 运行方式
  41. parser.add_argument("--run_all", action='store_true')
  42. parser.add_argument("--run_hot", action='store_true')
  43. parser.add_argument("--run_itemcf", action='store_true')
  44. parser.add_argument("--run_itemcf_inference", action='store_true')
  45. # 协同过滤相关配置
  46. parser.add_argument("--matrix_path", type=str, default="./models/recall/itemCF/matrix")
  47. # parser.add_argument("--interst_score_path", type=str, default="./models/recall/itemCF/matrix/score.csv")
  48. # parser.add_argument("--similarity_matrix_path", type=str, default="./models/recall/itemCF/matrix/similarity.csv")
  49. parser.add_argument("--n", type=int, default=100)
  50. parser.add_argument("--k", type=int, default=10)
  51. parser.add_argument("--top_n", type=int, default=200, help='default n * k')
  52. parser.add_argument("--n_jobs", type=int, default=4)
  53. # 协同过滤推理配置
  54. parser.add_argument("--product_code", type=int, default=110111)
  55. args = parser.parse_args()
  56. # 初始化文件保存相关配置
  57. if not os.path.exists(args.matrix_path):
  58. os.makedirs(args.matrix_path)
  59. args.interst_score_path = os.path.join(args.matrix_path, "score.csv")
  60. args.similarity_matrix_path = os.path.join(args.matrix_path, "similarity.csv")
  61. if args.run_all:
  62. order_data = load_order_data_from_mysql()
  63. run_hot_recall(order_data)
  64. run_itemcf(order_data, args)
  65. elif args.run_hot:
  66. order_data = load_order_data_from_mysql()
  67. run_hot_recall(order_data)
  68. elif args.run_itemcf:
  69. order_data = load_order_data_from_mysql()
  70. run_itemcf(order_data, args)
  71. elif args.run_itemcf_inference:
  72. recomments = run_itemcf_inference(args.product_code)
  73. print(recomments)
  74. if __name__ == "__main__":
  75. run()