| 12345678910111213141516171819202122232425262728293031323334353637383940414243 |
- #!/usr/bin/env python3
- # -*- coding:utf-8 -*-
- import numpy as np
- from models.rank.data import DataLoader
- from sklearn.ensemble import GradientBoostingClassifier
- from sklearn.linear_model import LogisticRegression
- class Trainer:
- def __init__(self, path):
- self._load_data(path)
-
- def _load_data(self, path):
- dataloader = DataLoader(path)
- self._train_dataset, self._val_dataset, self._test_dataset = dataloader.split_dataset()
-
- def train_gbdt(self):
- self._gbdt_model = GradientBoostingClassifier(
- n_estimators=100,
- learning_rate=0.1,
- max_depth=3,
- random_state=42,
- )
-
- # 模型训练
- self._gbdt_model.fit(self._train_dataset["data"], self._train_dataset["label"])
-
- def train_lr(self):
- gbdt_train_prdes = self._gbdt_model.predict_proba(self._train_dataset["data"])[:, 1] # 获取正类概率
- gbdt_val_prdes = self._gbdt_model.predict_proba(self._val_dataset["data"])[:, 1]
-
- # 将GBDT的预测结果作为额外特征来训练LR
- lr_train_data = np.column_stack([self._train_dataset["data"], gbdt_train_prdes])
- lr_val_data = np.column_stack([self._val_dataset["data"], gbdt_val_prdes])
-
- # 训练LR模型
- self.lr_model = LogisticRegression(solver='saga', max_iter=1000)
- self.lr_model.fit(lr_train_data, self._train_dataset["label"])
-
- if __name__ == "__main__":
- gbdt_data_path = "./models/rank/data/gbdt_data.csv"
- trainer = Trainer(gbdt_data_path)
- trainer.train_gbdt()
- trainer.train_lr()
|