train.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import argparse
  2. import os
  3. from models.rank import DataProcess, Trainer, GbdtLrModel
  4. from models import ItemCFModel, HotRecallModel
  5. import time
  6. import pandas as pd
  7. # train_data_path = "./moldes/rank/data/gbdt_data.csv"
  8. # model_path = "./models/rank/weights"
  9. def gbdtlr_train(args):
  10. model_dir = os.path.join(args.model_path, args.city_uuid)
  11. train_data_dir = args.train_data_dir
  12. if not os.path.exists(model_dir):
  13. os.makedirs(model_dir)
  14. if not os.path.exists(train_data_dir):
  15. os.makedirs(train_data_dir)
  16. # 准备数据集
  17. print("正在整合训练数据...")
  18. processor = DataProcess(args.city_uuid, args.train_data_dir)
  19. processor.data_process()
  20. print("训练数据整合完成!")
  21. # 进行训练
  22. print("开始训练gbdt-lr模型")
  23. gbdtlr_trainer(os.path.join(args.train_data_dir, "train_data.csv"), model_dir, "gbdtlr_model.pkl")
  24. def gbdtlr_trainer(train_data_path, model_dir, model_name):
  25. trainer = Trainer(train_data_path)
  26. start_time = time.time()
  27. trainer.train()
  28. end_time = time.time()
  29. training_time_hours = (end_time - start_time) / 3600
  30. print(f"训练时间: {training_time_hours:.4f} 小时")
  31. eval_metrics = trainer.evaluate()
  32. # 输出评估结果
  33. print("GBDT-LR Evaluation Metrics:")
  34. for metric, value in eval_metrics.items():
  35. print(f"{metric}: {value:.4f}")
  36. # 保存模型
  37. trainer.save_model(os.path.join(model_dir, model_name))
  38. def itemCF(args):
  39. itemcf_model = ItemCFModel()
  40. itemcf_model.train(city_uuid=args.city_uuid, n=args.largest_n, k=args.similarity_k, top_n=args.top_n, n_jobs=args.n_jobs)
  41. def hot_recall(args):
  42. hot_recall = HotRecallModel(args.city_uuid)
  43. hot_recall.calculate_all_hot_score()
  44. def run():
  45. parser = argparse.ArgumentParser()
  46. # 全局参数
  47. parser.add_argument("--run_train", action='store_true')
  48. parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301')
  49. # GBDT-LR模型训练参数
  50. parser.add_argument("--train_data_dir", type=str, default="./data/gbdt")
  51. parser.add_argument("--model_path", type=str, default="./models/rank/weights")
  52. # 协同过滤参数
  53. parser.add_argument("--largest_n", type=int, default=300)
  54. parser.add_argument("--similarity_k", type=int, default=100)
  55. parser.add_argument("--top_n", type=int, default=1500)
  56. parser.add_argument("--n_jobs", type=int, default=4)
  57. args = parser.parse_args()
  58. if args.run_train:
  59. print("正在计算协同过滤...")
  60. itemCF(args)
  61. print("正在计算热度召回...")
  62. hot_recall(args)
  63. print("正在进行gbdt_lr训练...")
  64. gbdtlr_train(args)
  65. if __name__ == "__main__":
  66. run()