Browse Source

gbdt-lr训练流程

yangzeyu 1 year ago
parent
commit
9a097ac775
2 changed files with 44 additions and 1 deletions
  1. 3 1
      models/rank/data/__init__.py
  2. 41 0
      models/rank/gbdt_lr.py

+ 3 - 1
models/rank/data/__init__.py

@@ -1,5 +1,7 @@
 from models.rank.data.config import CustConfig, ProductConfig
+from models.rank.data.dataloader import DataLoader
 __all__ = [
     "CustConfig",
-    "ProductConfig"
+    "ProductConfig",
+    "DataLoader"
 ]

+ 41 - 0
models/rank/gbdt_lr.py

@@ -1,2 +1,43 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
+import numpy as np
+from models.rank.data import DataLoader
+from sklearn.ensemble import GradientBoostingClassifier
+from sklearn.linear_model import LogisticRegression
+
+class Trainer:
+    def __init__(self, path):
+        self._load_data(path)
+        
+    def _load_data(self, path):
+        dataloader = DataLoader(path)
+        self._train_dataset, self._val_dataset, self._test_dataset = dataloader.split_dataset()
+        
+    def train_gbdt(self):
+        self._gbdt_model = GradientBoostingClassifier(
+            n_estimators=100,
+            learning_rate=0.1,
+            max_depth=3,
+            random_state=42,
+        )
+        
+        # 模型训练
+        self._gbdt_model.fit(self._train_dataset["data"], self._train_dataset["label"])
+        
+    def train_lr(self):
+        gbdt_train_prdes = self._gbdt_model.predict_proba(self._train_dataset["data"])[:, 1] # 获取正类概率
+        gbdt_val_prdes = self._gbdt_model.predict_proba(self._val_dataset["data"])[:, 1]
+        
+        # 将GBDT的预测结果作为额外特征来训练LR
+        lr_train_data = np.column_stack([self._train_dataset["data"], gbdt_train_prdes])
+        lr_val_data = np.column_stack([self._val_dataset["data"], gbdt_val_prdes])
+        
+        # 训练LR模型
+        self.lr_model = LogisticRegression(solver='saga', max_iter=1000)
+        self.lr_model.fit(lr_train_data, self._train_dataset["label"])
+        
+if __name__ == "__main__":
+    gbdt_data_path = "./models/rank/data/gbdt_data.csv"
+    trainer = Trainer(gbdt_data_path)
+    trainer.train_gbdt()
+    trainer.train_lr()