import argparse import os from models.rank import DataProcess, Trainer, GbdtLrModel from models import ItemCFModel, HotRecallModel import time import pandas as pd # train_data_path = "./moldes/rank/data/gbdt_data.csv" # model_path = "./models/rank/weights" def gbdtlr_train(args): model_dir = os.path.join(args.model_path, args.city_uuid) train_data_dir = args.train_data_dir if not os.path.exists(model_dir): os.makedirs(model_dir) if not os.path.exists(train_data_dir): os.makedirs(train_data_dir) # 准备数据集 print("正在整合训练数据...") processor = DataProcess(args.city_uuid, args.train_data_dir) processor.data_process() print("训练数据整合完成!") # 进行训练 print("开始训练gbdt-lr模型") gbdtlr_trainer(os.path.join(args.train_data_dir, "train_data.csv"), model_dir, "gbdtlr_model.pkl") def gbdtlr_trainer(train_data_path, model_dir, model_name): trainer = Trainer(train_data_path) start_time = time.time() trainer.train() end_time = time.time() training_time_hours = (end_time - start_time) / 3600 print(f"训练时间: {training_time_hours:.4f} 小时") eval_metrics = trainer.evaluate() # 输出评估结果 print("GBDT-LR Evaluation Metrics:") for metric, value in eval_metrics.items(): print(f"{metric}: {value:.4f}") # 保存模型 trainer.save_model(os.path.join(model_dir, model_name)) def itemCF(args): itemcf_model = ItemCFModel() itemcf_model.train(city_uuid=args.city_uuid, n=args.largest_n, k=args.similarity_k, top_n=args.top_n, n_jobs=args.n_jobs) def hot_recall(args): hot_recall = HotRecallModel(args.city_uuid) hot_recall.calculate_all_hot_score() def run(): parser = argparse.ArgumentParser() # 全局参数 parser.add_argument("--run_train", action='store_true') parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301') # GBDT-LR模型训练参数 parser.add_argument("--train_data_dir", type=str, default="./data/gbdt") parser.add_argument("--model_path", type=str, default="./models/rank/weights") # 协同过滤参数 parser.add_argument("--largest_n", type=int, default=300) parser.add_argument("--similarity_k", type=int, default=100) parser.add_argument("--top_n", type=int, default=1500) parser.add_argument("--n_jobs", type=int, default=4) args = parser.parse_args() if args.run_train: print("正在计算协同过滤...") itemCF(args) print("正在计算热度召回...") hot_recall(args) print("正在进行gbdt_lr训练...") gbdtlr_train(args) if __name__ == "__main__": run()