import pandas as pd from models.rank.data.config import CustConfig, ProductConfig from sklearn.preprocessing import OneHotEncoder from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler class DataLoader: def __init__(self,path): self._gbdt_data_path = path self._load_data() def _load_data(self): self._gbdt_data = pd.read_csv(self._gbdt_data_path, encoding="utf-8") self._gbdt_data.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True) self._gbdt_data.drop('product_code', axis=1, inplace=True) self._onehot_columns = CustConfig.ONEHOT + ProductConfig.ONEHOT self._numeric_columns = self._gbdt_data.drop(self._onehot_columns + ["label"], axis=1).columns # 将类别数据进行one-hot编码 self.one_hot_embedding(self._onehot_columns) def one_hot_embedding(self, onehot_columns): """对指定的特征进行onehot编码""" self._gbdt_data = pd.get_dummies(self._gbdt_data, columns=onehot_columns, drop_first=False) def split_dataset(self): """数据集划分,将数据集划分为训练集、验证集、测试集""" # 1. 分离特征和标签 features = self._gbdt_data.drop("label", axis=1) labels = self._gbdt_data["label"] # 2. 划分数据集,70%训练集、15%验证集、15%测试集 X_train, X_temp, y_train, y_temp = train_test_split(features, labels, test_size=0.3, random_state=42, shuffle=True) X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp,test_size=0.5, random_state=42,shuffle=True) # 获取One-Hot编码列和数值型列 # 3. 数据标准化(仅对特征进行标准化) scaler = StandardScaler() X_train[self._numeric_columns] = scaler.fit_transform(X_train[self._numeric_columns]) X_val[self._numeric_columns] = scaler.fit_transform(X_val[self._numeric_columns]) X_test[self._numeric_columns] = scaler.fit_transform(X_test[self._numeric_columns]) train_dataset = {"data": X_train, "label": y_train} val_dataset = {"data": X_val, "label": y_val} test_dataset = {"data": X_test, "label": y_test} train_data = pd.DataFrame(X_train, columns=self._gbdt_data.drop('label', axis=1).columns) train_data['label'] = y_train return train_dataset, val_dataset, test_dataset if __name__ == '__main__': path = './models/rank/data/gbdt_data.csv' dataloader = DataLoader(path) dataloader.split_dataset()