train.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import argparse
  2. import os
  3. from models.rank import DataProcess, Trainer, GbdtLrModel
  4. from models import ItemCFModel, HotRecallModel
  5. import time
  6. import pandas as pd
  7. from core import get_logger
  8. logger = get_logger("train")
  9. # train_data_path = "./moldes/rank/data/gbdt_data.csv"
  10. # model_path = "./models/rank/weights"
  11. def gbdtlr_train(args):
  12. model_dir = os.path.join(args.model_path, args.city_uuid)
  13. train_data_dir = args.train_data_dir
  14. if not os.path.exists(model_dir):
  15. os.makedirs(model_dir)
  16. if not os.path.exists(train_data_dir):
  17. os.makedirs(train_data_dir)
  18. # 准备数据集
  19. logger.info("正在整合训练数据...")
  20. processor = DataProcess(args.city_uuid, args.train_data_dir)
  21. processor.data_process()
  22. logger.info("训练数据整合完成")
  23. # 进行训练
  24. logger.info("开始训练gbdt-lr模型")
  25. gbdtlr_trainer(os.path.join(args.train_data_dir, "train_data.csv"), model_dir, "gbdtlr_model.pkl")
  26. def gbdtlr_trainer(train_data_path, model_dir, model_name):
  27. trainer = Trainer(train_data_path)
  28. start_time = time.time()
  29. trainer.train()
  30. end_time = time.time()
  31. training_time_hours = (end_time - start_time) / 3600
  32. logger.info(f"训练时间: {training_time_hours:.4f} 小时")
  33. eval_metrics = trainer.evaluate()
  34. # 输出评估结果
  35. logger.info("GBDT-LR Evaluation Metrics:")
  36. for metric, value in eval_metrics.items():
  37. logger.info(f"{metric}: {value:.4f}")
  38. # 保存模型
  39. trainer.save_model(os.path.join(model_dir, model_name))
  40. def itemCF(args):
  41. itemcf_model = ItemCFModel()
  42. itemcf_model.train(city_uuid=args.city_uuid, n=args.largest_n, k=args.similarity_k, top_n=args.top_n, n_jobs=args.n_jobs)
  43. def hot_recall(args):
  44. hot_recall = HotRecallModel(args.city_uuid)
  45. hot_recall.calculate_all_hot_score()
  46. def run():
  47. parser = argparse.ArgumentParser()
  48. # 全局参数
  49. parser.add_argument("--run_train", action='store_true')
  50. parser.add_argument("--run_recall", action='store_true')
  51. parser.add_argument("--run_gbdtlr", action='store_true')
  52. parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301')
  53. # GBDT-LR模型训练参数
  54. parser.add_argument("--train_data_dir", type=str, default="./data/gbdt")
  55. parser.add_argument("--model_path", type=str, default="./models/rank/weights")
  56. # 协同过滤参数
  57. parser.add_argument("--largest_n", type=int, default=300)
  58. parser.add_argument("--similarity_k", type=int, default=100)
  59. parser.add_argument("--top_n", type=int, default=1500)
  60. parser.add_argument("--n_jobs", type=int, default=2)
  61. args = parser.parse_args()
  62. if args.run_train:
  63. logger.info("正在计算协同过滤...")
  64. itemCF(args)
  65. logger.info("正在计算热度召回...")
  66. hot_recall(args)
  67. logger.info("正在进行gbdt_lr训练...")
  68. gbdtlr_train(args)
  69. if args.run_recall:
  70. logger.info("正在计算协同过滤...")
  71. itemCF(args)
  72. logger.info("正在计算热度召回...")
  73. hot_recall(args)
  74. if args.run_gbdtlr:
  75. logger.info("正在进行gbdt_lr训练...")
  76. gbdtlr_train(args)
  77. if __name__ == "__main__":
  78. run()