Browse Source

refactor: add logging to utils/train, secure config, create .env.example

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sherlock 3 weeks ago
parent
commit
a35bba88b6
8 changed files with 153 additions and 99 deletions
  1. 22 0
      .env.example
  2. 2 1
      .gitignore
  3. 32 16
      config/config.py
  4. 12 12
      config/database_config.yaml
  5. 6 3
      models/rank/data/preprocess.py
  6. 24 21
      train.py
  7. 36 43
      utils/file_stream.py
  8. 19 3
      utils/report_utils.py

+ 22 - 0
.env.example

@@ -0,0 +1,22 @@
+# BrandCultivation 环境变量配置
+# 复制此文件为 .env 并填入实际值
+
+# MySQL
+MYSQL_HOST=rm-t4n6rz18y4t5x47y70o.mysql.singapore.rds.aliyuncs.com
+MYSQL_PORT=3036
+MYSQL_USER=BrandCultivation
+MYSQL_PASSWORD=your_mysql_password_here
+MYSQL_DB=brand_cultivation
+
+# Redis
+REDIS_HOST=r-t4nb4n9i8je7u6ogk1pd.redis.singapore.rds.aliyuncs.com
+REDIS_PORT=5000
+REDIS_PASSWORD=your_redis_password_here
+REDIS_DB=10
+
+# Logging
+LOG_LEVEL=INFO
+
+# File Service
+FILE_UPLOAD_URL=http://file-center.jcpt:8080/file/fileUpload
+FILE_DOWNLOAD_URL=http://file-center.jcpt:8080/file/fileDownload

+ 2 - 1
.gitignore

@@ -3,4 +3,5 @@
 __pycache__/
 *.pyc
 data/
-models/rank/weights
+models/rank/weights
+.env

+ 32 - 16
config/config.py

@@ -1,16 +1,32 @@
-import yaml
-
-def load_config():
-    with open('./config/database_config.yaml', encoding='utf-8') as file:
-        config = yaml.safe_load(file)
-    return config
-
-def load_model_config():
-    with open('./config/model_config.yaml', encoding='utf-8') as file:
-        config = yaml.safe_load(file)
-    return config
-
-def load_service_config():
-    with open("./config/service_config.yaml", encoding='utf-8') as file:
-        config = yaml.safe_load(file)
-    return config
+from core.config import settings
+
+
+def load_config():
+    return {
+        "mysql": {
+            "host": settings.mysql_host,
+            "port": settings.mysql_port,
+            "user": settings.mysql_user,
+            "passwd": settings.mysql_password,
+            "db": settings.mysql_db,
+        },
+        "redis": {
+            "host": settings.redis_host,
+            "port": settings.redis_port,
+            "passwd": settings.redis_password,
+            "db": settings.redis_db,
+        },
+    }
+
+
+def load_model_config():
+    return settings.model_config
+
+
+def load_service_config():
+    return {
+        "aliyun": {
+            "upload_url": settings.file_upload_url,
+            "download_url": settings.file_download_url,
+        }
+    }

+ 12 - 12
config/database_config.yaml

@@ -1,12 +1,12 @@
-mysql:
-  host: 'rm-t4n6rz18y4t5x47y70o.mysql.singapore.rds.aliyuncs.com'
-  port: 3036
-  db: 'brand_cultivation'
-  user: 'BrandCultivation'
-  passwd: '8BfWBc18NBXl#CMd'
-
-redis:
-  host: 'r-t4nb4n9i8je7u6ogk1pd.redis.singapore.rds.aliyuncs.com'
-  port: 5000
-  db: 10
-  passwd: 'gHmNkVBd88sZybj'
+mysql:
+  host: 'rm-t4n6rz18y4t5x47y70o.mysql.singapore.rds.aliyuncs.com'
+  port: 3036
+  db: 'brand_cultivation'
+  user: 'BrandCultivation'
+  # passwd moved to environment variable MYSQL_PASSWORD
+
+redis:
+  host: 'r-t4nb4n9i8je7u6ogk1pd.redis.singapore.rds.aliyuncs.com'
+  port: 5000
+  db: 10
+  # passwd moved to environment variable REDIS_PASSWORD

+ 6 - 3
models/rank/data/preprocess.py

@@ -5,16 +5,19 @@ import pandas as pd
 from sklearn.preprocessing import MinMaxScaler
 from sklearn.utils import shuffle
 import numpy as np
