gbdt_lr.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import argparse
  2. import os
  3. from models.rank import DataProcess, Trainer, GbdtLrModel
  4. import time
  5. import pandas as pd
  6. # train_data_path = "./moldes/rank/data/gbdt_data.csv"
  7. # model_path = "./models/rank/weights"
  8. def train(args):
  9. model_dir = os.path.join(args.model_path, args.city_uuid)
  10. train_data_dir = args.train_data_dir
  11. if not os.path.exists(model_dir):
  12. os.makedirs(model_dir)
  13. if not os.path.exists(train_data_dir):
  14. os.makedirs(train_data_dir)
  15. # 准备数据集
  16. print("正在整合训练数据...")
  17. processor = DataProcess(args.city_uuid, args.train_data_dir)
  18. processor.data_process()
  19. print("训练数据整合完成!")
  20. # 进行训练
  21. print("开始训练原始模型")
  22. trainer(args, os.path.join(args.train_data_dir, "original_train_data.csv"), model_dir, "ori_model.pkl")
  23. print("开始训练pos模型")
  24. trainer(args, os.path.join(args.train_data_dir, "pos_train_data.csv"), model_dir, "pos_model.pkl")
  25. print("开始训练shopping模型")
  26. trainer(args, os.path.join(args.train_data_dir, "shopping_train_data.csv"), model_dir, "shopping_model.pkl")
  27. def trainer(args, train_data_path, model_dir, model_name):
  28. trainer = Trainer(train_data_path)
  29. start_time = time.time()
  30. trainer.train()
  31. end_time = time.time()
  32. training_time_hours = (end_time - start_time) / 3600
  33. print(f"训练时间: {training_time_hours:.4f} 小时")
  34. eval_metrics = trainer.evaluate()
  35. # 输出评估结果
  36. print("GBDT-LR Evaluation Metrics:")
  37. for metric, value in eval_metrics.items():
  38. print(f"{metric}: {value:.4f}")
  39. # 保存模型
  40. trainer.save_model(os.path.join(model_dir, model_name))
  41. def recommend_by_product(args):
  42. model_dir = os.path.join(args.model_path, args.city_uuid)
  43. if not os.path.exists(model_dir):
  44. print("暂无该城市的模型,请先进行模型训练")
  45. return
  46. # 加载模型
  47. model = GbdtLrModel(os.path.join(model_dir, args.model_name))
  48. recommend_list = model.sort(args.city_uuid, args.product_id)
  49. for item in recommend_list[:min(args.last_n, len(recommend_list))]:
  50. print(item)
  51. def get_features_importance(args):
  52. model_dir = os.path.join(args.model_path, args.city_uuid)
  53. if not os.path.exists(model_dir):
  54. print("暂无该城市的模型,请先进行模型训练")
  55. return
  56. # # 加载模型
  57. # model = GbdtLrModel(os.path.join(model_dir, args.model_name))
  58. # cust_features_importance, product_features_importance = model.generate_feats_importance()
  59. # # 将字典列表转换为 DataFrame
  60. # cust_df = pd.DataFrame([
  61. # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  62. # for item in cust_features_importance
  63. # ])
  64. # product_df = pd.DataFrame([
  65. # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  66. # for item in product_features_importance
  67. # ])
  68. # cust_file_path = os.path.join(model_dir, "cust_features_importance.csv")
  69. # product_file_path = os.path.join(model_dir, "product_features_importance.csv")
  70. # cust_df.to_csv(cust_file_path, index=False, encoding='utf-8')
  71. # product_df.to_csv(product_file_path, index=False, encoding='utf-8')
  72. get_features_importance_by_model(model_dir, "ori_model")
  73. get_features_importance_by_model(model_dir, "pos_model")
  74. get_features_importance_by_model(model_dir, "shopping_model")
  75. def get_features_importance_by_model(model_dir, modelname):
  76. model = GbdtLrModel(os.path.join(model_dir, f"{modelname}.pkl"))
  77. cust_features_importance, product_features_importance, order_features_importance = model.generate_feats_importance()
  78. # 将字典列表转换为 DataFrame
  79. cust_df = pd.DataFrame([
  80. {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  81. for item in cust_features_importance
  82. ])
  83. product_df = pd.DataFrame([
  84. {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  85. for item in product_features_importance
  86. ])
  87. order_df = pd.DataFrame([
  88. {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
  89. for item in order_features_importance
  90. ])
  91. importance_dir = os.path.join(model_dir, "importance")
  92. if modelname == 'ori_model':
  93. importance_dir = os.path.join(importance_dir, "ori")
  94. elif modelname == 'pos_model':
  95. importance_dir = os.path.join(importance_dir, "pos")
  96. elif modelname == 'shopping_model':
  97. importance_dir = os.path.join(importance_dir, "shopping")
  98. if not os.path.exists(importance_dir):
  99. os.makedirs(importance_dir)
  100. cust_file_path = os.path.join(importance_dir, "cust_features_importance.csv")
  101. product_file_path = os.path.join(importance_dir, "product_features_importance.csv")
  102. order_file_path = os.path.join(importance_dir, "order_features_importance.csv")
  103. cust_df.to_csv(cust_file_path, index=False, encoding='utf-8')
  104. product_df.to_csv(product_file_path, index=False, encoding='utf-8')
  105. order_df.to_csv(order_file_path, index=False, encoding='utf-8')
  106. def run():
  107. parser = argparse.ArgumentParser()
  108. parser.add_argument("--run_train", action='store_true')
  109. parser.add_argument("--recommend", action='store_true')
  110. parser.add_argument("--importance", action='store_true')
  111. parser.add_argument("--train_data_dir", type=str, default="./data")
  112. parser.add_argument("--model_path", type=str, default="./models/rank/weights")
  113. parser.add_argument("--model_name", type=str, default='model.pkl')
  114. parser.add_argument("--last_n", type=int, default=200)
  115. parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301')
  116. parser.add_argument("--product_id", type=str, default='110102')
  117. args = parser.parse_args()
  118. if args.run_train:
  119. train(args)
  120. if args.recommend:
  121. recommend_by_product(args)
  122. if args.importance:
  123. get_features_importance(args)
  124. if __name__ == "__main__":
  125. run()