| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- import argparse
- from dao import load_order_data_from_mysql
- from dao.redis_db import Redis
- from models import HotRecallModel, UserItemScore, ItemCFModel, calculate_similarity_and_save_results
- import os
- def run_hot_recall(order_data, city_uuid):
- """运行热度召回算法"""
- hot_model = HotRecallModel(order_data)
- hot_model.calculate_all_hot_score(city_uuid)
- print("热度召回已完成!")
- def run_itemcf(order_data, args):
- # """运行协同过滤算法"""
- if os.path.exists(args.interst_score_path) and os.path.exists(args.similarity_matrix_path):
- os.remove(args.interst_score_path)
- os.remove(args.similarity_matrix_path)
-
- # 计算user-score-item数据
- cal_interest_scores_model = UserItemScore()
- scores = cal_interest_scores_model.score(order_data)
- scores.to_csv(args.interst_score_path, index=False, encoding="utf-8")
- print("Interest Scores cal done!")
-
- # 计算商户共现矩阵及相似度矩阵
- calculate_similarity_and_save_results(order_data, args.similarity_matrix_path)
- print("Shops similarity matrix cal done!")
-
- # 运行协同过滤召回
- itemcf_model = ItemCFModel()
- itemcf_model.train(args.interst_score_path, args.similarity_matrix_path, args.city_uuid, args.n, args.k, args.top_n, args.n_jobs)
- print("协同过滤已完成!")
- def run_itemcf_inference(product_code):
- """
- 从 Redis 中读取推荐结果,并返回 {shop_id: score} 的列表
- """
- redis_db = Redis()
- redis_key = f"fc:{product_code}"
- recommendations = redis_db.redis.zrange(redis_key, 0, -1, withscores=True, desc=True)
-
- # 将推荐结果转换为 {shop_id: score} 的字典列表
- result = [{shop_id: float(score)} for shop_id, score in recommendations]
-
- return result
- def run():
- parser = argparse.ArgumentParser()
-
- # 运行方式
- parser.add_argument("--run_all", action='store_true')
- parser.add_argument("--run_hot", action='store_true')
- parser.add_argument("--run_itemcf", action='store_true')
- parser.add_argument("--run_itemcf_inference", action='store_true')
-
- # 协同过滤相关配置
- parser.add_argument("--matrix_path", type=str, default="./models/recall/itemCF/matrix")
- # parser.add_argument("--interst_score_path", type=str, default="./models/recall/itemCF/matrix/score.csv")
- # parser.add_argument("--similarity_matrix_path", type=str, default="./models/recall/itemCF/matrix/similarity.csv")
- parser.add_argument("--n", type=int, default=100)
- parser.add_argument("--k", type=int, default=20)
- parser.add_argument("--top_n", type=int, default=2000, help='default n * k')
- parser.add_argument("--n_jobs", type=int, default=4)
- parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011441801', help="City UUID for filtering data")
-
- # 协同过滤推理配置
- parser.add_argument("--product_code", type=int, default=110111)
-
- args = parser.parse_args()
-
- # 初始化文件保存相关配置
- if not os.path.exists(args.matrix_path):
- os.makedirs(args.matrix_path)
- args.interst_score_path = os.path.join(args.matrix_path, "score.csv")
- args.similarity_matrix_path = os.path.join(args.matrix_path, "similarity.csv")
-
-
- if args.run_all:
- order_data = load_order_data_from_mysql(args.city_uuid)
- if order_data is not None:
- run_hot_recall(order_data, args.city_uuid)
- run_itemcf(order_data, args)
- else:
- print("数据库中暂无数据")
-
- elif args.run_hot:
- order_data = load_order_data_from_mysql(args.city_uuid)
- if order_data is not None:
- run_hot_recall(order_data, args.city_uuid)
- else:
- print("数据库中暂无数据")
-
- elif args.run_itemcf:
- order_data = load_order_data_from_mysql(args.city_uuid)
- if order_data is not None:
- run_itemcf(order_data, args)
- else:
- print("数据库中暂无数据")
-
- elif args.run_itemcf_inference:
- recomments = run_itemcf_inference(args.product_code)
- print(recomments)
-
- if __name__ == "__main__":
- run()
|