gbdt_lr.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import numpy as np
  4. from models.rank.data import DataLoader
  5. from sklearn.ensemble import GradientBoostingClassifier
  6. from sklearn.linear_model import LogisticRegression
  7. from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
  8. from sklearn.model_selection import GridSearchCV
  9. from sklearn.preprocessing import OneHotEncoder
  10. import joblib
  11. import time
  12. class Trainer:
  13. def __init__(self, path):
  14. self._load_data(path)
  15. # 初始化GBDT和LR模型参数
  16. self._gbdt_params = {
  17. 'n_estimators': 100,
  18. 'learning_rate': 0.01,
  19. 'max_depth': 6,
  20. 'subsample': 0.8,
  21. 'random_state': 42,
  22. }
  23. self._lr_params = {
  24. "max_iter": 1000,
  25. 'C': 1.0,
  26. 'penalty': 'elasticnet',
  27. 'l1_ratio': 0.8, # 添加 l1_ratio 参数,可以根据需要调整
  28. 'solver': 'saga',
  29. 'random_state': 42,
  30. 'class_weight': 'balanced'
  31. }
  32. # 初始化模型
  33. self._gbdt_model = GradientBoostingClassifier(**self._gbdt_params)
  34. self._lr_model = LogisticRegression(**self._lr_params)
  35. self._onehot_encoder = OneHotEncoder(sparse_output=True, handle_unknown='ignore')
  36. def _load_data(self, path):
  37. dataloader = DataLoader(path)
  38. self._train_dataset, self._test_dataset = dataloader.split_dataset()
  39. def train(self):
  40. """模型训练"""
  41. print("开始训练GBDT模型...")
  42. # 训练GBDT模型
  43. self._gbdt_model.fit(self._train_dataset["data"], self._train_dataset["label"])
  44. # 获取GBDT的每棵树的分数(决策值)
  45. gbdt_train_preds = self._gbdt_model.apply(self._train_dataset["data"])[:, :, 0] # 仅取每棵树的叶节点输出
  46. gbdt_feats_encoded = self._onehot_encoder.fit_transform(gbdt_train_preds)
  47. print("开始训练LR模型...")
  48. # 使用决策树输出作为LR的输入特征
  49. self._lr_model.fit(gbdt_feats_encoded, self._train_dataset["label"])
  50. def predict(self, X):
  51. # 获取GBDT模型的预测分数
  52. gbdt_preds = self._gbdt_model.apply(X)[:, :, 0]
  53. gbdt_feats_encoded = self._onehot_encoder.transform(gbdt_preds)
  54. # 使用训练好的LR模型输出概率
  55. return self._lr_model.predict(gbdt_feats_encoded)
  56. def predict_proba(self, X):
  57. # 获取GBDT模型的预测分数
  58. gbdt_preds = self._gbdt_model.apply(X)[:, :, 0]
  59. gbdt_feats_encoded = self._onehot_encoder.transform(gbdt_preds)
  60. # 使用训练好的LR模型输出概率
  61. return self._lr_model.predict_proba(gbdt_feats_encoded)
  62. def evaluate(self):
  63. # 对测试集进行预测
  64. y_pred = self.predict(self._test_dataset["data"])
  65. y_pred_proba = self.predict_proba(self._test_dataset["data"])[:, 1] # 获取正类的概率
  66. # 计算各类评估指标
  67. accuracy = accuracy_score(self._test_dataset["label"], y_pred)
  68. precision = precision_score(self._test_dataset["label"], y_pred)
  69. recall = recall_score(self._test_dataset["label"], y_pred)
  70. f1 = f1_score(self._test_dataset["label"], y_pred)
  71. roc_auc = roc_auc_score(self._test_dataset["label"], y_pred_proba)
  72. return {
  73. 'accuracy': accuracy,
  74. 'precision': precision,
  75. 'recall': recall,
  76. 'f1_score': f1,
  77. 'roc_auc': roc_auc
  78. }
  79. def save_model(self, model_path):
  80. """将模型保存到本地"""
  81. models = {"gbdt_model": self._gbdt_model, "lr_model": self._lr_model, "onehot_encoder": self._onehot_encoder}
  82. joblib.dump(models, model_path)
  83. if __name__ == "__main__":
  84. gbdt_data_path = "./models/rank/data/gbdt_data.csv"
  85. trainer = Trainer(gbdt_data_path)
  86. start_time = time.time()
  87. trainer.train()
  88. end_time = time.time()
  89. training_time_hours = (end_time - start_time) / 3600
  90. print(f"训练时间: {training_time_hours:.4f} 小时")
  91. eval_metrics = trainer.evaluate()
  92. # 输出评估结果
  93. print("GBDT-LR Evaluation Metrics:")
  94. for metric, value in eval_metrics.items():
  95. print(f"{metric}: {value:.4f}")
  96. # 保存模型
  97. model_path = "./models/rank/weights/model.pkl"
  98. trainer.save_model(model_path)