import argparse import os from models.rank import DataProcess, Trainer, GbdtLrModel import time import pandas as pd # train_data_path = "./moldes/rank/data/gbdt_data.csv" # model_path = "./models/rank/weights" def train(args): model_dir = os.path.join(args.model_path, args.city_uuid) train_data_dir = args.train_data_dir if not os.path.exists(model_dir): os.makedirs(model_dir) if not os.path.exists(train_data_dir): os.makedirs(train_data_dir) # 准备数据集 print("正在整合训练数据...") processor = DataProcess(args.city_uuid, args.train_data_dir) processor.data_process() print("训练数据整合完成!") # 进行训练 print("开始训练gbdt-lr模型") trainer(args, os.path.join(args.train_data_dir, "train_data.csv"), model_dir, "gbdtlr_model.pkl") def trainer(args, train_data_path, model_dir, model_name): trainer = Trainer(train_data_path) start_time = time.time() trainer.train() end_time = time.time() training_time_hours = (end_time - start_time) / 3600 print(f"训练时间: {training_time_hours:.4f} 小时") eval_metrics = trainer.evaluate() # 输出评估结果 print("GBDT-LR Evaluation Metrics:") for metric, value in eval_metrics.items(): print(f"{metric}: {value:.4f}") # 保存模型 trainer.save_model(os.path.join(model_dir, model_name)) def recommend_by_product(args): model_dir = os.path.join(args.model_path, args.city_uuid) if not os.path.exists(model_dir): print("暂无该城市的模型,请先进行模型训练") return # 加载模型 model = GbdtLrModel(os.path.join(model_dir, args.model_name)) recommend_list = model.sort(args.city_uuid, args.product_id) for item in recommend_list[:min(args.last_n, len(recommend_list))]: print(item) def get_features_importance(args): model_dir = os.path.join(args.model_path, args.city_uuid) if not os.path.exists(model_dir): print("暂无该城市的模型,请先进行模型训练") return # # 加载模型 # model = GbdtLrModel(os.path.join(model_dir, args.model_name)) # cust_features_importance, product_features_importance = model.generate_feats_importance() # # 将字典列表转换为 DataFrame # cust_df = pd.DataFrame([ # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]} # for item in cust_features_importance # ]) # product_df = pd.DataFrame([ # {"Features": list(item.keys())[0], "Importance": list(item.values())[0]} # for item in product_features_importance # ]) # cust_file_path = os.path.join(model_dir, "cust_features_importance.csv") # product_file_path = os.path.join(model_dir, "product_features_importance.csv") # cust_df.to_csv(cust_file_path, index=False, encoding='utf-8') # product_df.to_csv(product_file_path, index=False, encoding='utf-8') get_features_importance_by_model(model_dir, "ori_model") get_features_importance_by_model(model_dir, "pos_model") get_features_importance_by_model(model_dir, "shopping_model") def get_features_importance_by_model(model_dir, modelname): model = GbdtLrModel(os.path.join(model_dir, f"{modelname}.pkl")) cust_features_importance, product_features_importance, order_features_importance = model.generate_feats_importance() # 将字典列表转换为 DataFrame cust_df = pd.DataFrame([ {"Features": list(item.keys())[0], "Importance": list(item.values())[0]} for item in cust_features_importance ]) product_df = pd.DataFrame([ {"Features": list(item.keys())[0], "Importance": list(item.values())[0]} for item in product_features_importance ]) order_df = pd.DataFrame([ {"Features": list(item.keys())[0], "Importance": list(item.values())[0]} for item in order_features_importance ]) importance_dir = os.path.join(model_dir, "importance") if modelname == 'ori_model': importance_dir = os.path.join(importance_dir, "ori") elif modelname == 'pos_model': importance_dir = os.path.join(importance_dir, "pos") elif modelname == 'shopping_model': importance_dir = os.path.join(importance_dir, "shopping") if not os.path.exists(importance_dir): os.makedirs(importance_dir) cust_file_path = os.path.join(importance_dir, "cust_features_importance.csv") product_file_path = os.path.join(importance_dir, "product_features_importance.csv") order_file_path = os.path.join(importance_dir, "order_features_importance.csv") cust_df.to_csv(cust_file_path, index=False, encoding='utf-8') product_df.to_csv(product_file_path, index=False, encoding='utf-8') order_df.to_csv(order_file_path, index=False, encoding='utf-8') def run(): parser = argparse.ArgumentParser() parser.add_argument("--run_train", action='store_true') parser.add_argument("--recommend", action='store_true') parser.add_argument("--importance", action='store_true') parser.add_argument("--train_data_dir", type=str, default="./data/gbdt") parser.add_argument("--model_path", type=str, default="./models/rank/weights") parser.add_argument("--model_name", type=str, default='model.pkl') parser.add_argument("--last_n", type=int, default=200) parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301') parser.add_argument("--product_id", type=str, default='110102') args = parser.parse_args() if args.run_train: train(args) if args.recommend: recommend_by_product(args) if args.importance: get_features_importance(args) if __name__ == "__main__": run()