gbdt_lr.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import argparse
  2. import os
  3. from models.rank import DataProcess, Trainer, GbdtLrModel
  4. import time
  5. import pandas as pd
  6. # train_data_path = "./moldes/rank/data/gbdt_data.csv"
  7. # model_path = "./models/rank/weights"
  8. def train(args):
  9. model_dir = os.path.join(args.model_path, args.city_uuid)
  10. if not os.path.exists(model_dir):
  11. os.makedirs(model_dir)
  12. # 准备数据集
  13. print("正在整合训练数据...")
  14. processor = DataProcess(args.city_uuid, args.train_data_path)
  15. processor.data_process()
  16. print("训练数据整合完成!")
  17. # 进行训练
  18. trainer(args, model_dir)
  19. def trainer(args, model_dir):
  20. trainer = Trainer(args.train_data_path)
  21. start_time = time.time()
  22. trainer.train()
  23. end_time = time.time()
  24. training_time_hours = (end_time - start_time) / 3600
  25. print(f"训练时间: {training_time_hours:.4f} 小时")
  26. eval_metrics = trainer.evaluate()
  27. # 输出评估结果
  28. print("GBDT-LR Evaluation Metrics:")
  29. for metric, value in eval_metrics.items():
  30. print(f"{metric}: {value:.4f}")
  31. # 保存模型
  32. trainer.save_model(os.path.join(model_dir, args.model_name))
  33. def recommend_by_product(args):
  34. model_dir = os.path.join(args.model_path, args.city_uuid)
  35. if not os.path.exists(model_dir):
  36. print("暂无该城市的模型,请先进行模型训练")
  37. return
  38. # 加载模型
  39. model = GbdtLrModel(os.path.join(model_dir, args.model_name))
  40. recommend_list = model.sort(args.city_uuid, args.product_id)
  41. for item in recommend_list[:min(args.last_n, len(recommend_list))]:
  42. print(item)
  43. def get_features_importance(args):
  44. model_dir = os.path.join(args.model_path, args.city_uuid)
  45. if not os.path.exists(model_dir):
  46. print("暂无该城市的模型,请先进行模型训练")
  47. return
  48. # 加载模型
  49. model = GbdtLrModel(os.path.join(model_dir, args.model_name))
  50. cust_features_importance, product_features_importance = model.generate_feats_importance()
  51. # 将字典列表转换为 DataFrame
  52. cust_df = pd.DataFrame([
  53. {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  54. for item in cust_features_importance
  55. ])
  56. product_df = pd.DataFrame([
  57. {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  58. for item in product_features_importance
  59. ])
  60. cust_file_path = os.path.join(model_dir, "cust_features_importance.csv")
  61. product_file_path = os.path.join(model_dir, "product_features_importance.csv")
  62. cust_df.to_csv(cust_file_path, index=False, encoding='utf-8')
  63. product_df.to_csv(product_file_path, index=False, encoding='utf-8')
  64. def run():
  65. parser = argparse.ArgumentParser()
  66. parser.add_argument("--run_train", action='store_true')
  67. parser.add_argument("--recommend", action='store_true')
  68. parser.add_argument("--importance", action='store_true')
  69. parser.add_argument("--train_data_path", type=str, default="./models/rank/data/gbdt_data.csv")
  70. parser.add_argument("--model_path", type=str, default="./models/rank/weights")
  71. parser.add_argument("--model_name", type=str, default='model.pkl')
  72. parser.add_argument("--last_n", type=int, default=200)
  73. parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301')
  74. parser.add_argument("--product_id", type=str, default='110102')
  75. args = parser.parse_args()
  76. if args.run_train:
  77. train(args)
  78. if args.recommend:
  79. recommend_by_product(args)
  80. if args.importance:
  81. get_features_importance(args)
  82. if __name__ == "__main__":
  83. run()