| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- import argparse
- import os
- from models.rank import DataProcess, Trainer, GbdtLrModel
- from models import ItemCFModel, HotRecallModel
- import time
- import pandas as pd
- # train_data_path = "./moldes/rank/data/gbdt_data.csv"
- # model_path = "./models/rank/weights"
- def gbdtlr_train(args):
- model_dir = os.path.join(args.model_path, args.city_uuid)
- train_data_dir = args.train_data_dir
- if not os.path.exists(model_dir):
- os.makedirs(model_dir)
-
- if not os.path.exists(train_data_dir):
- os.makedirs(train_data_dir)
-
- # 准备数据集
- print("正在整合训练数据...")
- processor = DataProcess(args.city_uuid, args.train_data_dir)
- processor.data_process()
- print("训练数据整合完成!")
-
- # 进行训练
- print("开始训练gbdt-lr模型")
- gbdtlr_trainer(os.path.join(args.train_data_dir, "train_data.csv"), model_dir, "gbdtlr_model.pkl")
- def gbdtlr_trainer(train_data_path, model_dir, model_name):
- trainer = Trainer(train_data_path)
-
- start_time = time.time()
- trainer.train()
- end_time = time.time()
-
- training_time_hours = (end_time - start_time) / 3600
- print(f"训练时间: {training_time_hours:.4f} 小时")
-
- eval_metrics = trainer.evaluate()
-
- # 输出评估结果
- print("GBDT-LR Evaluation Metrics:")
- for metric, value in eval_metrics.items():
- print(f"{metric}: {value:.4f}")
-
- # 保存模型
- trainer.save_model(os.path.join(model_dir, model_name))
- def itemCF(args):
- itemcf_model = ItemCFModel()
- itemcf_model.train(city_uuid=args.city_uuid, n=args.largest_n, k=args.similarity_k, top_n=args.top_n, n_jobs=args.n_jobs)
- def hot_recall(args):
- hot_recall = HotRecallModel(args.city_uuid)
- hot_recall.calculate_all_hot_score()
-
- def run():
- parser = argparse.ArgumentParser()
- # 全局参数
- parser.add_argument("--run_train", action='store_true')
- parser.add_argument("--run_recall", action='store_true')
- parser.add_argument("--run_gbdtlr", action='store_true')
-
- parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301')
-
- # GBDT-LR模型训练参数
- parser.add_argument("--train_data_dir", type=str, default="./data/gbdt")
- parser.add_argument("--model_path", type=str, default="./models/rank/weights")
-
- # 协同过滤参数
- parser.add_argument("--largest_n", type=int, default=300)
- parser.add_argument("--similarity_k", type=int, default=100)
- parser.add_argument("--top_n", type=int, default=1500)
- parser.add_argument("--n_jobs", type=int, default=4)
-
-
- args = parser.parse_args()
-
- if args.run_train:
- print("正在计算协同过滤...")
- itemCF(args)
-
- print("正在计算热度召回...")
- hot_recall(args)
-
- print("正在进行gbdt_lr训练...")
- gbdtlr_train(args)
-
- if args.run_recall:
- print("正在计算协同过滤...")
- itemCF(args)
-
- print("正在计算热度召回...")
- hot_recall(args)
-
- if args.run_gbdtlr:
- print("正在进行gbdt_lr训练...")
- gbdtlr_train(args)
-
- if __name__ == "__main__":
- run()
|