|
@@ -7,6 +7,7 @@ from sklearn.linear_model import LogisticRegression
|
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
|
|
from sklearn.model_selection import GridSearchCV
|
|
from sklearn.model_selection import GridSearchCV
|
|
|
from sklearn.preprocessing import OneHotEncoder
|
|
from sklearn.preprocessing import OneHotEncoder
|
|
|
|
|
+import joblib
|
|
|
|
|
|
|
|
class Trainer:
|
|
class Trainer:
|
|
|
def __init__(self, path):
|
|
def __init__(self, path):
|
|
@@ -24,7 +25,8 @@ class Trainer:
|
|
|
"max_iter": 1000,
|
|
"max_iter": 1000,
|
|
|
'C': 1.0,
|
|
'C': 1.0,
|
|
|
'penalty': 'l2',
|
|
'penalty': 'l2',
|
|
|
- 'solver': 'liblinear',
|
|
|
|
|
|
|
+ # 'l1_ratio': 0.5, # 添加 l1_ratio 参数,可以根据需要调整
|
|
|
|
|
+ 'solver': 'sag',
|
|
|
'random_state': 42,
|
|
'random_state': 42,
|
|
|
'class_weight': 'balanced'
|
|
'class_weight': 'balanced'
|
|
|
}
|
|
}
|
|
@@ -82,7 +84,7 @@ class Trainer:
|
|
|
precision = precision_score(self._test_dataset["label"], y_pred)
|
|
precision = precision_score(self._test_dataset["label"], y_pred)
|
|
|
recall = recall_score(self._test_dataset["label"], y_pred)
|
|
recall = recall_score(self._test_dataset["label"], y_pred)
|
|
|
f1 = f1_score(self._test_dataset["label"], y_pred)
|
|
f1 = f1_score(self._test_dataset["label"], y_pred)
|
|
|
- roc_auc = roc_auc_score(self._test_dataset["label"], y_pred_proba)
|
|
|
|
|
|
|
+ roc_auc = roc_auc_score(self._test_dataset["label"], y_pred_proba)
|
|
|
|
|
|
|
|
return {
|
|
return {
|
|
|
'accuracy': accuracy,
|
|
'accuracy': accuracy,
|
|
@@ -91,6 +93,11 @@ class Trainer:
|
|
|
'f1_score': f1,
|
|
'f1_score': f1,
|
|
|
'roc_auc': roc_auc
|
|
'roc_auc': roc_auc
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ def save_model(self, model_path):
|
|
|
|
|
+ """将模型保存到本地"""
|
|
|
|
|
+ models = {"gbdt_model": self._gbdt_model, "lr_model": self._lr_model, "onehot_encoder": self._onehot_encoder}
|
|
|
|
|
+ joblib.dump(models, model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
@@ -103,4 +110,8 @@ if __name__ == "__main__":
|
|
|
print("GBDT-LR Evaluation Metrics:")
|
|
print("GBDT-LR Evaluation Metrics:")
|
|
|
for metric, value in eval_metrics.items():
|
|
for metric, value in eval_metrics.items():
|
|
|
print(f"{metric}: {value:.4f}")
|
|
print(f"{metric}: {value:.4f}")
|
|
|
|
|
+
|
|
|
|
|
+ # 保存模型
|
|
|
|
|
+ model_path = "./models/rank/weights/model.pkl"
|
|
|
|
|
+ trainer.save_model(model_path)
|
|
|
|
|
|