| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- import pandas as pd
- from models.rank.data.config import CustConfig, ProductConfig, ShopConfig
- from sklearn.model_selection import train_test_split
- from sklearn.preprocessing import StandardScaler
- from models.rank.data.utils import one_hot_embedding
- class DataLoader:
- def __init__(self,path):
- self._gbdt_data_path = path
- self._load_data()
-
- def _load_data(self):
-
- self._gbdt_data = pd.read_csv(self._gbdt_data_path, encoding="utf-8")
- self._gbdt_data.drop('cust_code', axis=1, inplace=True)
- self._gbdt_data.drop('product_code', axis=1, inplace=True)
-
- self._onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}
-
- self._onehot_columns = list(self._onehot_feats.keys())
- self._numeric_columns = self._gbdt_data.drop(self._onehot_columns + ["label"], axis=1).columns
-
- # 将类别数据进行one-hot编码
- self._gbdt_data = one_hot_embedding(self._gbdt_data, self._onehot_feats)
-
-
- def split_dataset(self):
- """数据集划分,将数据集划分为训练集、验证集、测试集"""
- # 1. 分离特征和标签
- features = self._gbdt_data.drop("label", axis=1)
- labels = self._gbdt_data["label"]
-
- # 2. 划分数据集,80%训练集、20%的测试集
- X_train, X_test, y_train, y_test = train_test_split(
- features, labels,
- test_size=0.2,
- random_state=42,
- shuffle=True,
- stratify=labels,
- )
-
- # 3. 数据标准化(仅对特征进行标准化)
- if len(self._numeric_columns) != 0:
- scaler = StandardScaler()
- X_train[self._numeric_columns] = scaler.fit_transform(X_train[self._numeric_columns])
- X_test[self._numeric_columns] = scaler.fit_transform(X_test[self._numeric_columns])
-
- train_dataset = {"data": X_train, "label": y_train}
- test_dataset = {"data": X_test, "label": y_test}
-
- return train_dataset, test_dataset
-
- if __name__ == '__main__':
- path = './data/train_data.csv'
- dataloader = DataLoader(path)
- train_dataset, test_dataset = dataloader.split_dataset()
-
- # 打印训练集和测试集的正负样本分布
- print("训练集正负样本分布:")
- print(train_dataset["label"].value_counts(normalize=True))
-
- print("测试集正负样本分布:")
- print(test_dataset["label"].value_counts(normalize=True))
|