gbdt_lr.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import numpy as np
  4. from models.rank.data import DataLoader
  5. from sklearn.ensemble import GradientBoostingClassifier
  6. from sklearn.linear_model import LogisticRegression
  7. class Trainer:
  8. def __init__(self, path):
  9. self._load_data(path)
  10. def _load_data(self, path):
  11. dataloader = DataLoader(path)
  12. self._train_dataset, self._val_dataset, self._test_dataset = dataloader.split_dataset()
  13. def train_gbdt(self):
  14. self._gbdt_model = GradientBoostingClassifier(
  15. n_estimators=100,
  16. learning_rate=0.1,
  17. max_depth=3,
  18. random_state=42,
  19. )
  20. # 模型训练
  21. self._gbdt_model.fit(self._train_dataset["data"], self._train_dataset["label"])
  22. def train_lr(self):
  23. gbdt_train_prdes = self._gbdt_model.predict_proba(self._train_dataset["data"])[:, 1] # 获取正类概率
  24. gbdt_val_prdes = self._gbdt_model.predict_proba(self._val_dataset["data"])[:, 1]
  25. # 将GBDT的预测结果作为额外特征来训练LR
  26. lr_train_data = np.column_stack([self._train_dataset["data"], gbdt_train_prdes])
  27. lr_val_data = np.column_stack([self._val_dataset["data"], gbdt_val_prdes])
  28. # 训练LR模型
  29. self.lr_model = LogisticRegression(solver='saga', max_iter=1000)
  30. self.lr_model.fit(lr_train_data, self._train_dataset["label"])
  31. if __name__ == "__main__":
  32. gbdt_data_path = "./models/rank/data/gbdt_data.csv"
  33. trainer = Trainer(gbdt_data_path)
  34. trainer.train_gbdt()
  35. trainer.train_lr()