import argparse from dao import load_order_data_from_mysql from dao.redis_db import Redis from models import HotRecallModel, UserItemScore, ItemCFModel, calculate_similarity_and_save_results import os def run_hot_recall(order_data): """运行热度召回算法""" hot_model = HotRecallModel(order_data) hot_model.calculate_all_hot_score() print("热度召回已完成!") def run_itemcf(order_data, args): # """运行协同过滤算法""" if os.path.exists(args.interst_score_path) and os.path.exists(args.similarity_matrix_path): os.remove(args.interst_score_path) os.remove(args.similarity_matrix_path) # 计算user-score-item数据 cal_interest_scores_model = UserItemScore() scores = cal_interest_scores_model.score(order_data) scores.to_csv(args.interst_score_path, index=False, encoding="utf-8") print("Interest Scores cal done!") # 计算商户共现矩阵及相似度矩阵 calculate_similarity_and_save_results(order_data, args.similarity_matrix_path) print("Shops similarity matrix cal done!") # 运行协同过滤召回 itemcf_model = ItemCFModel() itemcf_model.train(args.interst_score_path, args.similarity_matrix_path, args.n, args.k, args.top_n, args.n_jobs) print("协同过滤已完成!") def run_itemcf_inference(product_code): """ 从 Redis 中读取推荐结果,并返回 {shop_id: score} 的列表 """ redis_db = Redis() redis_key = f"fc:{product_code}" recommendations = redis_db.redis.zrange(redis_key, 0, -1, withscores=True, desc=True) # 将推荐结果转换为 {shop_id: score} 的字典列表 result = [{shop_id: float(score)} for shop_id, score in recommendations] return result def run(): parser = argparse.ArgumentParser() # 运行方式 parser.add_argument("--run_all", action='store_true') parser.add_argument("--run_hot", action='store_true') parser.add_argument("--run_itemcf", action='store_true') parser.add_argument("--run_itemcf_inference", action='store_true') # 协同过滤相关配置 parser.add_argument("--matrix_path", type=str, default="./models/recall/itemCF/matrix") # parser.add_argument("--interst_score_path", type=str, default="./models/recall/itemCF/matrix/score.csv") # parser.add_argument("--similarity_matrix_path", type=str, default="./models/recall/itemCF/matrix/similarity.csv") parser.add_argument("--n", type=int, default=100) parser.add_argument("--k", type=int, default=10) parser.add_argument("--top_n", type=int, default=200, help='default n * k') parser.add_argument("--n_jobs", type=int, default=4) # 协同过滤推理配置 parser.add_argument("--product_code", type=int, default=110111) args = parser.parse_args() # 初始化文件保存相关配置 if not os.path.exists(args.matrix_path): os.makedirs(args.matrix_path) args.interst_score_path = os.path.join(args.matrix_path, "score.csv") args.similarity_matrix_path = os.path.join(args.matrix_path, "similarity.csv") if args.run_all: order_data = load_order_data_from_mysql() run_hot_recall(order_data) run_itemcf(order_data, args) elif args.run_hot: order_data = load_order_data_from_mysql() run_hot_recall(order_data) elif args.run_itemcf: order_data = load_order_data_from_mysql() run_itemcf(order_data, args) elif args.run_itemcf_inference: recomments = run_itemcf_inference(args.product_code) print(recomments) if __name__ == "__main__": run()