yangzeyu 11 місяців тому
батько
коміт
8cbde2f5de
2 змінених файлів з 16 додано та 5 видалено
  1. 4 5
      report.py
  2. 12 0
      train.py

+ 4 - 5
report.py

@@ -136,11 +136,10 @@ def run():
     parser = argparse.ArgumentParser()
     
     parser.add_argument("--city_uuid", type=str, default="00000000000000000000000011445301")
-    parser.add_argument("--product_id", type=str, default="350139")
+    parser.add_argument("--product_id", type=str, default="510149")
     parser.add_argument("--recall_count", type=int, default=100)
     parser.add_argument("--delivery_count", type=int, default=5000)
     
-    parser.add_argument("--all_report", action='store_true')
     # parser.add_argument()
     # parser.add_argument()
     
@@ -148,7 +147,7 @@ def run():
     
     # 查找该城市的gbdt模型是否存在
     args.gbdtlr_model_path = os.path.join("./models/rank/weights/", args.city_uuid, "gbdtlr_model.pkl")
-    args.report_dir = os.path.join("./data/report", args.city_uuid)
+    args.report_dir = os.path.join("./data/report", args.city_uuid, args.product_id)
     if not os.path.exists(args.gbdtlr_model_path):
         print("该城市的模型还未训练,请先启动训练!!!")
         
@@ -159,8 +158,8 @@ def run():
     if not os.path.exists(args.report_dir):
         os.makedirs(args.report_dir)
         
-    if args.all_report:
-        generate_all_data(args, report_utils)
+    # 生成报告
+    generate_all_data(args, report_utils)
     
     
 if __name__ == "__main__":

+ 12 - 0
train.py

@@ -60,6 +60,8 @@ def run():
     parser = argparse.ArgumentParser()
     # 全局参数
     parser.add_argument("--run_train", action='store_true')
+    parser.add_argument("--run_recall", action='store_true')
+    parser.add_argument("--run_gbdtlr", action='store_true')
     
     parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301')
     
@@ -86,6 +88,16 @@ def run():
         print("正在进行gbdt_lr训练...")
         gbdtlr_train(args)
         
+    if args.run_recall:
+        print("正在计算协同过滤...")
+        itemCF(args)
+        
+        print("正在计算热度召回...")
+        hot_recall(args)
+        
+    if args.run_gbdtlr:
+        print("正在进行gbdt_lr训练...")
+        gbdtlr_train(args)
         
 if __name__ == "__main__":
     run()