+from core import get_logger
+
+logger = get_logger("models.rank.preprocess")
 
 class DataProcess():
     def __init__(self, city_uuid, save_dir):
         self._mysql_dao = MySqlDao()
         self.save_dir = save_dir
-        print("gbdr-lr: 正在加载cust_info...")
+        logger.info("Loading cust_info")
         self._cust_data = self._mysql_dao.load_cust_data(city_uuid)
-        print("gbdr-lr: 正在加载product_info...")
+        logger.info("Loading product_info")
         self._product_data = self._mysql_dao.load_product_data(city_uuid)
-        print("gbdr-lr: 正在加载order_info...")
+        logger.info("Loading order_info")
         self._order_data = self._mysql_dao.load_order_data(city_uuid)
         # self._order_data = self._mysql_dao.load_mock_order_data()
         # print("gbdr-lr: 正在加载shopping_info...")

+ 24 - 21
train.py

@@ -4,6 +4,9 @@ from models.rank import DataProcess, Trainer, GbdtLrModel
 from models import ItemCFModel, HotRecallModel
 import time
 import pandas as pd
+from core import get_logger
+
+logger = get_logger("train")
 
 # train_data_path = "./moldes/rank/data/gbdt_data.csv"
 # model_path = "./models/rank/weights"
@@ -17,14 +20,14 @@ def gbdtlr_train(args):
     if not os.path.exists(train_data_dir):
         os.makedirs(train_data_dir)
     
-    # 准备数据集  
-    print("正在整合训练数据...")
+    # 准备数据集
+    logger.info("正在整合训练数据...")
     processor = DataProcess(args.city_uuid, args.train_data_dir)
     processor.data_process()
-    print("训练数据整合完成!")
-    
+    logger.info("训练数据整合完成")
+
     # 进行训练
-    print("开始训练gbdt-lr模型")
+    logger.info("开始训练gbdt-lr模型")
     gbdtlr_trainer(os.path.join(args.train_data_dir, "train_data.csv"), model_dir, "gbdtlr_model.pkl")
 
 def gbdtlr_trainer(train_data_path, model_dir, model_name):
@@ -35,14 +38,14 @@ def gbdtlr_trainer(train_data_path, model_dir, model_name):
     end_time = time.time()
     
     training_time_hours = (end_time - start_time) / 3600
-    print(f"训练时间: {training_time_hours:.4f} 小时")
-    
+    logger.info(f"训练时间: {training_time_hours:.4f} 小时")
+
     eval_metrics = trainer.evaluate()
-    
+
     # 输出评估结果
-    print("GBDT-LR Evaluation Metrics:")
+    logger.info("GBDT-LR Evaluation Metrics:")
     for metric, value in eval_metrics.items():
-        print(f"{metric}: {value:.4f}")
+        logger.info(f"{metric}: {value:.4f}")
         
     # 保存模型
     trainer.save_model(os.path.join(model_dir, model_name))
@@ -79,24 +82,24 @@ def run():
     args = parser.parse_args()
     
     if args.run_train:
-        print("正在计算协同过滤...")
+        logger.info("正在计算协同过滤...")
         itemCF(args)
-        
-        print("正在计算热度召回...")
+
+        logger.info("正在计算热度召回...")
         hot_recall(args)
-        
-        print("正在进行gbdt_lr训练...")
+
+        logger.info("正在进行gbdt_lr训练...")
         gbdtlr_train(args)
-        
+
     if args.run_recall:
-        print("正在计算协同过滤...")
+        logger.info("正在计算协同过滤...")
         itemCF(args)
-        
-        print("正在计算热度召回...")
+
+        logger.info("正在计算热度召回...")
         hot_recall(args)
-        
+
     if args.run_gbdtlr:
-        print("正在进行gbdt_lr训练...")
+        logger.info("正在进行gbdt_lr训练...")
         gbdtlr_train(args)
         
 if __name__ == "__main__":

+ 36 - 43
utils/file_stream.py

@@ -1,87 +1,80 @@
-from config import load_service_config
+import time
+from core import get_logger, settings
 from io import BytesIO
 import os
 import pandas as pd
 import requests
 
+logger = get_logger("utils.file_stream")
+
 
 class FileStreamUtils:
-    cfgs = load_service_config()
-    upload_url = cfgs["aliyun"]["upload_url"]
-    download_url = cfgs["aliyun"]["download_url"]
-    # cookies = cfgs["aliyun"]['cookies']
-     # 设置请求头
+    upload_url = settings.file_upload_url
+    download_url = settings.file_download_url
     headers = {
-        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
+        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
         "Accept": "*/*",
     }
-    
+
     @staticmethod
     def upload_files(reports_dir, files):
         files_id = {}
         for filename in files:
-            file_path = os.path.join(reports_dir, f'{filename}.xlsx')
+            file_path = os.path.join(reports_dir, f"{filename}.xlsx")
+            start_time = time.time()
             try:
-                with open(file_path, 'rb') as f:
-                    files = {'file': (os.path.basename(file_path), f)}
-
+                with open(file_path, "rb") as f:
+                    upload_files = {"file": (os.path.basename(file_path), f)}
                     response = requests.post(
                         FileStreamUtils.upload_url,
                         headers=FileStreamUtils.headers,
-                        files=files,
-                        # cookies=FileStreamUtils.cookies,
-                        verify=True
+                        files=upload_files,
+                        verify=True,
                     )
-                    
-                    if response.json()["success"]:
+                    duration_ms = (time.time() - start_time) * 1000
+                    if response.json().get("success"):
                         file_id = response.json()["data"]["file_info"]["fileid"]
                         files_id[filename] = file_id
+                        logger.info(f"File uploaded: {filename} -> {file_id} ({duration_ms:.0f}ms)")
+                    else:
+                        logger.error(f"Upload failed for {filename}: {response.text}")
+                        return None
             except requests.exceptions.RequestException as e:
-                print("请求出错:", e)
+                logger.error(f"Upload request error for {filename}: {e}", exc_info=True)
                 return None
             except Exception as e:
+                logger.error(f"Upload error for {filename}: {e}", exc_info=True)
                 return None
-                
         return files_id
-    
+
     @staticmethod
-    def download_file(file_id, file_type='xlsx'):
+    def download_file(file_id, file_type="xlsx"):
         """通过file_id从阿里云文件数据库下载文件"""
+        start_time = time.time()
         try:
-            # params = {
-            #     'fileid': file_id,
-            #     'action': 'download'
-            # }
             response = requests.get(
                 f"{FileStreamUtils.download_url}/{file_id}",
                 headers=FileStreamUtils.headers,
-                # cookies=FileStreamUtils.cookies,
-                # params=params,
-                verify=True
+                verify=True,
             )
-            
+            duration_ms = (time.time() - start_time) * 1000
+
             if response.status_code == 200:
                 file_content = BytesIO(response.content)
-                if file_type == 'xlsx':
-                    data = pd.read_excel(file_content, engine='openpyxl')
-                elif file_type == 'csv':
+                if file_type == "xlsx":
+                    data = pd.read_excel(file_content, engine="openpyxl")
+                elif file_type == "csv":
                     data = pd.read_csv(file_content)
                 else:
-                    raise ValueError(f"不支持的文件类型:{file_type}" )
-                
+                    raise ValueError(f"不支持的文件类型:{file_type}")
+                logger.info(f"File downloaded: {file_id} ({duration_ms:.0f}ms, {len(response.content)} bytes)")
                 return data
             else:
+                logger.error(f"Download failed: file_id={file_id}, status={response.status_code}")
                 return None
         except requests.exceptions.RequestException as e:
-            print("Request Error: ", e)
+            logger.error(f"Download request error: file_id={file_id}, error={e}", exc_info=True)
             return None
         except Exception as e:
-            print("File download Error: ", e)
+            logger.error(f"Download error: file_id={file_id}, error={e}", exc_info=True)
             return None
-    
-if __name__ == '__main__':
-    # print(FileStreamUtils.cfgs["aliyun"]["cookies"])
-    file_id = '11C1AC088863421C9BC32A5E722F5147'
-    
-    data = FileStreamUtils.download_file(file_id)
-    data.to_excel('./recommend_list.xlsx', index=False)

+ 19 - 3
utils/report_utils.py

@@ -3,10 +3,14 @@ from models import Recommend
 from models.rank.data.config import CustConfig, ImportanceFeaturesMap, ProductConfig, DeliveryConfig
 from models.rank.data.utils import sample_data_clear
 from models.rank import generate_feats_map
+from core import get_logger
 
 import os
 import pandas as pd
 from utils.reports_process import feats_relation_process, calculate_delivery_by_recommend_data, eval_report_process_pre, eval_report_process
+
+logger = get_logger("utils.report")
+
 class ReportUtils:
     def __init__(self, city_uuid, product_id):
         self._recommend_model = Recommend(city_uuid)
@@ -64,33 +68,40 @@ class ReportUtils:
     
     def generate_feats_ralation_report(self, recall_count):
         """生成特征相关性分析报告"""
+        logger.info("Generating feature relation report")
         feats_map = self._generate_feats_map(recall_count)
         product_content = self._get_product_content()
         # 计算SHAP值
         shap_result = self._recommend_model._gbdtlr_model.generate_shap_interance(feats_map)
         report = feats_relation_process(shap_result, product_content)
-        
+
         report.to_excel(os.path.join(self._save_dir, "品规商户特征关系表.xlsx"), index=False)
+        logger.info("Feature relation report saved")
         
     def generate_product_report(self):
         """生成推荐品规信息表"""
+        logger.info("Generating product report")
         product_data = self._get_product_content()
         with open(os.path.join(self._save_dir, "卷烟信息表.xlsx"), "w", encoding='utf-8-sig') as file:
             for key, value in product_data.items():
                 if key != 'product_code':
                     file.write(f"{ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP[key]}, {value}\n")
+        logger.info("Product report saved")
                     
     def generate_recommend_report(self, recall_count, delivery_count):
         """生成推荐报告,包括投放量"""
+        logger.info("Generating recommend report")
         recommend_data = self._get_recommend_data(recall_count)
         recommend_list = list(map(lambda x: x["cust_code"], recommend_data))
         recommend_cust_infos = self._dao.get_cust_by_ids(self._city_uuid, recommend_list)
         report = calculate_delivery_by_recommend_data(recommend_data, recommend_cust_infos, delivery_count)
-        
+
         report.to_excel(os.path.join(self._save_dir, "商户售卖推荐表.xlsx"), index=False)
+        logger.info("Recommend report saved")
         
     def generate_similarity_product_report(self):
         """生成相似卷烟表"""
+        logger.info("Generating similarity product report")
         product_similarity_map = self._recommend_model._item2vec_model.generate_product_similarity_map(self._product_id)
         product_similarity_map = product_similarity_map[["product_name", "similarity", "brand_name", "factory_name", "is_low_tar", "is_medium", "is_tiny", "is_coarse", "is_exploding_beads", "is_abnormity", "is_cig", "is_chuangxin", "direct_retail_price", "tbc_total_length", "product_style"]]
         product_similarity_map = product_similarity_map.rename(
@@ -113,6 +124,7 @@ class ReportUtils:
             }
         )
         product_similarity_map.to_excel(os.path.join(self._save_dir, "相似卷烟表.xlsx"), index=False)
+        logger.info("Similarity product report saved")
         
     def generate_eval_data_pre(self):
         if self._product_id == '350139':
@@ -121,7 +133,7 @@ class ReportUtils:
             eval_product_id = self._product_id
         eval_order_data = self._dao.get_eval_order_by_product(self._city_uuid, eval_product_id)
         if not os.path.exists(os.path.join(self._save_dir, "商户售卖推荐表.xlsx")):
-            print("请先生成'商户售卖推荐表'")
+            logger.error("商户售卖推荐表 not found")
         recommend_data = pd.read_excel(os.path.join(self._save_dir, "商户售卖推荐表.xlsx"))
         report = eval_report_process_pre(eval_order_data, recommend_data)
         
@@ -129,6 +141,7 @@ class ReportUtils:
         
     def generate_eval_data(self, start_time, end_time, recommend_data):
         """根据推荐列表生成验证报告"""
+        logger.info("Generating eval report")
         if self._product_id == '350139':
             eval_product_id = "350355"
         else:
@@ -142,13 +155,16 @@ class ReportUtils:
         report = eval_report_process(delivery_data, recommend_data)
         
         report.to_excel(os.path.join(self._save_dir, "投放验证报告.xlsx"), index=False)
+        logger.info("Eval report saved")
     
     def generate_all_data(self, recall_count, delivery_count):
+        logger.info("Generating all reports")
         self.generate_feats_ralation_report(recall_count)
         self.generate_product_report()
         self.generate_recommend_report(recall_count, delivery_count)
         self.generate_similarity_product_report()
         # self.generate_eval_data()
+        logger.info("All reports generated")
         
 if __name__ == "__main__":
     city_uuid = "00000000000000000000000011445301"