9 Commits a88b37012b ... 2f0a0a565a

Auteur SHA1 Message Date
  Sherlock 2f0a0a565a docs: 添加 README.md 说明环境配置和功能启动 il y a 3 semaines
  Sherlock a35bba88b6 refactor: add logging to utils/train, secure config, create .env.example il y a 3 semaines
  Sherlock c84bbbeb69 fix(models): fix get_recommend_list bug, add logging to all model modules il y a 3 semaines
  Sherlock 75642c90ce refactor(api): add logging and error handling to all endpoints il y a 3 semaines
  Sherlock b92be3efd9 refactor: improve DAO logging, rewrite run_api with middleware and health check il y a 3 semaines
  Sherlock a7a8187584 refactor(database): add logging, session context manager, env-based config il y a 3 semaines
  Sherlock aaf6ae01e9 feat(core): add infrastructure layer (logging, config, exceptions, middleware) il y a 3 semaines
  Sherlock 89ebb90f13 docs: 添加项目重构实现计划 il y a 3 semaines
  Sherlock 346ff6fa5c docs: 添加项目级重构设计文档 il y a 3 semaines

+ 22 - 0
.env.example

@@ -0,0 +1,22 @@
+# BrandCultivation 环境变量配置
+# 复制此文件为 .env 并填入实际值
+
+# MySQL
+MYSQL_HOST=rm-t4n6rz18y4t5x47y70o.mysql.singapore.rds.aliyuncs.com
+MYSQL_PORT=3036
+MYSQL_USER=BrandCultivation
+MYSQL_PASSWORD=your_mysql_password_here
+MYSQL_DB=brand_cultivation
+
+# Redis
+REDIS_HOST=r-t4nb4n9i8je7u6ogk1pd.redis.singapore.rds.aliyuncs.com
+REDIS_PORT=5000
+REDIS_PASSWORD=your_redis_password_here
+REDIS_DB=10
+
+# Logging
+LOG_LEVEL=INFO
+
+# File Service
+FILE_UPLOAD_URL=http://file-center.jcpt:8080/file/fileUpload
+FILE_DOWNLOAD_URL=http://file-center.jcpt:8080/file/fileDownload

+ 2 - 1
.gitignore

@@ -3,4 +3,5 @@
 __pycache__/
 *.pyc
 data/
-models/rank/weights
+models/rank/weights
+.env

+ 241 - 0
README.md

@@ -0,0 +1,241 @@
+# BrandCultivation 卷烟品牌培育推荐系统
+
+基于协同过滤、Item2Vec 和 GBDT-LR 的卷烟品牌培育商户推荐系统,提供品规-商户匹配推荐、投放量分配、效果验证等功能。
+
+## 目录结构
+
+```
+BrandCultivation/
+├── core/                    # 基础设施层(日志、配置、异常、中间件)
+├── api/                     # FastAPI 路由层
+├── database/                # 数据访问层(MySQL DAO + Redis)
+├── models/                  # ML 模型(Item2Vec、ItemCF、GBDT-LR)
+├── utils/                   # 工具类(文件上传、报告生成)
+├── config/                  # 配置文件(YAML)
+├── run_api.py               # API 服务入口
+├── train.py                 # 模型训练入口
+├── requirements.txt         # Python 依赖
+└── .env.example             # 环境变量模板
+```
+
+## 环境要求
+
+- Python 3.10+
+- MySQL 5.7+
+- Redis 5.0+
+
+## 安装
+
+```bash
+# 克隆项目
+git clone <repo-url>
+cd BrandCultivation
+
+# 创建虚拟环境
+conda create -n recommend python=3.10
+conda activate recommend
+
+# 安装依赖
+pip install -r requirements.txt
+```
+
+## 配置
+
+### 环境变量
+
+复制 `.env.example` 为 `.env`,填入实际值:
+
+```bash
+cp .env.example .env
+```
+
+必须配置的环境变量:
+
+| 变量 | 说明 | 示例 |
+|------|------|------|
+| `MYSQL_HOST` | MySQL 主机地址 | `rm-xxx.mysql.rds.aliyuncs.com` |
+| `MYSQL_PORT` | MySQL 端口 | `3036` |
+| `MYSQL_USER` | MySQL 用户名 | `BrandCultivation` |
+| `MYSQL_PASSWORD` | MySQL 密码 | (必填) |
+| `MYSQL_DB` | 数据库名 | `brand_cultivation` |
+| `REDIS_HOST` | Redis 主机地址 | `r-xxx.redis.rds.aliyuncs.com` |
+| `REDIS_PORT` | Redis 端口 | `5000` |
+| `REDIS_PASSWORD` | Redis 密码 | (必填) |
+| `REDIS_DB` | Redis 数据库编号 | `10` |
+| `LOG_LEVEL` | 日志级别 | `INFO`(默认) |
+| `FILE_UPLOAD_URL` | 文件上传服务地址 | `http://file-center.jcpt:8080/file/fileUpload` |
+| `FILE_DOWNLOAD_URL` | 文件下载服务地址 | `http://file-center.jcpt:8080/file/fileDownload` |
+
+如果不使用 `.env` 文件,也可以直接 export 环境变量:
+
+```bash
+export MYSQL_PASSWORD='your_password'
+export REDIS_PASSWORD='your_password'
+```
+
+### YAML 配置
+
+非敏感配置保留在 `config/` 目录下的 YAML 文件中,环境变量优先级高于 YAML。
+
+## 运行
+
+### 启动 API 服务
+
+```bash
+python run_api.py
+```
+
+服务启动后监听 `0.0.0.0:7960`,可通过以下方式验证:
+
+```bash
+# 健康检查
+curl http://localhost:7960/health
+
+# 预期返回
+# {"status":"healthy","mysql":"ok","redis":"ok"}
+```
+
+也可以使用 uvicorn 直接启动(支持热重载):
+
+```bash
+uvicorn run_api:app --host 0.0.0.0 --port 7960 --reload
+```
+
+### 模型训练
+
+训练前确保 MySQL 和 Redis 均可连接。
+
+```bash
+# 完整训练(协同过滤 + 热度召回 + GBDT-LR)
+python train.py --run_train --city_uuid 00000000000000000000000011445301
+
+# 仅训练召回模型(协同过滤 + 热度召回)
+python train.py --run_recall --city_uuid 00000000000000000000000011445301
+
+# 仅训练排序模型(GBDT-LR)
+python train.py --run_gbdtlr --city_uuid 00000000000000000000000011445301
+```
+
+训练参数:
+
+| 参数 | 说明 | 默认值 |
+|------|------|--------|
+| `--city_uuid` | 城市 UUID | `00000000000000000000000011445301` |
+| `--train_data_dir` | 训练数据保存目录 | `./data/gbdt` |
+| `--model_path` | 模型权重保存目录 | `./models/rank/weights` |
+| `--largest_n` | ItemCF 热度 Top N | `300` |
+| `--similarity_k` | ItemCF 相似商户数 | `100` |
+| `--top_n` | ItemCF 推荐候选数 | `1500` |
+| `--n_jobs` | 并行计算线程数 | `2` |
+
+## API 接口
+
+基础路径:`/brandcultivation/api/v1`
+
+### POST /recommend
+
+生成商户推荐列表并分配投放量。
+
+请求体:
+```json
+{
+    "city_uuid": "00000000000000000000000011445301",
+    "product_code": "440298",
+    "recall_cust_count": 500,
+    "delivery_count": 5000,
+    "cultivacation_id": "10000001",
+    "limit_cycle_name": "202505W1(05.05-05.11)"
+}
+```
+
+响应:
+```json
+{
+    "code": 200,
+    "msg": "success",
+    "data": {
+        "recommendationInfo": [
+            {"id": 1, "cust_code": "445300108802", "recommend_score": 95.3, "delivery_count": 120}
+        ]
+    }
+}
+```
+
+### POST /report
+
+获取推荐相关报告文件 ID。
+
+请求体:
+```json
+{
+    "cultivacation_id": "10000001"
+}
+```
+
+### POST /eval_report
+
+生成投放效果验证报告。
+
+请求体:
+```json
+{
+    "city_uuid": "00000000000000000000000011445301",
+    "product_code": "440298",
+    "cultivacation_id": "10000001",
+    "start_time": "2025/2/10",
+    "end_time": "2025/2/16"
+}
+```
+
+### GET /health
+
+健康检查,返回 MySQL 和 Redis 连接状态。
+
+## 日志
+
+系统使用 JSON 格式日志输出到 stdout,每条日志包含:
+
+```json
+{
+    "timestamp": "2026-05-21T03:35:48.869426+00:00",
+    "level": "INFO",
+    "module": "recommend",
+    "function": "recommend",
+    "line": 18,
+    "message": "Recommend request: city=xxx, product=440298, recall=500",
+    "request_id": "a1b2c3d4"
+}
+```
+
+通过 `LOG_LEVEL` 环境变量控制日志级别(DEBUG / INFO / WARNING / ERROR)。
+
+API 请求会自动生成 `request_id`,贯穿整个请求链路,方便问题追踪。响应头中也会返回 `X-Request-ID`。
+
+## Docker 部署
+
+```dockerfile
+FROM python:3.10-slim
+
+WORKDIR /app
+COPY requirements.txt .
+RUN pip install --no-cache-dir -r requirements.txt
+
+COPY . .
+
+ENV MYSQL_PASSWORD=""
+ENV REDIS_PASSWORD=""
+ENV LOG_LEVEL=INFO
+
+EXPOSE 7960
+CMD ["python", "run_api.py"]
+```
+
+```bash
+docker build -t brand-cultivation .
+docker run -d \
+    -p 7960:7960 \
+    -e MYSQL_PASSWORD='your_password' \
+    -e REDIS_PASSWORD='your_password' \
+    brand-cultivation
+```
+

+ 18 - 9
api/eval_report.py

@@ -1,24 +1,27 @@
-from config import load_service_config
 from database import MySqlDao
 from fastapi import APIRouter, status, HTTPException
 import os
 from .request_body import EvalReportRequest
 import requests
 from utils import ReportUtils, FileStreamUtils
+from core import get_logger
 
-cfgs = load_service_config()
+logger = get_logger("api.eval_report")
 dao = MySqlDao()
 router = APIRouter()
 
 @router.post('/eval_report')
 async def eval_report(request: EvalReportRequest):
     """生成并上传验证报告到阿里云文件数据库"""
+    logger.info(f"Eval report request: cultivacation_id={request.cultivacation_id}, city={request.city_uuid}, product={request.product_code}")
+
     reports_dir = os.path.join('./data/reports', request.city_uuid, request.product_code)
     report_util = ReportUtils(request.city_uuid, request.product_code)
-    
+
     # 获取report数据表中eval_table的file_id,如果不为空,直接返回结果,如果为空则先创建验证数据
     eval_file_id = dao.get_report_file_id(request.cultivacation_id)['val_table'].item()
     if eval_file_id:
+        logger.info(f"Existing eval report found: file_id={eval_file_id}")
         content = [
             {
                 "id": 1,
@@ -27,25 +30,31 @@ async def eval_report(request: EvalReportRequest):
             }
         ]
         return {"code": 200, "msg": "success", "data": {"evalReportInfo": content}}
-    
+
     # 获取推荐列表
     file_id = dao.get_report_file_id(request.cultivacation_id)['recommend_table'].item()
     if file_id is None:
+        logger.error(f"Recommend table missing for cultivacation_id={request.cultivacation_id}")
         return {"code": 405, "msg": "推荐表丢失,生成验证报告失败!", "data": {"reportInfo": "推荐表丢失,生成验证报告失败!"}}
-    
+
+    logger.info(f"Downloading recommend data: file_id={file_id}")
     recommend_data = FileStreamUtils.download_file(file_id)
     if recommend_data is None:
+        logger.error(f"Failed to download recommend data: file_id={file_id}")
         return {"code": 405, "msg": "下载推荐数据出错,生成验证报告失败!", "data": {"reportInfo": "下载推荐数据出错,生成验证报告失败!"}}
-    
+
     # 生成验证报告
+    logger.info(f"Generating eval report for period {request.start_time} to {request.end_time}")
     report_util.generate_eval_data(request.start_time, request.end_time, recommend_data)
-    
+
     # 上传报告
+    logger.info(f"Uploading eval report to {reports_dir}")
     eval_report = ['投放验证报告']
     file_id_map = FileStreamUtils.upload_files(reports_dir, eval_report)
-    
+
     dao.update_eval_report_data(request.cultivacation_id, file_id_map.get('投放验证报告'))
-    
+    logger.info(f"Eval report uploaded: file_id={file_id_map.get('投放验证报告')}")
+
     content = [
         {
             "id": 1,

+ 53 - 45
api/recommend.py

@@ -1,77 +1,85 @@
 from database import MySqlDao
-from fastapi import APIRouter, BackgroundTasks
+from fastapi import APIRouter, BackgroundTasks, HTTPException, status
 from .request_body import RecommendRequest
+from core import get_logger
 
 from models import Recommend
 import os
 from utils import FileStreamUtils, ReportUtils
 
+logger = get_logger("api.recommend")
 dao = MySqlDao()
-
 router = APIRouter()
 
+
 @router.post("/recommend")
 async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks):
     """推荐接口"""
+    logger.info(f"Recommend request: city={request.city_uuid}, product={request.product_code}, recall={request.recall_cust_count}")
+
     gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
     if not os.path.exists(gbdtlr_model_path):
-        return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}}
-    
-    # 初始化模型
+        logger.warning(f"Model not found: {gbdtlr_model_path}")
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="该城市的模型未训练,请先进行训练",
+        )
+
     recommend_model = Recommend(request.city_uuid)
-    
-    # 判断该品规是否是新品规
-    products_in_oreder = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
-    if request.product_code in products_in_oreder:
+
+    products_in_order = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
+    if request.product_code in products_in_order:
+        logger.info(f"Using GBDT-LR model for existing product {request.product_code}")
         recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count)
     else:
+        logger.info(f"Using Item2Vec model for new product {request.product_code}")
         recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count)
+
     recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count)
     request_data = []
     for index, data in enumerate(recommend_data):
-        id = index + 1
         request_data.append(
             {
-                "id": id,
+                "id": index + 1,
                 "cust_code": data["cust_code"],
                 "recommend_score": data["recommend_score"],
-                "delivery_count": data["delivery_count"]
+                "delivery_count": data["delivery_count"],
             }
         )
-    
-    # 异步执行报告生成任务
-    backgroundTasks.add_task(
-        generate_and_upload_report,
-        request
-    )
-    
+
+    logger.info(f"Recommend completed: {len(request_data)} customers recommended")
+
+    backgroundTasks.add_task(generate_and_upload_report, request)
+
     return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
 
+
 def generate_and_upload_report(request: RecommendRequest):
     """生成并上传报告到阿里云文件数据库"""
-    # 生成相关报告
-    report_util = ReportUtils(request.city_uuid, request.product_code)
-    report_util.generate_all_data(request.recall_cust_count, request.delivery_count)
-    
-    # 上传报告
-    reports_dir = os.path.join('./data/reports', request.city_uuid, request.product_code)
-    report_files = [
-        '卷烟信息表',
-        '品规商户特征关系表',
-        '相似卷烟表',
-        '商户售卖推荐表'
-    ]
-    file_id_map = FileStreamUtils.upload_files(reports_dir, report_files)
-    
-    # 将返回的file_id保存到数据库中
-    data_dict = {
-        'cultivacation_id': request.cultivacation_id,
-        'city_uuid': request.city_uuid,
-        'limit_cycle_name': request.limit_cycle_name,
-        'product_code': request.product_code,
-        'product_info_table': file_id_map.get('卷烟信息表'),
-        'relation_table': file_id_map.get('品规商户特征关系表'),
-        'similarity_product_table': file_id_map.get('相似卷烟表'),
-        'recommend_table': file_id_map.get('商户售卖推荐表'),
-    }
-    dao.insert_report(data_dict)
+    logger.info(f"Background task started: generating report for {request.city_uuid}/{request.product_code}")
+    try:
+        report_util = ReportUtils(request.city_uuid, request.product_code)
+        report_util.generate_all_data(request.recall_cust_count, request.delivery_count)
+
+        reports_dir = os.path.join("./data/reports", request.city_uuid, request.product_code)
+        report_files = ["卷烟信息表", "品规商户特征关系表", "相似卷烟表", "商户售卖推荐表"]
+        file_id_map = FileStreamUtils.upload_files(reports_dir, report_files)
+
+        if file_id_map is None:
+            logger.error(f"Report upload failed for {request.city_uuid}/{request.product_code}")
+            return
+
+        data_dict = {
+            "cultivacation_id": request.cultivacation_id,
+            "city_uuid": request.city_uuid,
+            "limit_cycle_name": request.limit_cycle_name,
+            "product_code": request.product_code,
+            "product_info_table": file_id_map.get("卷烟信息表"),
+            "relation_table": file_id_map.get("品规商户特征关系表"),
+            "similarity_product_table": file_id_map.get("相似卷烟表"),
+            "recommend_table": file_id_map.get("商户售卖推荐表"),
+        }
+        dao.insert_report(data_dict)
+        logger.info(f"Background task completed: report uploaded for {request.city_uuid}/{request.product_code}")
+    except Exception as e:
+        logger.error(f"Background task failed: {e}", exc_info=True)

+ 9 - 38
api/report.py

@@ -1,68 +1,39 @@
 from database import MySqlDao
 from fastapi import APIRouter, status, HTTPException
-import os
 from .request_body import ReportRequest
+from core import get_logger
 
-dao =MySqlDao()
+logger = get_logger("api.report")
+dao = MySqlDao()
 router = APIRouter()
 
 @router.post("/report")
 async def report(request: ReportRequest):
     """获取推荐相关报告接口"""
+    logger.info(f"Report request: cultivacation_id={request.cultivacation_id}")
+
     file_id_record = dao.get_report_file_id(request.cultivacation_id)
     if file_id_record.empty:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
             detail="Reports not found"
         )
-    
+
     file_id_map = {
         '卷烟信息表': file_id_record['product_info_table'].item(),
         '品规商户特征关系表': file_id_record['relation_table'].item(),
         '相似卷烟表': file_id_record['similarity_product_table'].item()
     }
-    
+
     request_data = []
     for index, filename in enumerate(file_id_map):
         request_data.append(
             {
-                "id": index+1,
+                "id": index + 1,
                 "filename": filename,
                 "file_id": file_id_map.get(filename)
             }
         )
-        
-    return {"code": 200, "msg": "success", "data": {"reportInfo": request_data}}
 
-# @router.post("/report")
-# async def report_api(request: ReportRequest):
-#     """获取推荐相关报告接口"""
-#     files_id_path = os.path.join("./data/reports", request.city_uuid, request.product_code, "files_id.txt")
-#     if not os.path.exists(files_id_path):
-#         raise HTTPException(
-#             status_code=status.HTTP_404_NOT_FOUND,
-#             detail="Reports not found"
-#         )
-    
-#     with open(files_id_path, 'r', encoding="utf-8") as file:
-#         lines = file.readlines()
-        
-#     if lines[0].strip() == "failed":
-#         raise HTTPException(
-#             status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
-#             detail="reports upload failed"
-#         )
-    
-#     request_data = []
-#     for index, line in enumerate(lines):
-#         filename, file_id = line.strip().split(",")
-#         request_data.append(
-#             {
-#                 "id": index+1,
-#                 "filename": filename,
-#                 "file_id": file_id
-#             }
-#         )
-    
-#     return {"code": 200, "msg": "success", "data": {"reportInfo": request_data}}
+    return {"code": 200, "msg": "success", "data": {"reportInfo": request_data}}
 

+ 32 - 16
config/config.py

@@ -1,16 +1,32 @@
-import yaml
-
-def load_config():
-    with open('./config/database_config.yaml', encoding='utf-8') as file:
-        config = yaml.safe_load(file)
-    return config
-
-def load_model_config():
-    with open('./config/model_config.yaml', encoding='utf-8') as file:
-        config = yaml.safe_load(file)
-    return config
-
-def load_service_config():
-    with open("./config/service_config.yaml", encoding='utf-8') as file:
-        config = yaml.safe_load(file)
-    return config
+from core.config import settings
+
+
+def load_config():
+    return {
+        "mysql": {
+            "host": settings.mysql_host,
+            "port": settings.mysql_port,
+            "user": settings.mysql_user,
+            "passwd": settings.mysql_password,
+            "db": settings.mysql_db,
+        },
+        "redis": {
+            "host": settings.redis_host,
+            "port": settings.redis_port,
+            "passwd": settings.redis_password,
+            "db": settings.redis_db,
+        },
+    }
+
+
+def load_model_config():
+    return settings.model_config
+
+
+def load_service_config():
+    return {
+        "aliyun": {
+            "upload_url": settings.file_upload_url,
+            "download_url": settings.file_download_url,
+        }
+    }

+ 12 - 12
config/database_config.yaml

@@ -1,12 +1,12 @@
-mysql:
-  host: 'rm-t4n6rz18y4t5x47y70o.mysql.singapore.rds.aliyuncs.com'
-  port: 3036
-  db: 'brand_cultivation'
-  user: 'BrandCultivation'
-  passwd: '8BfWBc18NBXl#CMd'
-
-redis:
-  host: 'r-t4nb4n9i8je7u6ogk1pd.redis.singapore.rds.aliyuncs.com'
-  port: 5000
-  db: 10
-  passwd: 'gHmNkVBd88sZybj'
+mysql:
+  host: 'rm-t4n6rz18y4t5x47y70o.mysql.singapore.rds.aliyuncs.com'
+  port: 3036
+  db: 'brand_cultivation'
+  user: 'BrandCultivation'
+  # passwd moved to environment variable MYSQL_PASSWORD
+
+redis:
+  host: 'r-t4nb4n9i8je7u6ogk1pd.redis.singapore.rds.aliyuncs.com'
+  port: 5000
+  db: 10
+  # passwd moved to environment variable REDIS_PASSWORD

+ 23 - 0
core/__init__.py

@@ -0,0 +1,23 @@
+from core.logging import get_logger, request_id_var
+from core.config import settings
+from core.exceptions import (
+    AppException,
+    DatabaseException,
+    ModelException,
+    FileServiceException,
+    ValidationException,
+)
+from core.middleware import RequestLoggingMiddleware, get_request_id
+
+__all__ = [
+    "get_logger",
+    "request_id_var",
+    "settings",
+    "AppException",
+    "DatabaseException",
+    "ModelException",
+    "FileServiceException",
+    "ValidationException",
+    "RequestLoggingMiddleware",
+    "get_request_id",
+]

+ 77 - 0
core/config.py

@@ -0,0 +1,77 @@
+import os
+from pathlib import Path
+import yaml
+
+PROJECT_ROOT = Path(__file__).resolve().parent.parent
+
+
+def _get_env(key: str, default=None):
+    return os.environ.get(key, default)
+
+
+def _load_yaml(filename: str) -> dict:
+    filepath = PROJECT_ROOT / "config" / filename
+    with open(filepath, encoding="utf-8") as f:
+        return yaml.safe_load(f) or {}
+
+
+class _Settings:
+    def __init__(self):
+        self._db_cfg = _load_yaml("database_config.yaml")
+        self._model_cfg = _load_yaml("model_config.yaml")
+        self._service_cfg = _load_yaml("service_config.yaml")
+
+    @property
+    def mysql_host(self) -> str:
+        return _get_env("MYSQL_HOST", self._db_cfg.get("mysql", {}).get("host", "localhost"))
+
+    @property
+    def mysql_port(self) -> int:
+        return int(_get_env("MYSQL_PORT", self._db_cfg.get("mysql", {}).get("port", 3306)))
+
+    @property
+    def mysql_user(self) -> str:
+        return _get_env("MYSQL_USER", self._db_cfg.get("mysql", {}).get("user", "root"))
+
+    @property
+    def mysql_password(self) -> str:
+        return _get_env("MYSQL_PASSWORD", self._db_cfg.get("mysql", {}).get("passwd", ""))
+
+    @property
+    def mysql_db(self) -> str:
+        return _get_env("MYSQL_DB", self._db_cfg.get("mysql", {}).get("db", ""))
+
+    @property
+    def redis_host(self) -> str:
+        return _get_env("REDIS_HOST", self._db_cfg.get("redis", {}).get("host", "localhost"))
+
+    @property
+    def redis_port(self) -> int:
+        return int(_get_env("REDIS_PORT", self._db_cfg.get("redis", {}).get("port", 6379)))
+
+    @property
+    def redis_password(self) -> str:
+        return _get_env("REDIS_PASSWORD", self._db_cfg.get("redis", {}).get("passwd", ""))
+
+    @property
+    def redis_db(self) -> int:
+        return int(_get_env("REDIS_DB", self._db_cfg.get("redis", {}).get("db", 0)))
+
+    @property
+    def log_level(self) -> str:
+        return _get_env("LOG_LEVEL", "INFO").upper()
+
+    @property
+    def file_upload_url(self) -> str:
+        return _get_env("FILE_UPLOAD_URL", self._service_cfg.get("aliyun", {}).get("upload_url", ""))
+
+    @property
+    def file_download_url(self) -> str:
+        return _get_env("FILE_DOWNLOAD_URL", self._service_cfg.get("aliyun", {}).get("download_url", ""))
+
+    @property
+    def model_config(self) -> dict:
+        return self._model_cfg
+
+
+settings = _Settings()

+ 31 - 0
core/exceptions.py

@@ -0,0 +1,31 @@
+class AppException(Exception):
+    """应用异常基类"""
+    def __init__(self, code: int, message: str, detail: str = None):
+        self.code = code
+        self.message = message
+        self.detail = detail
+        super().__init__(message)
+
+
+class DatabaseException(AppException):
+    """数据库操作失败"""
+    def __init__(self, message: str = "数据库操作失败", detail: str = None):
+        super().__init__(code=500, message=message, detail=detail)
+
+
+class ModelException(AppException):
+    """模型推理失败"""
+    def __init__(self, message: str = "模型推理失败", detail: str = None):
+        super().__init__(code=500, message=message, detail=detail)
+
+
+class FileServiceException(AppException):
+    """文件服务失败"""
+    def __init__(self, message: str = "文件服务操作失败", detail: str = None):
+        super().__init__(code=500, message=message, detail=detail)
+
+
+class ValidationException(AppException):
+    """业务校验失败"""
+    def __init__(self, message: str = "参数校验失败", detail: str = None):
+        super().__init__(code=400, message=message, detail=detail)

+ 38 - 0
core/logging.py

@@ -0,0 +1,38 @@
+import logging
+import json
+import sys
+from contextvars import ContextVar
+from datetime import datetime, timezone
+
+request_id_var: ContextVar[str] = ContextVar("request_id", default="-")
+
+
+class JSONFormatter(logging.Formatter):
+    def format(self, record: logging.LogRecord) -> str:
+        log_data = {
+            "timestamp": datetime.now(timezone.utc).isoformat(),
+            "level": record.levelname,
+            "module": record.module,
+            "function": record.funcName,
+            "line": record.lineno,
+            "message": record.getMessage(),
+            "request_id": request_id_var.get("-"),
+        }
+        if record.exc_info and record.exc_info[0] is not None:
+            log_data["exception"] = self.formatException(record.exc_info)
+        if hasattr(record, "extra_data"):
+            log_data["extra"] = record.extra_data
+        return json.dumps(log_data, ensure_ascii=False)
+
+
+def get_logger(name: str) -> logging.Logger:
+    from core.config import settings
+
+    logger = logging.getLogger(name)
+    if not logger.handlers:
+        handler = logging.StreamHandler(sys.stdout)
+        handler.setFormatter(JSONFormatter())
+        logger.addHandler(handler)
+        logger.setLevel(getattr(logging, settings.log_level, logging.INFO))
+        logger.propagate = False
+    return logger

+ 45 - 0
core/middleware.py

@@ -0,0 +1,45 @@
+import time
+import uuid
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.requests import Request
+from starlette.responses import Response
+from core.logging import get_logger, request_id_var
+
+logger = get_logger("middleware")
+
+
+def get_request_id() -> str:
+    return request_id_var.get("-")
+
+
+class RequestLoggingMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next) -> Response:
+        req_id = str(uuid.uuid4())[:8]
+        request_id_var.set(req_id)
+
+        start_time = time.time()
+        client_ip = request.client.host if request.client else "unknown"
+
+        logger.info(
+            f"Request started: {request.method} {request.url.path}",
+            extra={"extra_data": {"client_ip": client_ip, "method": request.method, "path": str(request.url.path)}},
+        )
+
+        try:
+            response = await call_next(request)
+        except Exception:
+            duration_ms = (time.time() - start_time) * 1000
+            logger.error(
+                f"Request failed: {request.method} {request.url.path} ({duration_ms:.1f}ms)",
+                exc_info=True,
+            )
+            raise
+
+        duration_ms = (time.time() - start_time) * 1000
+        logger.info(
+            f"Request completed: {request.method} {request.url.path} -> {response.status_code} ({duration_ms:.1f}ms)",
+            extra={"extra_data": {"status_code": response.status_code, "duration_ms": round(duration_ms, 1)}},
+        )
+
+        response.headers["X-Request-ID"] = req_id
+        return response

+ 29 - 19
database/dao/mysql_dao.py

@@ -1,7 +1,10 @@
+from core import get_logger
 from database import MySqlDatabaseHelper
 from sqlalchemy import text, bindparam
 import pandas as pd
 
+logger = get_logger("database.dao")
+
 class MySqlDao:
     _instance = None
     
@@ -31,6 +34,7 @@ class MySqlDao:
         
     def load_product_data(self, city_uuid):
         """从数据库中读取商品信息"""
+        logger.info(f"Loading product data for city_uuid={city_uuid}")
         query = f"SELECT * FROM {self._product_tablename} WHERE city_uuid = :city_uuid AND org_is_active = '是'"
         params = {"city_uuid": city_uuid}
         
@@ -39,6 +43,7 @@ class MySqlDao:
         
     def load_cust_data(self, city_uuid):
         """从数据库中读取商户信息"""
+        logger.info(f"Loading cust data for city_uuid={city_uuid}")
         query = f"SELECT * FROM {self._cust_tablename} WHERE corp_uuid = :city_uuid"
         params = {"city_uuid": city_uuid}
         
@@ -47,6 +52,7 @@ class MySqlDao:
     
     def load_order_data(self, city_uuid):
         """从数据库中读取订单信息"""
+        logger.info(f"Loading order data for city_uuid={city_uuid}")
         query = f"SELECT * FROM {self._order_tablename} WHERE city_uuid = :city_uuid"
         params = {"city_uuid": city_uuid}
         
@@ -61,6 +67,7 @@ class MySqlDao:
     
     def load_order_analysis_index_data(self, city_uuid):
         """从数据库中读取销售指标评估表"""
+        logger.info(f"Loading order analysis index data for city_uuid={city_uuid}")
         query = f"SELECT * FROM {self._order_analysis_table_name} WHERE city_uuid = :city_uuid"
         params = {"city_uuid": city_uuid}
         
@@ -69,6 +76,7 @@ class MySqlDao:
 
     def load_delivery_order_data(self, city_uuid, start_time, end_time):
         """从数据库中读取订单信息"""
+        logger.info(f"Loading delivery order data for city_uuid={city_uuid}, start_time={start_time}, end_time={end_time}")
         query = f"SELECT * FROM {self._eval_order_name} WHERE city_uuid = :city_uuid AND cycle_begin_date = :start_time AND cycle_end_date = :end_time"
         params = {
             "city_uuid": city_uuid,
@@ -83,6 +91,7 @@ class MySqlDao:
     
     def load_mock_order_data(self):
         """从数据库中读取mock的订单信息"""
+        logger.info("Loading mock order data")
         query = f"SELECT * FROM {self._mock_order_tablename}"
         
         data = self.db_helper.load_data_with_page(query, {})
@@ -91,6 +100,7 @@ class MySqlDao:
     
     def load_shopping_data(self, city_uuid):
         """从数据库中读取商圈数据"""
+        logger.info(f"Loading shopping data for city_uuid={city_uuid}")
         query = f"SELECT * FROM {self._shopping_tablename} WHERE city_uuid = :city_uuid"
         params = {"city_uuid": city_uuid}
         
@@ -100,6 +110,7 @@ class MySqlDao:
     
     def get_product_by_id(self, city_uuid, product_id):
         """根据city_uuid 和 product_id 从表中获取拼柜信息"""
+        logger.info(f"Getting product by id for city_uuid={city_uuid}, product_id={product_id}")
         query = text(f"""
             SELECT *
             FROM {self._product_tablename}
@@ -112,6 +123,7 @@ class MySqlDao:
     
     def get_cust_by_ids(self, city_uuid, cust_id_list):
         """根据零售户列表查询其信息"""
+        logger.info(f"Getting cust by ids for city_uuid={city_uuid}, count={len(cust_id_list) if cust_id_list else 0}")
         if not cust_id_list:
             return pd.DataFrame()
 
@@ -128,6 +140,7 @@ class MySqlDao:
     
     def get_shop_by_ids(self, city_uuid, cust_id_list):
         """根据零售户列表查询其信息"""
+        logger.info(f"Getting shop by ids for city_uuid={city_uuid}, count={len(cust_id_list) if cust_id_list else 0}")
         if not cust_id_list:
             return pd.DataFrame()
 
@@ -144,6 +157,7 @@ class MySqlDao:
     
     def get_product_by_ids(self, city_uuid, product_id_list):
         """根据product_code列表查询其信息"""
+        logger.info(f"Getting products by ids for city_uuid={city_uuid}, count={len(product_id_list) if product_id_list else 0}")
         if not product_id_list:
             return pd.DataFrame()
 
@@ -166,6 +180,7 @@ class MySqlDao:
     
     def get_order_by_product_ids(self, city_uuid, product_ids):
         """获取指定香烟列表的所有售卖记录"""
+        logger.info(f"Getting orders by product ids for city_uuid={city_uuid}, count={len(product_ids) if product_ids else 0}")
         if not product_ids:
             return pd.DataFrame()
 
@@ -193,7 +208,7 @@ class MySqlDao:
         return data
     
     def get_order_by_product(self, city_uuid, product_id):
-        
+        logger.info(f"Getting orders by product for city_uuid={city_uuid}, product_id={product_id}")
         query = f"""
             SELECT *
             FROM {self._order_tablename}
@@ -210,6 +225,7 @@ class MySqlDao:
         return data
     
     def get_eval_order_by_product(self, city_uuid, product_id):
+        logger.info(f"Getting eval orders by product for city_uuid={city_uuid}, product_id={product_id}")
         query = f"""
             SELECT *
             FROM {self._eval_order_name}
@@ -223,6 +239,7 @@ class MySqlDao:
     
     def get_delivery_data_by_product(self, city_uuid, product_id, start_time, end_time):
         """通过品规获取验证数据"""
+        logger.info(f"Getting delivery data by product for city_uuid={city_uuid}, product_id={product_id}, start_time={start_time}, end_time={end_time}")
         query = f"""
             SELECT *
             FROM {self._eval_order_name}
@@ -242,6 +259,7 @@ class MySqlDao:
         return data
     
     def get_order_by_cust(self, city_uuid, cust_id):
+        logger.info(f"Getting orders by cust for city_uuid={city_uuid}, cust_id={cust_id}")
         query = f"""
             SELECT *
             FROM {self._order_tablename}
@@ -254,6 +272,7 @@ class MySqlDao:
         return data
     
     def get_order_by_cust_and_product(self, city_uuid, cust_id, product_id):
+        logger.info(f"Getting orders by cust and product for city_uuid={city_uuid}, cust_id={cust_id}, product_id={product_id}")
         query = f"""
             SELECT *
             FROM {self._order_tablename}
@@ -267,6 +286,7 @@ class MySqlDao:
         return data
     
     def get_product_from_order(self, city_uuid):
+        logger.info(f"Getting products from order for city_uuid={city_uuid}")
         query = f"SELECT DISTINCT product_code FROM {self._order_tablename} WHERE city_uuid = :city_uuid ORDER BY product_code"
         params = {"city_uuid": city_uuid}
         
@@ -275,6 +295,7 @@ class MySqlDao:
         return data
 
     def get_cust_list(self, city_uuid):
+        logger.info(f"Getting cust list for city_uuid={city_uuid}")
         query = f"SELECT DISTINCT cust_code FROM {self._cust_tablename} WHERE corp_uuid = :city_uuid ORDER BY cust_code"
         params = {"city_uuid": city_uuid}
         
@@ -282,26 +303,14 @@ class MySqlDao:
 
         return data
 
-    def data_preprocess(self, data: pd.DataFrame):
-        """数据预处理"""
-        data.drop(["cust_uuid", "longitude", "latitude", "range_radius"], axis=1, inplace=True)
-        remaining_cols = data.columns.drop(["city_uuid", "cust_code"])
-        col_with_missing = remaining_cols[data[remaining_cols].isnull().any()].tolist() # 判断有缺失的字段
-        col_all_missing = remaining_cols[data[remaining_cols].isnull().all()].to_list() # 全部缺失的字段
-        col_partial_missing = list(set(col_with_missing) - set(col_all_missing)) # 部分缺失的字段
-        
-        for col in col_partial_missing:
-            data[col] = data[col].fillna(data[col].mean())
-        
-        for col in col_all_missing:
-            data[col] = data[col].fillna(0).infer_objects(copy=False)
-        
     def insert_report(self, data_dict):
         """向report中插入数据"""
+        logger.info("Inserting report data")
         return self.db_helper.insert_data(self._report_tablename, data_dict)
     
     def update_eval_report_data(self, cultivacation_id, eval_fileid):
         """更新投放记录中的验证报告fileid"""
+        logger.info(f"Updating eval report data for cultivacation_id={cultivacation_id}")
         update_data = {"val_table": eval_fileid}
         conditions = [
             "cultivacation_id = :cultivacation_id",
@@ -314,13 +323,14 @@ class MySqlDao:
     
     def get_report_file_id(self, cultivacation_id):
         """从report中根据cultivacation_id获取对应文件的fileid"""
+        logger.info(f"Getting report file id for cultivacation_id={cultivacation_id}")
         query = f"SELECT product_info_table, relation_table, similarity_product_table, recommend_table, val_table FROM {self._report_tablename} WHERE cultivacation_id = :cultivacation_id"
         params = {"cultivacation_id": cultivacation_id}
-        
         result = self.db_helper.fetch_one(text(query), params)
-        data = pd.DataFrame([dict(result._mapping)] if result else None)
-        
-        return data
+        if result is None:
+            logger.warning(f"No report found for cultivacation_id={cultivacation_id}")
+            return pd.DataFrame()
+        return pd.DataFrame([dict(result._mapping)])
         
 if __name__ == "__main__":
     dao = MySqlDao()

+ 154 - 209
database/db/mysql.py

@@ -1,209 +1,154 @@
-from config import load_config
-import pandas as pd
-from sqlalchemy import create_engine, text
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy.exc import SQLAlchemyError
-from tqdm import tqdm
-
-cfgs = load_config()
-
-
-class MySqlDatabaseHelper:
-    _instance = None
-    
-    def __new__(cls):
-        if not cls._instance:
-            cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
-            cls._instance._initialized = False
-        return cls._instance
-        
-    def __init__(self):
-        if self._initialized:
-            return
-        
-        self._host = cfgs['mysql']['host']
-        self._port = cfgs['mysql']['port']
-        self._user = cfgs['mysql']['user']
-        self._passwd = cfgs['mysql']['passwd']
-        self._dbname = cfgs['mysql']['db']
-        
-        self.connect_database()
-        self._initialized = True
-        
-    def connect_database(self):
-        # 创建数据库连接
-        try:
-            conn = "mysql+pymysql://" + self._user + ":" + self._passwd + "@" + self._host + ":" + str(self._port) + "/" + self._dbname
-
-            # 通过连接池创建engine
-            self.engine = create_engine(
-                conn,
-                pool_size=20, # 设置连接池大小
-                max_overflow=30, # 超过连接池大小时的额外连接数
-                pool_recycle=1800, # 回收连接时间
-                pool_pre_ping=True, # 防止断开连接
-                isolation_level="READ COMMITTED" # 降低隔离级别
-            )
-        except Exception as e:
-            raise ConnectionAbortedError(f"failed to create connection: {e}")
-
-        self._DBSession = sessionmaker(bind=self.engine)
-    
-    def load_data_with_page(self, query, params, page_size=100000):
-        """分页查询数据"""
-        data = pd.DataFrame()
-        # 用子查询包裹原始查询来计数,避免字符串替换
-        count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq")
-        query += " LIMIT :limit OFFSET :offset"
-        query = text(query)
-
-        # 获取总行数
-        result = self.fetch_one(count_query, params)
-        total_rows = result[0] if result is not None else 0
-
-        if total_rows == 0:
-            return data
-
-        page = 1
-        with tqdm(total=total_rows, desc="Loading data", unit="rows") as pbar:
-            while True:
-                offset = (page - 1) * page_size
-                # 复制 params 避免修改调用方的字典
-                page_params = dict(params)
-                page_params["limit"] = page_size
-                page_params["offset"] = offset
-
-                df = pd.DataFrame(self.fetch_all(query, page_params))
-                if df.empty:
-                    break
-                data = pd.concat([data, df], ignore_index=True)
-
-                pbar.update(len(df))
-
-                page += 1
-        return data
-        
-        
-    def fetch_all(self, query, params=None):
-        """执行SQL查询并返回所有结果"""
-        session = self._DBSession()
-        try:
-            results = session.execute(query, params or {}).fetchall()
-            return results
-        except SQLAlchemyError as e:
-            session.rollback()
-            print(f"error: {e}")
-            raise
-        finally:
-            session.close()
-            
-    def fetch_one(self, query, params=None):
-        """执行SQL查询并返回单条结果"""
-        session = self._DBSession()
-        try:
-            result = session.execute(query, params or {}).fetchone()
-            return result
-
-        except SQLAlchemyError as e:
-            session.rollback()
-            print(f"error: {e}")
-            raise
-        finally:
-            session.close()
-            
-    def insert_data(self, table_name, data_dict):
-        """插入单条数据到指定表"""
-        if not data_dict:
-            return 0
-        
-        columns = ", ".join(data_dict.keys())
-        values = ", ".join([f":{key}" for key in data_dict.keys()])
-        query = text(f"INSERT INTO {table_name} ({columns}) VALUES ({values})")
-        
-        session = self._DBSession()
-        
-        try:
-            result = session.execute(query, data_dict)
-            session.commit()
-            return result.rowcount
-        
-        except SQLAlchemyError as e:
-            session.rollback()
-            print(f"Error inserting data: {e}")
-            return 0
-        finally:
-            session.close()
-            
-    def update_data(self, table_name, update_dict, conditions, condition_params=None):
-        """更新表中符合条件的数据"""
-        if not update_dict:
-            return 0
-        
-        set_clause = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
-        
-        if len(conditions) == 1:
-            where_clause = f"WHERE {conditions[0]}"
-        elif len(conditions) > 1:
-            where_clause = f"WHERE {' AND '.join(conditions)}"
-        else:
-            where_clause = ""
-        
-        query = text(f"UPDATE {table_name} SET {set_clause} {where_clause}")
-        
-        params = update_dict.copy()
-        if condition_params:
-            params.update(condition_params)
-            
-        session = self._DBSession()
-        try:
-            result = session.execute(query, params)
-            session.commit()
-            return result.rowcount
-        except SQLAlchemyError as e:
-            session.rollback()
-            print(f"Error updating data: {e}")
-            return 0
-        
-        finally:
-            session.close()
-    
-    def execute_query(self, query, params=None):
-        """执行SQL语句 (无返回值, 如INSERT, UPDATE, DELETE)"""
-        session = self._DBSession()
-        try:
-            session.execute(query, params or {})
-            session.commit()
-        except SQLAlchemyError as e:
-            session.rollback()
-            print(f"Error: {e}")
-        finally:
-            session.close()
-            
-if __name__ == '__main__':
-    db_helper = MySqlDatabaseHelper()
-    
-    table_name = 'tads_brandcul_report'
-    data_dict = {
-        'cultivacation_id': 10000002,
-        'city_uuid': '00000000000000000000000011445301',
-        'limit_cycle_name': '202505W1(05.05-05.11)',
-        'product_code': '440298',
-        'product_info_table': 'D72E3FAE8DCE4270BD23C3EC015C0A35',
-        'relation_table': 'AD889019FD4F4EE7B887981162BA09EC',
-        'similarity_product_table': 'CE436AC24D96461FA0C091CB01E9BC05',
-        'recommend_table': 'A7C5918B8DDB4BEA9D921936955CBAF6',
-    }
-    
-    # db_helper.insert_data(table_name, data_dict)
-    
-    update_data = {"val_table": "A7C5918B8DDB4BEA9D921936955CBAF6"}
-    conditions = [
-        "cultivacation_id = :cultivacation_id",
-        "city_uuid = :city_uuid"
-    ]
-    condition_params = {
-        'cultivacation_id': 10000001,
-        'city_uuid': '00000000000000000000000011445301',
-    }
-    
-    db_helper.update_data(table_name, update_data, conditions, condition_params)
+from contextlib import contextmanager
+from core import get_logger, settings, DatabaseException
+import pandas as pd
+from sqlalchemy import create_engine, text
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.exc import SQLAlchemyError
+
+logger = get_logger("database.mysql")
+
+
+class MySqlDatabaseHelper:
+    _instance = None
+
+    def __new__(cls):
+        if not cls._instance:
+            cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
+            cls._instance._initialized = False
+        return cls._instance
+
+    def __init__(self):
+        if self._initialized:
+            return
+        self._connect_database()
+        self._initialized = True
+
+    def _connect_database(self):
+        try:
+            conn_str = (
+                f"mysql+pymysql://{settings.mysql_user}:{settings.mysql_password}"
+                f"@{settings.mysql_host}:{settings.mysql_port}/{settings.mysql_db}"
+            )
+            self.engine = create_engine(
+                conn_str,
+                pool_size=20,
+                max_overflow=30,
+                pool_recycle=1800,
+                pool_pre_ping=True,
+                isolation_level="READ COMMITTED",
+            )
+            self._DBSession = sessionmaker(bind=self.engine)
+            logger.info("MySQL connection pool created", extra={"extra_data": {"host": settings.mysql_host, "db": settings.mysql_db}})
+        except Exception as e:
+            logger.error("Failed to create MySQL connection", exc_info=True)
+            raise DatabaseException(message="数据库连接失败", detail=str(e))
+
+    @contextmanager
+    def get_session(self):
+        session = self._DBSession()
+        try:
+            yield session
+        except SQLAlchemyError as e:
+            session.rollback()
+            logger.error("Database operation failed", exc_info=True)
+            raise DatabaseException(message="数据库操作失败", detail=str(e))
+        finally:
+            session.close()
+
+    def load_data_with_page(self, query, params, page_size=100000):
+        """分页查询数据"""
+        count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq")
+        query += " LIMIT :limit OFFSET :offset"
+        query = text(query)
+
+        result = self.fetch_one(count_query, params)
+        total_rows = result[0] if result is not None else 0
+
+        if total_rows == 0:
+            logger.debug("Query returned 0 rows")
+            return pd.DataFrame()
+
+        logger.debug(f"Loading {total_rows} rows with page_size={page_size}")
+        data = pd.DataFrame()
+        page = 1
+        while True:
+            offset = (page - 1) * page_size
+            page_params = dict(params)
+            page_params["limit"] = page_size
+            page_params["offset"] = offset
+
+            df = pd.DataFrame(self.fetch_all(query, page_params))
+            if df.empty:
+                break
+            data = pd.concat([data, df], ignore_index=True)
+            page += 1
+
+        logger.debug(f"Loaded {len(data)} rows in {page - 1} pages")
+        return data
+
+    def fetch_all(self, query, params=None):
+        """执行SQL查询并返回所有结果"""
+        with self.get_session() as session:
+            results = session.execute(query, params or {}).fetchall()
+            return results
+
+    def fetch_one(self, query, params=None):
+        """执行SQL查询并返回单条结果"""
+        with self.get_session() as session:
+            result = session.execute(query, params or {}).fetchone()
+            return result
+
+    def insert_data(self, table_name, data_dict):
+        """插入单条数据到指定表"""
+        if not data_dict:
+            return 0
+
+        columns = ", ".join(data_dict.keys())
+        values = ", ".join([f":{key}" for key in data_dict.keys()])
+        query = text(f"INSERT INTO {table_name} ({columns}) VALUES ({values})")
+
+        with self.get_session() as session:
+            result = session.execute(query, data_dict)
+            session.commit()
+            logger.info(f"Inserted 1 row into {table_name}")
+            return result.rowcount
+
+    def update_data(self, table_name, update_dict, conditions, condition_params=None):
+        """更新表中符合条件的数据"""
+        if not update_dict:
+            return 0
+
+        set_clause = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
+
+        if len(conditions) == 1:
+            where_clause = f"WHERE {conditions[0]}"
+        elif len(conditions) > 1:
+            where_clause = f"WHERE {' AND '.join(conditions)}"
+        else:
+            where_clause = ""
+
+        query = text(f"UPDATE {table_name} SET {set_clause} {where_clause}")
+
+        params = update_dict.copy()
+        if condition_params:
+            params.update(condition_params)
+
+        with self.get_session() as session:
+            result = session.execute(query, params)
+            session.commit()
+            logger.info(f"Updated {result.rowcount} rows in {table_name}")
+            return result.rowcount
+
+    def execute_query(self, query, params=None):
+        """执行SQL语句"""
+        with self.get_session() as session:
+            session.execute(query, params or {})
+            session.commit()
+
+    def check_connection(self) -> bool:
+        """检查数据库连接是否正常"""
+        try:
+            self.fetch_one(text("SELECT 1"), {})
+            return True
+        except Exception:
+            return False

+ 42 - 52
database/db/redis_db.py

@@ -1,52 +1,42 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-import redis
-from config import load_config
-
-cfgs = load_config()
-
-
-class RedisDatabaseHelper:
-    _instance = None
-    
-    def __new__(cls):
-        if not cls._instance:
-            cls._instance = super(RedisDatabaseHelper, cls).__new__(cls)
-            cls._instance._initialized = False
-        return cls._instance
-        
-    def __init__(self):
-        if self._initialized:
-            return
-        self.redis = redis.StrictRedis(host=cfgs['redis']['host'],
-                                       port=cfgs['redis']['port'],
-                                       password=cfgs['redis']['passwd'],
-                                       db=cfgs['redis']['db'],
-                                       decode_responses=True)
-        
-        self._initialized = True
-
-
-if __name__ == '__main__':
-    import random
-    # 连接到 Redis 服务器
-    r = RedisDatabaseHelper().redis
-
-    # 有序集合的键名
-    zset_key = 'configs:hotkeys'
-
-    data_list = ['ORDER_FULLORDR_RATE', 'MONTH6_SALE_QTY_YOY', 'MONTH6_SALE_QTY_MOM', 'MONTH6_SALE_QTY']
-
-    # 清空已有的有序集合(可选,若需要全新的集合可执行此操作)
-    r.delete(zset_key)
-    
-    for item in data_list:
-        # 生成 80 到 100 之间的随机数,小数点后保留 4 位
-        score = round(random.uniform(80, 100), 4)
-        # 将元素和对应的分数添加到有序集合中
-        r.zadd(zset_key, {item: score})
-
-    # # 从 Redis 中读取有序集合并打印
-    # result = r.zrange(zset_key, 0, -1, withscores=True)
-    # for item, score in result:
-    #     print(f"元素: {item}, 分数: {score}")
+import redis
+from core import get_logger, settings, DatabaseException
+
+logger = get_logger("database.redis")
+
+
+class RedisDatabaseHelper:
+    _instance = None
+
+    def __new__(cls):
+        if not cls._instance:
+            cls._instance = super(RedisDatabaseHelper, cls).__new__(cls)
+            cls._instance._initialized = False
+        return cls._instance
+
+    def __init__(self):
+        if self._initialized:
+            return
+        try:
+            pool = redis.ConnectionPool(
+                host=settings.redis_host,
+                port=settings.redis_port,
+                password=settings.redis_password,
+                db=settings.redis_db,
+                decode_responses=True,
+                max_connections=50,
+            )
+            self.redis = redis.StrictRedis(connection_pool=pool)
+            self.redis.ping()
+            logger.info("Redis connection established", extra={"extra_data": {"host": settings.redis_host, "db": settings.redis_db}})
+        except redis.ConnectionError as e:
+            logger.error("Failed to connect to Redis", exc_info=True)
+            raise DatabaseException(message="Redis连接失败", detail=str(e))
+        self._initialized = True
+
+    def check_connection(self) -> bool:
+        """检查 Redis 连接是否正常"""
+        try:
+            self.redis.ping()
+            return True
+        except Exception:
+            return False

+ 1316 - 0
docs/superpowers/plans/2026-05-21-project-refactoring.md

@@ -0,0 +1,1316 @@
+# BrandCultivation 项目级重构实现计划
+
+> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
+
+**Goal:** 为 BrandCultivation 项目添加日志系统、配置管理、异常处理、请求追踪,修复已知 bug,移除明文密码。
+
+**Architecture:** 渐进式重构 — 新增 `core/` 基础设施层,在现有模块中逐步替换 print 为 logger,添加错误处理和请求追踪。保持现有目录结构和业务逻辑不变。
+
+**Tech Stack:** Python 3.x, FastAPI, SQLAlchemy, Redis, logging (stdlib), pathlib, contextvars, uuid
+
+---
+
+## File Structure
+
+### New Files
+- `core/__init__.py` — 公共接口导出
+- `core/logging.py` — JSON 格式日志系统
+- `core/config.py` — 配置管理(YAML + 环境变量)
+- `core/exceptions.py` — 自定义异常体系
+- `core/middleware.py` — 请求日志和 request_id 中间件
+- `.env.example` — 环境变量模板
+
+### Modified Files
+- `config/config.py` — 废弃,改为从 core/config.py 导入
+- `config/database_config.yaml` — 移除密码
+- `database/db/mysql.py` — 日志、session context manager、配置来源
+- `database/db/redis_db.py` — 日志、配置来源、连接池
+- `database/__init__.py` — 更新导出
+- `database/dao/mysql_dao.py` — 日志、异常处理
+- `run_api.py` — 注册中间件、健康检查、异常处理器
+- `api/recommend.py` — 日志、错误处理
+- `api/eval_report.py` — 日志、错误处理
+- `api/report.py` — 日志
+- `models/rank/gbdt_lr_inference.py` — bug 修复、日志
+- `models/recommend.py` — 日志
+- `models/recall/hot_recall.py` — 替换 print
+- `models/recall/itemCF/ItemCF.py` — 替换 print
+- `models/item2vec/inference.py` — 日志
+- `models/rank/data/preprocess.py` — 替换 print
+- `utils/file_stream.py` — 日志、错误处理
+- `utils/report_utils.py` — 日志
+- `train.py` — 替换 print
+
+---
+
+### Task 1: 创建 `core/exceptions.py` — 自定义异常体系
+
+**Files:**
+- Create: `core/exceptions.py`
+
+- [ ] **Step 1: 创建异常定义文件**
+
+```python
+class AppException(Exception):
+    """应用异常基类"""
+    def __init__(self, code: int, message: str, detail: str = None):
+        self.code = code
+        self.message = message
+        self.detail = detail
+        super().__init__(message)
+
+
+class DatabaseException(AppException):
+    """数据库操作失败"""
+    def __init__(self, message: str = "数据库操作失败", detail: str = None):
+        super().__init__(code=500, message=message, detail=detail)
+
+
+class ModelException(AppException):
+    """模型推理失败"""
+    def __init__(self, message: str = "模型推理失败", detail: str = None):
+        super().__init__(code=500, message=message, detail=detail)
+
+
+class FileServiceException(AppException):
+    """文件服务失败"""
+    def __init__(self, message: str = "文件服务操作失败", detail: str = None):
+        super().__init__(code=500, message=message, detail=detail)
+
+
+class ValidationException(AppException):
+    """业务校验失败"""
+    def __init__(self, message: str = "参数校验失败", detail: str = None):
+        super().__init__(code=400, message=message, detail=detail)
+```
+
+- [ ] **Step 2: 验证导入**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "from core.exceptions import AppException, DatabaseException, ModelException, FileServiceException, ValidationException; print('OK')"`
+Expected: `OK`
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add core/exceptions.py
+git commit -m "feat(core): add custom exception hierarchy"
+```
+
+---
+
+### Task 2: 创建 `core/config.py` — 配置管理
+
+**Files:**
+- Create: `core/config.py`
+
+- [ ] **Step 1: 创建配置管理模块**
+
+```python
+import os
+from pathlib import Path
+import yaml
+
+PROJECT_ROOT = Path(__file__).resolve().parent.parent
+
+
+def _get_env(key: str, default=None):
+    """从环境变量获取值"""
+    return os.environ.get(key, default)
+
+
+def _load_yaml(filename: str) -> dict:
+    """加载 YAML 配置文件"""
+    filepath = PROJECT_ROOT / "config" / filename
+    with open(filepath, encoding="utf-8") as f:
+        return yaml.safe_load(f) or {}
+
+
+class _Settings:
+    """配置单例,支持环境变量覆盖 YAML 默认值"""
+
+    def __init__(self):
+        self._db_cfg = _load_yaml("database_config.yaml")
+        self._model_cfg = _load_yaml("model_config.yaml")
+        self._service_cfg = _load_yaml("service_config.yaml")
+
+    @property
+    def mysql_host(self) -> str:
+        return _get_env("MYSQL_HOST", self._db_cfg.get("mysql", {}).get("host", "localhost"))
+
+    @property
+    def mysql_port(self) -> int:
+        return int(_get_env("MYSQL_PORT", self._db_cfg.get("mysql", {}).get("port", 3306)))
+
+    @property
+    def mysql_user(self) -> str:
+        return _get_env("MYSQL_USER", self._db_cfg.get("mysql", {}).get("user", "root"))
+
+    @property
+    def mysql_password(self) -> str:
+        return _get_env("MYSQL_PASSWORD", self._db_cfg.get("mysql", {}).get("passwd", ""))
+
+    @property
+    def mysql_db(self) -> str:
+        return _get_env("MYSQL_DB", self._db_cfg.get("mysql", {}).get("db", ""))
+
+    @property
+    def redis_host(self) -> str:
+        return _get_env("REDIS_HOST", self._db_cfg.get("redis", {}).get("host", "localhost"))
+
+    @property
+    def redis_port(self) -> int:
+        return int(_get_env("REDIS_PORT", self._db_cfg.get("redis", {}).get("port", 6379)))
+
+    @property
+    def redis_password(self) -> str:
+        return _get_env("REDIS_PASSWORD", self._db_cfg.get("redis", {}).get("passwd", ""))
+
+    @property
+    def redis_db(self) -> int:
+        return int(_get_env("REDIS_DB", self._db_cfg.get("redis", {}).get("db", 0)))
+
+    @property
+    def log_level(self) -> str:
+        return _get_env("LOG_LEVEL", "INFO").upper()
+
+    @property
+    def file_upload_url(self) -> str:
+        return _get_env("FILE_UPLOAD_URL", self._service_cfg.get("aliyun", {}).get("upload_url", ""))
+
+    @property
+    def file_download_url(self) -> str:
+        return _get_env("FILE_DOWNLOAD_URL", self._service_cfg.get("aliyun", {}).get("download_url", ""))
+
+    @property
+    def model_config(self) -> dict:
+        return self._model_cfg
+
+
+settings = _Settings()
+```
+
+- [ ] **Step 2: 验证导入**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "from core.config import settings; print(settings.mysql_host)"`
+Expected: 输出数据库 host 地址
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add core/config.py
+git commit -m "feat(core): add config management with env var override"
+```
+
+---
+
+### Task 3: 创建 `core/logging.py` — 日志系统
+
+**Files:**
+- Create: `core/logging.py`
+
+- [ ] **Step 1: 创建日志模块**
+
+```python
+import logging
+import json
+import sys
+from contextvars import ContextVar
+from datetime import datetime, timezone
+
+request_id_var: ContextVar[str] = ContextVar("request_id", default="-")
+
+
+class JSONFormatter(logging.Formatter):
+    """JSON 格式日志输出"""
+
+    def format(self, record: logging.LogRecord) -> str:
+        log_data = {
+            "timestamp": datetime.now(timezone.utc).isoformat(),
+            "level": record.levelname,
+            "module": record.module,
+            "function": record.funcName,
+            "line": record.lineno,
+            "message": record.getMessage(),
+            "request_id": request_id_var.get("-"),
+        }
+        if record.exc_info and record.exc_info[0] is not None:
+            log_data["exception"] = self.formatException(record.exc_info)
+        if hasattr(record, "extra_data"):
+            log_data["extra"] = record.extra_data
+        return json.dumps(log_data, ensure_ascii=False)
+
+
+def get_logger(name: str) -> logging.Logger:
+    """获取指定名称的 logger"""
+    from core.config import settings
+
+    logger = logging.getLogger(name)
+    if not logger.handlers:
+        handler = logging.StreamHandler(sys.stdout)
+        handler.setFormatter(JSONFormatter())
+        logger.addHandler(handler)
+        logger.setLevel(getattr(logging, settings.log_level, logging.INFO))
+        logger.propagate = False
+    return logger
+```
+
+- [ ] **Step 2: 验证导入和基本功能**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "from core.logging import get_logger; logger = get_logger('test'); logger.info('hello')"`
+Expected: JSON 格式日志输出
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add core/logging.py
+git commit -m "feat(core): add JSON logging system with request_id support"
+```
+
+---
+
+### Task 4: 创建 `core/middleware.py` — 请求中间件
+
+**Files:**
+- Create: `core/middleware.py`
+
+- [ ] **Step 1: 创建中间件模块**
+
+```python
+import time
+import uuid
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.requests import Request
+from starlette.responses import Response
+from core.logging import get_logger, request_id_var
+
+logger = get_logger("middleware")
+
+
+def get_request_id() -> str:
+    """获取当前请求的 request_id"""
+    return request_id_var.get("-")
+
+
+class RequestLoggingMiddleware(BaseHTTPMiddleware):
+    """请求日志中间件:生成 request_id,记录请求开始/结束"""
+
+    async def dispatch(self, request: Request, call_next) -> Response:
+        req_id = str(uuid.uuid4())[:8]
+        request_id_var.set(req_id)
+
+        start_time = time.time()
+        client_ip = request.client.host if request.client else "unknown"
+
+        logger.info(
+            f"Request started: {request.method} {request.url.path}",
+            extra={"extra_data": {"client_ip": client_ip, "method": request.method, "path": str(request.url.path)}},
+        )
+
+        try:
+            response = await call_next(request)
+        except Exception as e:
+            duration_ms = (time.time() - start_time) * 1000
+            logger.error(
+                f"Request failed: {request.method} {request.url.path} ({duration_ms:.1f}ms)",
+                exc_info=True,
+            )
+            raise
+
+        duration_ms = (time.time() - start_time) * 1000
+        logger.info(
+            f"Request completed: {request.method} {request.url.path} -> {response.status_code} ({duration_ms:.1f}ms)",
+            extra={"extra_data": {"status_code": response.status_code, "duration_ms": round(duration_ms, 1)}},
+        )
+
+        response.headers["X-Request-ID"] = req_id
+        return response
+```
+
+- [ ] **Step 2: 验证导入**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "from core.middleware import RequestLoggingMiddleware, get_request_id; print('OK')"`
+Expected: `OK`
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add core/middleware.py
+git commit -m "feat(core): add request logging middleware with request_id tracking"
+```
+
+---
+
+### Task 5: 创建 `core/__init__.py` — 公共接口导出
+
+**Files:**
+- Create: `core/__init__.py`
+
+- [ ] **Step 1: 创建 init 文件**
+
+```python
+from core.logging import get_logger, request_id_var
+from core.config import settings
+from core.exceptions import (
+    AppException,
+    DatabaseException,
+    ModelException,
+    FileServiceException,
+    ValidationException,
+)
+from core.middleware import RequestLoggingMiddleware, get_request_id
+
+__all__ = [
+    "get_logger",
+    "request_id_var",
+    "settings",
+    "AppException",
+    "DatabaseException",
+    "ModelException",
+    "FileServiceException",
+    "ValidationException",
+    "RequestLoggingMiddleware",
+    "get_request_id",
+]
+```
+
+- [ ] **Step 2: 验证完整导入**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "from core import get_logger, settings, AppException, DatabaseException, RequestLoggingMiddleware, get_request_id; print('All imports OK')"`
+Expected: `All imports OK`
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add core/__init__.py
+git commit -m "feat(core): add public interface exports"
+```
+
+---
+
+### Task 6: 改进 `database/db/mysql.py` — 日志、session 管理、配置来源
+
+**Files:**
+- Modify: `database/db/mysql.py`
+
+- [ ] **Step 1: 重写 mysql.py**
+
+替换整个文件内容为:
+
+```python
+from contextlib import contextmanager
+from core import get_logger, settings, DatabaseException
+import pandas as pd
+from sqlalchemy import create_engine, text
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.exc import SQLAlchemyError
+
+logger = get_logger("database.mysql")
+
+
+class MySqlDatabaseHelper:
+    _instance = None
+
+    def __new__(cls):
+        if not cls._instance:
+            cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
+            cls._instance._initialized = False
+        return cls._instance
+
+    def __init__(self):
+        if self._initialized:
+            return
+        self._connect_database()
+        self._initialized = True
+
+    def _connect_database(self):
+        try:
+            conn_str = (
+                f"mysql+pymysql://{settings.mysql_user}:{settings.mysql_password}"
+                f"@{settings.mysql_host}:{settings.mysql_port}/{settings.mysql_db}"
+            )
+            self.engine = create_engine(
+                conn_str,
+                pool_size=20,
+                max_overflow=30,
+                pool_recycle=1800,
+                pool_pre_ping=True,
+                isolation_level="READ COMMITTED",
+            )
+            self._DBSession = sessionmaker(bind=self.engine)
+            logger.info("MySQL connection pool created", extra={"extra_data": {"host": settings.mysql_host, "db": settings.mysql_db}})
+        except Exception as e:
+            logger.error("Failed to create MySQL connection", exc_info=True)
+            raise DatabaseException(message="数据库连接失败", detail=str(e))
+
+    @contextmanager
+    def get_session(self):
+        session = self._DBSession()
+        try:
+            yield session
+            session.commit()
+        except SQLAlchemyError as e:
+            session.rollback()
+            logger.error("Database operation failed", exc_info=True)
+            raise DatabaseException(message="数据库操作失败", detail=str(e))
+        finally:
+            session.close()
+
+    def load_data_with_page(self, query, params, page_size=100000):
+        """分页查询数据"""
+        count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq")
+        query += " LIMIT :limit OFFSET :offset"
+        query = text(query)
+
+        result = self.fetch_one(count_query, params)
+        total_rows = result[0] if result is not None else 0
+
+        if total_rows == 0:
+            logger.debug("Query returned 0 rows")
+            return pd.DataFrame()
+
+        logger.debug(f"Loading {total_rows} rows with page_size={page_size}")
+        data = pd.DataFrame()
+        page = 1
+        while True:
+            offset = (page - 1) * page_size
+            page_params = dict(params)
+            page_params["limit"] = page_size
+            page_params["offset"] = offset
+
+            df = pd.DataFrame(self.fetch_all(query, page_params))
+            if df.empty:
+                break
+            data = pd.concat([data, df], ignore_index=True)
+            page += 1
+
+        logger.debug(f"Loaded {len(data)} rows in {page - 1} pages")
+        return data
+
+    def fetch_all(self, query, params=None):
+        """执行SQL查询并返回所有结果"""
+        with self.get_session() as session:
+            results = session.execute(query, params or {}).fetchall()
+            return results
+
+    def fetch_one(self, query, params=None):
+        """执行SQL查询并返回单条结果"""
+        with self.get_session() as session:
+            result = session.execute(query, params or {}).fetchone()
+            return result
+
+    def insert_data(self, table_name, data_dict):
+        """插入单条数据到指定表"""
+        if not data_dict:
+            return 0
+
+        columns = ", ".join(data_dict.keys())
+        values = ", ".join([f":{key}" for key in data_dict.keys()])
+        query = text(f"INSERT INTO {table_name} ({columns}) VALUES ({values})")
+
+        with self.get_session() as session:
+            result = session.execute(query, data_dict)
+            logger.info(f"Inserted 1 row into {table_name}")
+            return result.rowcount
+
+    def update_data(self, table_name, update_dict, conditions, condition_params=None):
+        """更新表中符合条件的数据"""
+        if not update_dict:
+            return 0
+
+        set_clause = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
+
+        if len(conditions) == 1:
+            where_clause = f"WHERE {conditions[0]}"
+        elif len(conditions) > 1:
+            where_clause = f"WHERE {' AND '.join(conditions)}"
+        else:
+            where_clause = ""
+
+        query = text(f"UPDATE {table_name} SET {set_clause} {where_clause}")
+
+        params = update_dict.copy()
+        if condition_params:
+            params.update(condition_params)
+
+        with self.get_session() as session:
+            result = session.execute(query, params)
+            logger.info(f"Updated {result.rowcount} rows in {table_name}")
+            return result.rowcount
+
+    def execute_query(self, query, params=None):
+        """执行SQL语句"""
+        with self.get_session() as session:
+            session.execute(query, params or {})
+
+    def check_connection(self) -> bool:
+        """检查数据库连接是否正常"""
+        try:
+            self.fetch_one(text("SELECT 1"), {})
+            return True
+        except Exception:
+            return False
+```
+
+- [ ] **Step 2: 验证导入**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "from database.db.mysql import MySqlDatabaseHelper; print('OK')"`
+Expected: `OK`
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add database/db/mysql.py
+git commit -m "refactor(database): add logging, session context manager, env-based config"
+```
+
+---
+
+### Task 7: 改进 `database/db/redis_db.py` — 日志、配置来源、连接池
+
+**Files:**
+- Modify: `database/db/redis_db.py`
+
+- [ ] **Step 1: 重写 redis_db.py**
+
+替换整个文件内容为:
+
+```python
+import redis
+from core import get_logger, settings, DatabaseException
+
+logger = get_logger("database.redis")
+
+
+class RedisDatabaseHelper:
+    _instance = None
+
+    def __new__(cls):
+        if not cls._instance:
+            cls._instance = super(RedisDatabaseHelper, cls).__new__(cls)
+            cls._instance._initialized = False
+        return cls._instance
+
+    def __init__(self):
+        if self._initialized:
+            return
+        try:
+            pool = redis.ConnectionPool(
+                host=settings.redis_host,
+                port=settings.redis_port,
+                password=settings.redis_password,
+                db=settings.redis_db,
+                decode_responses=True,
+                max_connections=50,
+            )
+            self.redis = redis.StrictRedis(connection_pool=pool)
+            self.redis.ping()
+            logger.info("Redis connection established", extra={"extra_data": {"host": settings.redis_host, "db": settings.redis_db}})
+        except redis.ConnectionError as e:
+            logger.error("Failed to connect to Redis", exc_info=True)
+            raise DatabaseException(message="Redis连接失败", detail=str(e))
+        self._initialized = True
+
+    def check_connection(self) -> bool:
+        """检查 Redis 连接是否正常"""
+        try:
+            self.redis.ping()
+            return True
+        except Exception:
+            return False
+```
+
+- [ ] **Step 2: 验证导入**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "from database.db.redis_db import RedisDatabaseHelper; print('OK')"`
+Expected: `OK`
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add database/db/redis_db.py
+git commit -m "refactor(database): redis with logging, connection pool, env config"
+```
+
+---
+
+### Task 8: 改进 `database/dao/mysql_dao.py` — 日志和异常处理
+
+**Files:**
+- Modify: `database/dao/mysql_dao.py`
+
+- [ ] **Step 1: 添加日志和异常处理**
+
+在文件顶部添加 logger,在每个方法中添加日志记录。修复 `get_report_file_id` 的 None 处理。
+
+关键改动:
+- 文件顶部添加: `from core import get_logger` 和 `logger = get_logger("database.dao")`
+- 每个 public 方法入口添加 `logger.info(...)` 记录调用参数
+- `get_report_file_id` 中 result 为 None 时返回空 DataFrame 而非抛异常
+- 移除 `data_preprocess` 方法(不属于 DAO 层)
+
+- [ ] **Step 2: 验证导入**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "from database.dao.mysql_dao import MySqlDao; print('OK')"`
+Expected: `OK`
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add database/dao/mysql_dao.py
+git commit -m "refactor(dao): add logging, fix get_report_file_id null handling"
+```
+
+---
+
+### Task 9: 改进 `run_api.py` — 注册中间件、健康检查、异常处理器
+
+**Files:**
+- Modify: `run_api.py`
+
+- [ ] **Step 1: 重写 run_api.py**
+
+```python
+from api import recommend_router, report_router, eval_report_router
+from core import get_logger, AppException, RequestLoggingMiddleware, get_request_id
+from core.exceptions import DatabaseException
+from database.db.mysql import MySqlDatabaseHelper
+from database.db.redis_db import RedisDatabaseHelper
+from fastapi import FastAPI, Request, status
+from fastapi.exceptions import RequestValidationError
+from fastapi.responses import JSONResponse
+
+import uvicorn
+
+logger = get_logger("app")
+
+app = FastAPI()
+
+app.add_middleware(RequestLoggingMiddleware)
+
+
+@app.exception_handler(RequestValidationError)
+async def validation_exception_handler(request: Request, exc: RequestValidationError):
+    logger.warning(f"Validation error: {exc.errors()}")
+    return JSONResponse(
+        status_code=status.HTTP_400_BAD_REQUEST,
+        content={
+            "code": 400,
+            "msg": "请求参数错误",
+            "data": {"detail": exc.errors(), "body": exc.body},
+            "request_id": get_request_id(),
+        },
+    )
+
+
+@app.exception_handler(AppException)
+async def app_exception_handler(request: Request, exc: AppException):
+    logger.error(f"AppException: {exc.message}", extra={"extra_data": {"detail": exc.detail}})
+    return JSONResponse(
+        status_code=exc.code,
+        content={
+            "code": exc.code,
+            "msg": exc.message,
+            "data": {"detail": exc.detail},
+            "request_id": get_request_id(),
+        },
+    )
+
+
+@app.exception_handler(Exception)
+async def unhandled_exception_handler(request: Request, exc: Exception):
+    logger.error("Unhandled exception", exc_info=True)
+    return JSONResponse(
+        status_code=500,
+        content={
+            "code": 500,
+            "msg": "服务器内部错误",
+            "data": None,
+            "request_id": get_request_id(),
+        },
+    )
+
+
+@app.get("/health")
+async def health_check():
+    """健康检查端点"""
+    mysql_ok = False
+    redis_ok = False
+    try:
+        mysql_ok = MySqlDatabaseHelper().check_connection()
+    except Exception:
+        pass
+    try:
+        redis_ok = RedisDatabaseHelper().check_connection()
+    except Exception:
+        pass
+
+    healthy = mysql_ok and redis_ok
+    return {
+        "status": "healthy" if healthy else "degraded",
+        "mysql": "ok" if mysql_ok else "error",
+        "redis": "ok" if redis_ok else "error",
+    }
+
+
+url_prefix = "/brandcultivation/api/v1"
+
+app.include_router(recommend_router, prefix=url_prefix)
+app.include_router(report_router, prefix=url_prefix)
+app.include_router(eval_report_router, prefix=url_prefix)
+
+if __name__ == "__main__":
+    logger.info("Starting BrandCultivation API server on port 7960")
+    uvicorn.run(app, host="0.0.0.0", port=7960)
+```
+
+- [ ] **Step 2: 验证语法**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "import ast; ast.parse(open('run_api.py').read()); print('Syntax OK')"`
+Expected: `Syntax OK`
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add run_api.py
+git commit -m "refactor(api): add middleware, health check, global exception handlers"
+```
+
+---
+
+### Task 10: 改进 `api/recommend.py` — 日志和错误处理
+
+**Files:**
+- Modify: `api/recommend.py`
+
+- [ ] **Step 1: 添加日志和错误处理**
+
+关键改动:
+- 添加 `from core import get_logger` 和 `logger = get_logger("api.recommend")`
+- 模型不存在时返回 404 状态码
+- `generate_and_upload_report` 加入 try/except + logger.error
+- 推荐过程添加关键日志(开始、召回数量、完成)
+
+```python
+from database import MySqlDao
+from fastapi import APIRouter, BackgroundTasks, HTTPException, status
+from .request_body import RecommendRequest
+from core import get_logger
+
+from models import Recommend
+import os
+from utils import FileStreamUtils, ReportUtils
+
+logger = get_logger("api.recommend")
+dao = MySqlDao()
+router = APIRouter()
+
+
+@router.post("/recommend")
+async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks):
+    """推荐接口"""
+    logger.info(f"Recommend request: city={request.city_uuid}, product={request.product_code}, recall={request.recall_cust_count}")
+
+    gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
+    if not os.path.exists(gbdtlr_model_path):
+        logger.warning(f"Model not found: {gbdtlr_model_path}")
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="该城市的模型未训练,请先进行训练",
+        )
+
+    recommend_model = Recommend(request.city_uuid)
+
+    products_in_order = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
+    if request.product_code in products_in_order:
+        logger.info(f"Using GBDT-LR model for existing product {request.product_code}")
+        recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count)
+    else:
+        logger.info(f"Using Item2Vec model for new product {request.product_code}")
+        recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count)
+
+    recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count)
+    request_data = []
+    for index, data in enumerate(recommend_data):
+        request_data.append(
+            {
+                "id": index + 1,
+                "cust_code": data["cust_code"],
+                "recommend_score": data["recommend_score"],
+                "delivery_count": data["delivery_count"],
+            }
+        )
+
+    logger.info(f"Recommend completed: {len(request_data)} customers recommended")
+
+    backgroundTasks.add_task(generate_and_upload_report, request)
+
+    return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
+
+
+def generate_and_upload_report(request: RecommendRequest):
+    """生成并上传报告到阿里云文件数据库"""
+    logger.info(f"Background task started: generating report for {request.city_uuid}/{request.product_code}")
+    try:
+        report_util = ReportUtils(request.city_uuid, request.product_code)
+        report_util.generate_all_data(request.recall_cust_count, request.delivery_count)
+
+        reports_dir = os.path.join("./data/reports", request.city_uuid, request.product_code)
+        report_files = ["卷烟信息表", "品规商户特征关系表", "相似卷烟表", "商户售卖推荐表"]
+        file_id_map = FileStreamUtils.upload_files(reports_dir, report_files)
+
+        if file_id_map is None:
+            logger.error(f"Report upload failed for {request.city_uuid}/{request.product_code}")
+            return
+
+        data_dict = {
+            "cultivacation_id": request.cultivacation_id,
+            "city_uuid": request.city_uuid,
+            "limit_cycle_name": request.limit_cycle_name,
+            "product_code": request.product_code,
+            "product_info_table": file_id_map.get("卷烟信息表"),
+            "relation_table": file_id_map.get("品规商户特征关系表"),
+            "similarity_product_table": file_id_map.get("相似卷烟表"),
+            "recommend_table": file_id_map.get("商户售卖推荐表"),
+        }
+        dao.insert_report(data_dict)
+        logger.info(f"Background task completed: report uploaded for {request.city_uuid}/{request.product_code}")
+    except Exception as e:
+        logger.error(f"Background task failed: {e}", exc_info=True)
+```
+
+- [ ] **Step 2: 验证语法**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "import ast; ast.parse(open('api/recommend.py').read()); print('Syntax OK')"`
+Expected: `Syntax OK`
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add api/recommend.py
+git commit -m "refactor(api): add logging and error handling to recommend endpoint"
+```
+
+---
+
+### Task 11: 改进 `api/eval_report.py` 和 `api/report.py` — 日志
+
+**Files:**
+- Modify: `api/eval_report.py`
+- Modify: `api/report.py`
+
+- [ ] **Step 1: 改进 eval_report.py**
+
+添加 `from core import get_logger` 和 `logger = get_logger("api.eval_report")`,在每个步骤添加日志。
+
+- [ ] **Step 2: 改进 report.py**
+
+添加 `from core import get_logger` 和 `logger = get_logger("api.report")`,在关键步骤添加日志。
+
+- [ ] **Step 3: 验证语法**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "import ast; ast.parse(open('api/eval_report.py').read()); ast.parse(open('api/report.py').read()); print('OK')"`
+Expected: `OK`
+
+- [ ] **Step 4: Commit**
+
+```bash
+git add api/eval_report.py api/report.py
+git commit -m "refactor(api): add logging to eval_report and report endpoints"
+```
+
+---
+
+### Task 12: 修复 `models/rank/gbdt_lr_inference.py` — Bug 修复 + 日志
+
+**Files:**
+- Modify: `models/rank/gbdt_lr_inference.py`
+
+- [ ] **Step 1: 修复 get_recommend_list 方法**
+
+将 `get_recommend_list` 方法中的循环从:
+
+```python
+recommend_list = []
+for cust_id, score in zip(recall_list, scores):
+    recommend_list.append({cust_id: float(score)})
+    recommend_list.append({"cust_code": cust_id, "recommend_score": score})
+
+recommend_list = sorted(
+    [item for item in recommend_list if "recommend_score" in item],
+    key=lambda x: x["recommend_score"],
+    reverse=True
+)
+```
+
+修改为:
+
+```python
+recommend_list = [
+    {"cust_code": cust_id, "recommend_score": float(score)}
+    for cust_id, score in zip(recall_list, scores)
+]
+recommend_list.sort(key=lambda x: x["recommend_score"], reverse=True)
+```
+
+- [ ] **Step 2: 添加日志**
+
+在文件顶部添加 `from core import get_logger` 和 `logger = get_logger("models.rank.gbdtlr")`。
+在 `get_recommend_list` 中添加推理耗时日志。
+在 `generate_shap_interance` 中替换 print 为 logger。
+
+- [ ] **Step 3: 验证语法**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "import ast; ast.parse(open('models/rank/gbdt_lr_inference.py').read()); print('OK')"`
+Expected: `OK`
+
+- [ ] **Step 4: Commit**
+
+```bash
+git add models/rank/gbdt_lr_inference.py
+git commit -m "fix(models): fix double-append bug in get_recommend_list, add logging"
+```
+
+---
+
+### Task 13: 改进 `models/recommend.py` — 日志
+
+**Files:**
+- Modify: `models/recommend.py`
+
+- [ ] **Step 1: 添加日志**
+
+添加 `from core import get_logger` 和 `logger = get_logger("models.recommend")`。
+在关键方法中添加日志:
+- `_load_molde`: 模型加载完成
+- `get_recal_cust`: 召回数量
+- `get_recommend_list_by_gbdtlr`: 开始/完成 + 耗时
+- `get_recommend_list_by_item2vec`: 开始/完成 + 耗时
+- `get_recommend_and_delivery`: 投放分配完成
+
+- [ ] **Step 2: 验证语法**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "import ast; ast.parse(open('models/recommend.py').read()); print('OK')"`
+Expected: `OK`
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add models/recommend.py
+git commit -m "refactor(models): add logging to recommend module"
+```
+
+---
+
+### Task 14: 改进召回模块 — 替换 print 为 logger
+
+**Files:**
+- Modify: `models/recall/hot_recall.py`
+- Modify: `models/recall/itemCF/ItemCF.py`
+- Modify: `models/item2vec/inference.py`
+
+- [ ] **Step 1: 改进 hot_recall.py**
+
+替换 `print("hot_recall: ...")` 为 `logger.info(...)`。添加 `from core import get_logger` 和 `logger = get_logger("models.recall.hot")`。
+
+- [ ] **Step 2: 改进 ItemCF.py**
+
+替换 `print(...)` 为 `logger.info(...)`。添加 `from core import get_logger` 和 `logger = get_logger("models.recall.itemcf")`。
+
+- [ ] **Step 3: 改进 inference.py (item2vec)**
+
+添加 `from core import get_logger` 和 `logger = get_logger("models.item2vec")`。在关键步骤添加日志。
+
+- [ ] **Step 4: 验证语法**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "import ast; ast.parse(open('models/recall/hot_recall.py').read()); ast.parse(open('models/recall/itemCF/ItemCF.py').read()); ast.parse(open('models/item2vec/inference.py').read()); print('OK')"`
+Expected: `OK`
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add models/recall/hot_recall.py models/recall/itemCF/ItemCF.py models/item2vec/inference.py
+git commit -m "refactor(models): replace print with logger in recall modules"
+```
+
+---
+
+### Task 15: 改进 `utils/file_stream.py` — 日志和错误处理
+
+**Files:**
+- Modify: `utils/file_stream.py`
+
+- [ ] **Step 1: 添加日志和改进错误处理**
+
+添加 `from core import get_logger` 和 `logger = get_logger("utils.file_stream")`。
+- 上传时记录 file_id 和耗时
+- 下载时记录 HTTP 状态码
+- 失败时记录具体错误信息(替换 print)
+- 使用 `settings` 获取 URL 配置
+
+```python
+import time
+from core import get_logger, settings
+from core.exceptions import FileServiceException
+from io import BytesIO
+import os
+import pandas as pd
+import requests
+
+logger = get_logger("utils.file_stream")
+
+
+class FileStreamUtils:
+    upload_url = settings.file_upload_url
+    download_url = settings.file_download_url
+    headers = {
+        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
+        "Accept": "*/*",
+    }
+
+    @staticmethod
+    def upload_files(reports_dir, files):
+        files_id = {}
+        for filename in files:
+            file_path = os.path.join(reports_dir, f"{filename}.xlsx")
+            start_time = time.time()
+            try:
+                with open(file_path, "rb") as f:
+                    upload_files = {"file": (os.path.basename(file_path), f)}
+                    response = requests.post(
+                        FileStreamUtils.upload_url,
+                        headers=FileStreamUtils.headers,
+                        files=upload_files,
+                        verify=True,
+                    )
+                    duration_ms = (time.time() - start_time) * 1000
+                    if response.json().get("success"):
+                        file_id = response.json()["data"]["file_info"]["fileid"]
+                        files_id[filename] = file_id
+                        logger.info(f"File uploaded: {filename} -> {file_id} ({duration_ms:.0f}ms)")
+                    else:
+                        logger.error(f"Upload failed for {filename}: {response.text}")
+                        return None
+            except requests.exceptions.RequestException as e:
+                logger.error(f"Upload request error for {filename}: {e}", exc_info=True)
+                return None
+            except Exception as e:
+                logger.error(f"Upload error for {filename}: {e}", exc_info=True)
+                return None
+        return files_id
+
+    @staticmethod
+    def download_file(file_id, file_type="xlsx"):
+        """通过file_id从阿里云文件数据库下载文件"""
+        start_time = time.time()
+        try:
+            response = requests.get(
+                f"{FileStreamUtils.download_url}/{file_id}",
+                headers=FileStreamUtils.headers,
+                verify=True,
+            )
+            duration_ms = (time.time() - start_time) * 1000
+
+            if response.status_code == 200:
+                file_content = BytesIO(response.content)
+                if file_type == "xlsx":
+                    data = pd.read_excel(file_content, engine="openpyxl")
+                elif file_type == "csv":
+                    data = pd.read_csv(file_content)
+                else:
+                    raise ValueError(f"不支持的文件类型:{file_type}")
+                logger.info(f"File downloaded: {file_id} ({duration_ms:.0f}ms, {len(response.content)} bytes)")
+                return data
+            else:
+                logger.error(f"Download failed: file_id={file_id}, status={response.status_code}")
+                return None
+        except requests.exceptions.RequestException as e:
+            logger.error(f"Download request error: file_id={file_id}, error={e}", exc_info=True)
+            return None
+        except Exception as e:
+            logger.error(f"Download error: file_id={file_id}, error={e}", exc_info=True)
+            return None
+```
+
+- [ ] **Step 2: 验证语法**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "import ast; ast.parse(open('utils/file_stream.py').read()); print('OK')"`
+Expected: `OK`
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add utils/file_stream.py
+git commit -m "refactor(utils): add logging and error details to file_stream"
+```
+
+---
+
+### Task 16: 改进 `utils/report_utils.py` 和 `train.py` — 日志
+
+**Files:**
+- Modify: `utils/report_utils.py`
+- Modify: `train.py`
+- Modify: `models/rank/data/preprocess.py`
+
+- [ ] **Step 1: 改进 report_utils.py**
+
+添加 `from core import get_logger` 和 `logger = get_logger("utils.report")`。
+每个 `generate_*` 方法添加开始/完成日志。
+
+- [ ] **Step 2: 改进 train.py**
+
+替换所有 `print(...)` 为 `logger.info(...)`。添加 `from core import get_logger` 和 `logger = get_logger("train")`。
+
+- [ ] **Step 3: 改进 preprocess.py**
+
+替换 `print("gbdr-lr: ...")` 为 `logger.info(...)`。添加 `from core import get_logger` 和 `logger = get_logger("models.rank.preprocess")`。
+
+- [ ] **Step 4: 验证语法**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "import ast; ast.parse(open('utils/report_utils.py').read()); ast.parse(open('train.py').read()); ast.parse(open('models/rank/data/preprocess.py').read()); print('OK')"`
+Expected: `OK`
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add utils/report_utils.py train.py models/rank/data/preprocess.py
+git commit -m "refactor: replace print with logger in report_utils, train, preprocess"
+```
+
+---
+
+### Task 17: 配置文件变更 — 移除密码、创建 .env.example
+
+**Files:**
+- Modify: `config/database_config.yaml`
+- Create: `.env.example`
+- Modify: `config/config.py`
+
+- [ ] **Step 1: 更新 database_config.yaml(移除密码)**
+
+```yaml
+mysql:
+  host: 'rm-t4n6rz18y4t5x47y70o.mysql.singapore.rds.aliyuncs.com'
+  port: 3036
+  db: 'brand_cultivation'
+  user: 'BrandCultivation'
+  # passwd 已移至环境变量 MYSQL_PASSWORD
+
+redis:
+  host: 'r-t4nb4n9i8je7u6ogk1pd.redis.singapore.rds.aliyuncs.com'
+  port: 5000
+  db: 10
+  # passwd 已移至环境变量 REDIS_PASSWORD
+```
+
+- [ ] **Step 2: 创建 .env.example**
+
+```bash
+# BrandCultivation 环境变量配置
+# 复制此文件为 .env 并填入实际值
+
+# MySQL
+MYSQL_HOST=rm-t4n6rz18y4t5x47y70o.mysql.singapore.rds.aliyuncs.com
+MYSQL_PORT=3036
+MYSQL_USER=BrandCultivation
+MYSQL_PASSWORD=your_mysql_password_here
+MYSQL_DB=brand_cultivation
+
+# Redis
+REDIS_HOST=r-t4nb4n9i8je7u6ogk1pd.redis.singapore.rds.aliyuncs.com
+REDIS_PORT=5000
+REDIS_PASSWORD=your_redis_password_here
+REDIS_DB=10
+
+# Logging
+LOG_LEVEL=INFO
+
+# File Service
+FILE_UPLOAD_URL=http://file-center.jcpt:8080/file/fileUpload
+FILE_DOWNLOAD_URL=http://file-center.jcpt:8080/file/fileDownload
+```
+
+- [ ] **Step 3: 更新 config/config.py 为兼容层**
+
+保留旧接口以兼容未迁移的代码:
+
+```python
+from core.config import settings
+
+def load_config():
+    return {
+        "mysql": {
+            "host": settings.mysql_host,
+            "port": settings.mysql_port,
+            "user": settings.mysql_user,
+            "passwd": settings.mysql_password,
+            "db": settings.mysql_db,
+        },
+        "redis": {
+            "host": settings.redis_host,
+            "port": settings.redis_port,
+            "passwd": settings.redis_password,
+            "db": settings.redis_db,
+        },
+    }
+
+def load_model_config():
+    return settings.model_config
+
+def load_service_config():
+    return {
+        "aliyun": {
+            "upload_url": settings.file_upload_url,
+            "download_url": settings.file_download_url,
+        }
+    }
+```
+
+- [ ] **Step 4: 确保 .gitignore 包含 .env**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && grep -q "^\.env$" .gitignore 2>/dev/null || echo ".env" >> .gitignore`
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add config/database_config.yaml .env.example config/config.py .gitignore
+git commit -m "security: remove passwords from yaml, add .env.example"
+```
+
+---
+
+### Task 18: 最终验证
+
+**Files:** None (verification only)
+
+- [ ] **Step 1: 验证所有模块可导入**
+
+Run:
+```bash
+cd /home/yangzeyu/project/BrandCultivation && python -c "
+from core import get_logger, settings, AppException, DatabaseException, RequestLoggingMiddleware
+from database.db.mysql import MySqlDatabaseHelper
+from database.db.redis_db import RedisDatabaseHelper
+from database.dao.mysql_dao import MySqlDao
+from api.recommend import router as rec_router
+from api.report import router as rep_router
+from api.eval_report import router as eval_router
+from models.recommend import Recommend
+from utils.file_stream import FileStreamUtils
+from utils.report_utils import ReportUtils
+print('All imports successful')
+"
+```
+Expected: `All imports successful`
+
+- [ ] **Step 2: 验证 FastAPI 应用可启动(语法级别)**
+
+Run: `cd /home/yangzeyu/project/BrandCultivation && python -c "import ast; ast.parse(open('run_api.py').read()); print('App syntax OK')"`
+Expected: `App syntax OK`
+
+- [ ] **Step 3: 最终 Commit(如有遗漏文件)**
+
+```bash
+git status
+# 如有未提交的改动,补充提交
+```

+ 238 - 0
docs/superpowers/specs/2026-05-21-project-refactoring-design.md

@@ -0,0 +1,238 @@
+# BrandCultivation 项目级重构设计
+
+## 概述
+
+对 BrandCultivation(卷烟品牌培育推荐系统)进行全面重构,解决安全漏洞、缺失日志、错误处理不足、代码缺陷等问题。采用渐进式重构策略,保持现有目录结构,新增基础设施层。
+
+## 当前问题清单
+
+### 安全漏洞
+- `database_config.yaml` 明文存储 MySQL/Redis 密码
+- API 无认证中间件(内网服务可接受,但需加基础防护)
+- 无请求频率限制
+
+### 代码缺陷
+- `gbdt_lr_inference.py:60-77` — `get_recommend_list` 每次循环 append 两种格式字典后再过滤
+- 全局使用 `print()` 无日志系统
+- 后台任务 `generate_and_upload_report` 失败时静默无通知
+- 配置加载使用相对路径 `./config/...`,CWD 变化即崩溃
+- `get_report_file_id` 中 result 为 None 时抛 TypeError
+
+### 架构问题
+- 无请求追踪(request_id)
+- 无健康检查端点
+- 模块级别实例化(`cfgs = load_config()` 在 import 时执行)
+- Session 管理重复 try/finally 模式
+
+## 设计方案
+
+### 1. 基础设施层 (`core/`)
+
+#### `core/__init__.py`
+导出公共接口:`get_logger`, `settings`, `AppException` 等。
+
+#### `core/logging.py` — 统一日志系统
+
+```python
+# 基于 Python 标准库 logging
+# JSON 格式输出,便于日志收集系统解析
+# 工厂函数 get_logger(name) 按模块名创建 logger
+# 日志级别通过环境变量 LOG_LEVEL 控制(默认 INFO)
+# 输出字段:时间戳、级别、模块名、函数名、行号、消息、request_id(如有)
+```
+
+日志级别规范:
+- DEBUG: 数据库查询参数、模型推理中间结果
+- INFO: 请求开始/结束、任务开始/完成、关键业务步骤
+- WARNING: 非致命异常(如文件下载重试)
+- ERROR: 操作失败(含完整 traceback)
+
+#### `core/config.py` — 配置管理
+
+- 保留 YAML 文件作为默认值和非敏感配置
+- 敏感信息(密码)通过环境变量注入:`MYSQL_PASSWORD`、`REDIS_PASSWORD`
+- 使用 `pathlib` 基于 `__file__` 定位项目根目录,解决相对路径问题
+- 配置加载一次后缓存(模块级单例)
+- 支持 Docker 环境变量覆盖所有数据库连接参数
+
+环境变量清单:
+```
+MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DB
+REDIS_HOST, REDIS_PORT, REDIS_PASSWORD, REDIS_DB
+LOG_LEVEL (default: INFO)
+FILE_UPLOAD_URL, FILE_DOWNLOAD_URL
+```
+
+#### `core/exceptions.py` — 自定义异常体系
+
+```python
+class AppException(Exception):
+    """应用异常基类"""
+    def __init__(self, code: int, message: str, detail: str = None): ...
+
+class DatabaseException(AppException): ...    # 数据库操作失败
+class ModelException(AppException): ...       # 模型推理失败
+class FileServiceException(AppException): ... # 文件服务失败
+class ValidationException(AppException): ...  # 业务校验失败
+```
+
+#### `core/middleware.py` — 请求中间件
+
+- 为每个请求生成 UUID4 作为 request_id
+- 通过 `contextvars.ContextVar` 传递 request_id,所有下游日志自动携带
+- 记录请求开始(method, path, client_ip)和结束(status_code, 耗时ms)
+- 异常时记录完整 traceback
+
+### 2. 数据库层改进
+
+#### `database/db/mysql.py`
+
+- 替换 `print(f"error: {e}")` 为 `logger.error("...", exc_info=True)`
+- 连接字符串从 `core/config.py` 获取(支持环境变量覆盖)
+- `load_data_with_page` 中 tqdm 改为 `logger.debug` 输出
+- Session 管理改为 context manager:
+
+```python
+@contextmanager
+def get_session(self):
+    session = self._DBSession()
+    try:
+        yield session
+        session.commit()
+    except SQLAlchemyError as e:
+        session.rollback()
+        logger.error("Database operation failed", exc_info=True)
+        raise DatabaseException(500, "数据库操作失败", str(e))
+    finally:
+        session.close()
+```
+
+#### `database/db/redis_db.py`
+
+- 从 `core/config.py` 获取配置
+- 添加连接池配置(`max_connections=50`)
+- 连接失败时记录日志并抛出 `DatabaseException`
+
+#### `database/dao/mysql_dao.py`
+
+- 每个方法入口记录 `logger.info("Loading product data", extra={"city_uuid": city_uuid})`
+- 异常时抛出 `DatabaseException` 而非静默返回空 DataFrame
+- 修复 `get_report_file_id` 中 result 为 None 时的处理
+
+### 3. API 层改进
+
+#### 全局异常处理器
+
+```python
+@app.exception_handler(AppException)
+async def app_exception_handler(request, exc):
+    return JSONResponse(
+        status_code=exc.code,
+        content={"code": exc.code, "msg": exc.message, "data": {"detail": exc.detail}, "request_id": get_request_id()}
+    )
+
+@app.exception_handler(Exception)
+async def unhandled_exception_handler(request, exc):
+    logger.error("Unhandled exception", exc_info=True)
+    return JSONResponse(
+        status_code=500,
+        content={"code": 500, "msg": "服务器内部错误", "data": None, "request_id": get_request_id()}
+    )
+```
+
+#### 健康检查端点
+
+```python
+@app.get("/health")
+async def health_check():
+    # 检查 MySQL 和 Redis 连接状态
+    return {"status": "healthy", "mysql": "ok", "redis": "ok"}
+```
+
+#### `api/recommend.py`
+
+- 后台任务加入 try/except + logger.error
+- 模型不存在时返回 HTTP 404(而非 200 + 错误消息)
+- 添加推荐过程的关键日志
+
+#### `api/eval_report.py`
+
+- 每个步骤添加日志
+- 文件下载失败时记录具体错误
+
+### 4. Bug 修复
+
+#### `models/rank/gbdt_lr_inference.py:60-77`
+
+修复前(每次循环 append 两个字典,然后过滤):
+```python
+for cust_id, score in zip(recall_list, scores):
+    recommend_list.append({cust_id: float(score)})
+    recommend_list.append({"cust_code": cust_id, "recommend_score": score})
+
+recommend_list = sorted(
+    [item for item in recommend_list if "recommend_score" in item],
+    key=lambda x: x["recommend_score"], reverse=True
+)
+```
+
+修复后:
+```python
+recommend_list = [
+    {"cust_code": cust_id, "recommend_score": float(score)}
+    for cust_id, score in zip(recall_list, scores)
+]
+recommend_list.sort(key=lambda x: x["recommend_score"], reverse=True)
+```
+
+### 5. 模型层和工具层
+
+#### 模型层
+
+- `models/recommend.py` — 记录召回数量、排序结果数量、各步骤耗时
+- `models/recall/hot_recall.py` — 替换 print 为 logger
+- `models/recall/itemCF/ItemCF.py` — 替换 print 为 logger,记录 Redis 写入状态
+- `models/rank/gbdt_lr_inference.py` — 添加推理耗时日志
+- `models/item2vec/inference.py` — 添加相似度计算日志
+
+#### 工具层
+
+- `utils/file_stream.py` — 记录上传/下载的 file_id、耗时、HTTP 状态码
+- `utils/report_utils.py` — 每个报告生成步骤记录开始/完成/耗时
+
+#### 训练脚本
+
+- `train.py` — 替换 print 为 logger,记录训练全流程
+
+### 6. 配置文件变更
+
+- `database_config.yaml` — 移除密码,改为占位符 `passwd: "${MYSQL_PASSWORD}"`
+- 新增 `.env.example` — 列出所有环境变量及说明
+- 新增 `core/__init__.py` — 导出公共接口
+
+## 目录结构变更
+
+```
+BrandCultivation/
+├── core/                    # 新增:基础设施层
+│   ├── __init__.py
+│   ├── logging.py           # 日志系统
+│   ├── config.py            # 配置管理
+│   ├── exceptions.py        # 异常定义
+│   └── middleware.py        # 请求中间件
+├── api/                     # 改进:添加日志和错误处理
+├── database/                # 改进:日志、异常、session 管理
+├── models/                  # 改进:日志、bug 修复
+├── utils/                   # 改进:日志
+├── config/                  # 改进:移除敏感信息
+├── .env.example             # 新增:环境变量模板
+└── run_api.py               # 改进:注册中间件和健康检查
+```
+
+## 不在本次范围内
+
+- API 认证/鉴权(内网服务)
+- 数据库 ORM 模型定义(当前 raw SQL + pandas 模式保持不变)
+- ML 模型算法调整
+- 单元测试(可作为后续迭代)
+- CI/CD 配置

+ 5 - 0
models/item2vec/inference.py

@@ -5,6 +5,9 @@ from models.rank.data.utils import sample_data_clear
 import numpy as np
 import pandas as pd
 from sklearn.preprocessing import StandardScaler
+from core import get_logger
+
+logger = get_logger("models.item2vec")
 
 class Item2VecModel:
     def __init__(self, city_uuid):
@@ -14,6 +17,7 @@ class Item2VecModel:
         
     def generate_product_similarity_map(self, product_code):
         """根据product_code生成卷烟相似度矩阵"""
+        logger.info(f"Generating similarity map for product {product_code}")
         product = self._dao.get_product_by_id(self._city_uuid, product_code)[ProductConfig.FEATURE_COLUMNS]
         product = sample_data_clear(product, ProductConfig)
         
@@ -33,6 +37,7 @@ class Item2VecModel:
     
     def get_recommend_cust_list(self, product_code, top=100):
         """获取推荐的商户列表"""
+        logger.info(f"Getting recommend list for product {product_code}, top={top}")
         product_list = self.get_similarity_list(product_code)
         order_data = self._dao.get_order_by_product_ids(self._city_uuid, product_list)[OrderConfig.FEATURE_COLUMNS]
         order_data["sale_qty"] = order_data["sale_qty"].fillna(0)

+ 6 - 3
models/rank/data/preprocess.py

@@ -5,16 +5,19 @@ import pandas as pd
 from sklearn.preprocessing import MinMaxScaler
 from sklearn.utils import shuffle
 import numpy as np
+from core import get_logger
+
+logger = get_logger("models.rank.preprocess")
 
 class DataProcess():
     def __init__(self, city_uuid, save_dir):
         self._mysql_dao = MySqlDao()
         self.save_dir = save_dir
-        print("gbdr-lr: 正在加载cust_info...")
+        logger.info("Loading cust_info")
         self._cust_data = self._mysql_dao.load_cust_data(city_uuid)
-        print("gbdr-lr: 正在加载product_info...")
+        logger.info("Loading product_info")
         self._product_data = self._mysql_dao.load_product_data(city_uuid)
-        print("gbdr-lr: 正在加载order_info...")
+        logger.info("Loading order_info")
         self._order_data = self._mysql_dao.load_order_data(city_uuid)
         # self._order_data = self._mysql_dao.load_mock_order_data()
         # print("gbdr-lr: 正在加载shopping_info...")

+ 14 - 14
models/rank/gbdt_lr_inference.py

@@ -9,6 +9,9 @@ from models.rank.data.utils import one_hot_embedding, sample_data_clear
 import numpy as np
 import pandas as pd
 from sklearn.preprocessing import StandardScaler
+from core import get_logger
+
+logger = get_logger("models.rank.gbdtlr")
 
 def clean_column_name(col):
     """清理列名中的特殊字符,与 one_hot_embedding 保持一致"""
@@ -58,22 +61,19 @@ class GbdtLrModel:
         self.custs_data = self._mysql_dao.load_cust_data(city_uuid)[CustConfig.FEATURE_COLUMNS]
     
     def get_recommend_list(self, recommend_sample, recall_list):
-        
+
         gbdt_preds = self.gbdt_model.predict(recommend_sample, pred_leaf=True)
         gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
         scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1] * 100
-        
-        recommend_list = []
-        for cust_id, score in zip(recall_list, scores):
-            recommend_list.append({cust_id: float(score)})
-            recommend_list.append({"cust_code": cust_id, "recommend_score": score})
-            
-        recommend_list = sorted(
-            [item for item in recommend_list if "recommend_score" in item],
-            key=lambda x: x["recommend_score"],
-            reverse=True
-        )
-        
+
+        recommend_list = [
+            {"cust_code": cust_id, "recommend_score": float(score)}
+            for cust_id, score in zip(recall_list, scores)
+        ]
+        recommend_list.sort(key=lambda x: x["recommend_score"], reverse=True)
+
+        logger.info(f"Scored {len(recommend_list)} items in recommend list")
+
         return recommend_list
         
     
@@ -236,7 +236,7 @@ class GbdtLrModel:
                 os.remove(temp_file)
                 os.rmdir(temp_dir)
             except Exception as e:
-                print(f"清理临时文件时出错: {e}")
+                logger.error(f"清理临时文件时出错: {e}")
     
 if __name__ == "__main__":
     model_path = "./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl"

+ 4 - 1
models/recall/hot_recall.py

@@ -5,6 +5,9 @@ from tqdm import tqdm
 
 from models.rank.data.config import OrderConfig
 import numpy as np
+from core import get_logger
+
+logger = get_logger("models.recall.hot")
 
 cfgs = load_model_config()
 
@@ -18,7 +21,7 @@ class HotRecallModel:
     
     def _load_data(self):
         """加载订单记录表"""
-        print("hot_recall: 正在加载order_info...")
+        logger.info("Loading order data")
         self._order_data = self._dao.load_order_data(self._city_uuid)
         self._order_data =self._order_data[OrderConfig.FEATURE_COLUMNS] 
         

+ 7 - 4
models/recall/itemCF/ItemCF.py

@@ -5,6 +5,9 @@ import numpy as np
 from tqdm import tqdm
 from scipy.sparse import csr_matrix
 from joblib import Parallel, delayed
+from core import get_logger
+
+logger = get_logger("models.recall.itemcf")
 
 class ItemCFModel:
     def __init__(self):
@@ -14,11 +17,11 @@ class ItemCFModel:
     def train(self, city_uuid, n=300, k=100, top_n=300, n_jobs=4):
         # self._score_df = pd.read_csv(score_path)
         # self._similarity_df = pd.read_csv(similatity_path, index_col=0)
-        print("itemcf: 正在加载order_info...")
+        logger.info("Loading order data")
         self._order_data = self._dao.load_order_data(city_uuid)
-        print("正在计算品规培育分数...")
+        logger.info("Calculating product scores")
         self._score_df = UserItemScore(self._order_data).generate_product_scores()
-        print("正在计算商户相似度矩阵...")
+        logger.info("Calculating similarity matrix")
         self._similarity_df = SimilarityMatrix(self._order_data).generate_similarity_matrix()
         
         similarity_matrix = csr_matrix(self._similarity_df.values)
@@ -87,7 +90,7 @@ class ItemCFModel:
                     try:
                         zset_data[shop_id] = float(score)
                     except ValueError as e:
-                        print(f"Error converting score to float for shop_id {shop_id}: {score}")
+                        logger.error(f"Error converting score to float for shop_id {shop_id}: {score}")
                         raise e
             
             redis_db.redis.zadd(redis_key, zset_data)

+ 11 - 1
models/recommend.py

@@ -6,6 +6,9 @@ from models.rank.data.config import CustConfig, ProductConfig
 from models.rank.data.utils import sample_data_clear
 from models.rank import GbdtLrModel, generate_feats_map
 import pandas as pd
+from core import get_logger
+
+logger = get_logger("models.recommend")
 
 
 class Recommend:
@@ -21,6 +24,7 @@ class Recommend:
         gbdtlr_model_path = os.path.join("./models/rank/weights", city_uuid, "gbdtlr_model.pkl")
         self._gbdtlr_model = GbdtLrModel(gbdtlr_model_path)
         self._item2vec_model = Item2VecModel(city_uuid)
+        logger.info(f"Models loaded for city_uuid={city_uuid}")
         
     def _get_itemcf_recall(self, product_id):
         """协同召回"""
@@ -46,11 +50,13 @@ class Recommend:
             additional_items = [item for item in hot_recall_list if item in hot_recall_set]
             needed = recall_count - len(result)
             result.extend(additional_items[:needed])
-            
+
+        logger.info(f"Recall completed: {len(result)} customers for product {product_id}")
         return result[:recall_count]
     
     def get_recommend_list_by_gbdtlr(self, product_id, recall_count=500):
         """根据gbdt_lr获取商户推荐列表"""
+        logger.info(f"GBDT-LR recommend started for product {product_id}")
         # 获取召回的商户列表
         recall_cust_list = self.get_recal_cust(product_id, recall_count)
         # 获取卷烟数据
@@ -77,15 +83,18 @@ class Recommend:
         feats_map = generate_feats_map(product_data, cust_data)
         recommend_list = self._gbdtlr_model.get_recommend_list(feats_map, ordered_recall_list)
         # recommend_list = self.filter_recommend_list(recommend_list)
+        logger.info(f"GBDT-LR recommend completed: {len(recommend_list)} results")
         return recommend_list
     
     def get_recommend_list_by_item2vec(self, product_id, recall_count=500):
         """根据item2vec获取商户推荐列表"""
+        logger.info(f"Item2Vec recommend started for product {product_id}")
         recommend_list = self._item2vec_model.get_recommend_cust_list(product_id, top=recall_count)
         recommend_list = recommend_list.drop(columns=["sale_qty"])
         recommend_list = recommend_list.to_dict(orient='records')
         recommend_list = recommend_list[:recall_count]
         # recommend_list = self.filter_recommend_list(recommend_list)
+        logger.info(f"Item2Vec recommend completed: {len(recommend_list)} results")
         return recommend_list
         
     def filter_recommend_list(self, recommend_list):
@@ -120,6 +129,7 @@ class Recommend:
         recommend_data = recommend_data.sort_values(["recommend_score", "cust_code"], ascending=[False, True])
         
         recommend_data = recommend_data.to_dict(orient='records')
+        logger.info(f"Delivery allocation completed for {len(recommend_data)} customers, total={delivery_count}")
         return recommend_data
         
     

+ 65 - 11
run_api.py

@@ -1,35 +1,89 @@
 from api import recommend_router, report_router, eval_report_router
+from core import get_logger, AppException, RequestLoggingMiddleware, get_request_id
+from database.db.mysql import MySqlDatabaseHelper
+from database.db.redis_db import RedisDatabaseHelper
 from fastapi import FastAPI, Request, status
 from fastapi.exceptions import RequestValidationError
 from fastapi.responses import JSONResponse
 
 import uvicorn
 
+logger = get_logger("app")
+
 app = FastAPI()
 
-# 添加全局异常处理器
+app.add_middleware(RequestLoggingMiddleware)
+
+
 @app.exception_handler(RequestValidationError)
 async def validation_exception_handler(request: Request, exc: RequestValidationError):
+    logger.warning(f"Validation error: {exc.errors()}")
     return JSONResponse(
         status_code=status.HTTP_400_BAD_REQUEST,
         content={
             "code": 400,
             "msg": "请求参数错误",
-            "data": {
-                "detail": exc.errors(),
-                "body": exc.body
-            }
+            "data": {"detail": exc.errors(), "body": exc.body},
+            "request_id": get_request_id(),
         },
     )
 
-url_prefix = '/brandcultivation/api/v1'
-  
-# 注册路由
+
+@app.exception_handler(AppException)
+async def app_exception_handler(request: Request, exc: AppException):
+    logger.error(f"AppException: {exc.message}", extra={"extra_data": {"detail": exc.detail}})
+    return JSONResponse(
+        status_code=exc.code,
+        content={
+            "code": exc.code,
+            "msg": exc.message,
+            "data": {"detail": exc.detail},
+            "request_id": get_request_id(),
+        },
+    )
+
+
+@app.exception_handler(Exception)
+async def unhandled_exception_handler(request: Request, exc: Exception):
+    logger.error("Unhandled exception", exc_info=True)
+    return JSONResponse(
+        status_code=500,
+        content={
+            "code": 500,
+            "msg": "服务器内部错误",
+            "data": None,
+            "request_id": get_request_id(),
+        },
+    )
+
+
+@app.get("/health")
+async def health_check():
+    mysql_ok = False
+    redis_ok = False
+    try:
+        mysql_ok = MySqlDatabaseHelper().check_connection()
+    except Exception:
+        pass
+    try:
+        redis_ok = RedisDatabaseHelper().check_connection()
+    except Exception:
+        pass
+
+    healthy = mysql_ok and redis_ok
+    return {
+        "status": "healthy" if healthy else "degraded",
+        "mysql": "ok" if mysql_ok else "error",
+        "redis": "ok" if redis_ok else "error",
+    }
+
+
+url_prefix = "/brandcultivation/api/v1"
+
 app.include_router(recommend_router, prefix=url_prefix)
 app.include_router(report_router, prefix=url_prefix)
 app.include_router(eval_report_router, prefix=url_prefix)
-    
+
 if __name__ == "__main__":
+    logger.info("Starting BrandCultivation API server on port 7960")
     uvicorn.run(app, host="0.0.0.0", port=7960)
-    # report_dir = "./data/reports/00000000000000000000000011445301/440298"
-    # upload_file(report_dir)

+ 24 - 21
train.py

@@ -4,6 +4,9 @@ from models.rank import DataProcess, Trainer, GbdtLrModel
 from models import ItemCFModel, HotRecallModel
 import time
 import pandas as pd
+from core import get_logger
+
+logger = get_logger("train")
 
 # train_data_path = "./moldes/rank/data/gbdt_data.csv"
 # model_path = "./models/rank/weights"
@@ -17,14 +20,14 @@ def gbdtlr_train(args):
     if not os.path.exists(train_data_dir):
         os.makedirs(train_data_dir)
     
-    # 准备数据集  
-    print("正在整合训练数据...")
+    # 准备数据集
+    logger.info("正在整合训练数据...")
     processor = DataProcess(args.city_uuid, args.train_data_dir)
     processor.data_process()
-    print("训练数据整合完成!")
-    
+    logger.info("训练数据整合完成")
+
     # 进行训练
-    print("开始训练gbdt-lr模型")
+    logger.info("开始训练gbdt-lr模型")
     gbdtlr_trainer(os.path.join(args.train_data_dir, "train_data.csv"), model_dir, "gbdtlr_model.pkl")
 
 def gbdtlr_trainer(train_data_path, model_dir, model_name):
@@ -35,14 +38,14 @@ def gbdtlr_trainer(train_data_path, model_dir, model_name):
     end_time = time.time()
     
     training_time_hours = (end_time - start_time) / 3600
-    print(f"训练时间: {training_time_hours:.4f} 小时")
-    
+    logger.info(f"训练时间: {training_time_hours:.4f} 小时")
+
     eval_metrics = trainer.evaluate()
-    
+
     # 输出评估结果
-    print("GBDT-LR Evaluation Metrics:")
+    logger.info("GBDT-LR Evaluation Metrics:")
     for metric, value in eval_metrics.items():
-        print(f"{metric}: {value:.4f}")
+        logger.info(f"{metric}: {value:.4f}")
         
     # 保存模型
     trainer.save_model(os.path.join(model_dir, model_name))
@@ -79,24 +82,24 @@ def run():
     args = parser.parse_args()
     
     if args.run_train:
-        print("正在计算协同过滤...")
+        logger.info("正在计算协同过滤...")
         itemCF(args)
-        
-        print("正在计算热度召回...")
+
+        logger.info("正在计算热度召回...")
         hot_recall(args)
-        
-        print("正在进行gbdt_lr训练...")
+
+        logger.info("正在进行gbdt_lr训练...")
         gbdtlr_train(args)
-        
+
     if args.run_recall:
-        print("正在计算协同过滤...")
+        logger.info("正在计算协同过滤...")
         itemCF(args)
-        
-        print("正在计算热度召回...")
+
+        logger.info("正在计算热度召回...")
         hot_recall(args)
-        
+
     if args.run_gbdtlr:
-        print("正在进行gbdt_lr训练...")
+        logger.info("正在进行gbdt_lr训练...")
         gbdtlr_train(args)
         
 if __name__ == "__main__":

+ 36 - 43
utils/file_stream.py

@@ -1,87 +1,80 @@
-from config import load_service_config
+import time
+from core import get_logger, settings
 from io import BytesIO
 import os
 import pandas as pd
 import requests
 
+logger = get_logger("utils.file_stream")
+
 
 class FileStreamUtils:
-    cfgs = load_service_config()
-    upload_url = cfgs["aliyun"]["upload_url"]
-    download_url = cfgs["aliyun"]["download_url"]
-    # cookies = cfgs["aliyun"]['cookies']
-     # 设置请求头
+    upload_url = settings.file_upload_url
+    download_url = settings.file_download_url
     headers = {
-        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
+        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
         "Accept": "*/*",
     }
-    
+
     @staticmethod
     def upload_files(reports_dir, files):
         files_id = {}
         for filename in files:
-            file_path = os.path.join(reports_dir, f'{filename}.xlsx')
+            file_path = os.path.join(reports_dir, f"{filename}.xlsx")
+            start_time = time.time()
             try:
-                with open(file_path, 'rb') as f:
-                    files = {'file': (os.path.basename(file_path), f)}
-
+                with open(file_path, "rb") as f:
+                    upload_files = {"file": (os.path.basename(file_path), f)}
                     response = requests.post(
                         FileStreamUtils.upload_url,
                         headers=FileStreamUtils.headers,
-                        files=files,
-                        # cookies=FileStreamUtils.cookies,
-                        verify=True
+                        files=upload_files,
+                        verify=True,
                     )
-                    
-                    if response.json()["success"]:
+                    duration_ms = (time.time() - start_time) * 1000
+                    if response.json().get("success"):
                         file_id = response.json()["data"]["file_info"]["fileid"]
                         files_id[filename] = file_id
+                        logger.info(f"File uploaded: {filename} -> {file_id} ({duration_ms:.0f}ms)")
+                    else:
+                        logger.error(f"Upload failed for {filename}: {response.text}")
+                        return None
             except requests.exceptions.RequestException as e:
-                print("请求出错:", e)
+                logger.error(f"Upload request error for {filename}: {e}", exc_info=True)
                 return None
             except Exception as e:
+                logger.error(f"Upload error for {filename}: {e}", exc_info=True)
                 return None
-                
         return files_id
-    
+
     @staticmethod
-    def download_file(file_id, file_type='xlsx'):
+    def download_file(file_id, file_type="xlsx"):
         """通过file_id从阿里云文件数据库下载文件"""
+        start_time = time.time()
         try:
-            # params = {
-            #     'fileid': file_id,
-            #     'action': 'download'
-            # }
             response = requests.get(
                 f"{FileStreamUtils.download_url}/{file_id}",
                 headers=FileStreamUtils.headers,
-                # cookies=FileStreamUtils.cookies,
-                # params=params,
-                verify=True
+                verify=True,
             )
-            
+            duration_ms = (time.time() - start_time) * 1000
+
             if response.status_code == 200:
                 file_content = BytesIO(response.content)
-                if file_type == 'xlsx':
-                    data = pd.read_excel(file_content, engine='openpyxl')
-                elif file_type == 'csv':
+                if file_type == "xlsx":
+                    data = pd.read_excel(file_content, engine="openpyxl")
+                elif file_type == "csv":
                     data = pd.read_csv(file_content)
                 else:
-                    raise ValueError(f"不支持的文件类型:{file_type}" )
-                
+                    raise ValueError(f"不支持的文件类型:{file_type}")
+                logger.info(f"File downloaded: {file_id} ({duration_ms:.0f}ms, {len(response.content)} bytes)")
                 return data
             else:
+                logger.error(f"Download failed: file_id={file_id}, status={response.status_code}")
                 return None
         except requests.exceptions.RequestException as e:
-            print("Request Error: ", e)
+            logger.error(f"Download request error: file_id={file_id}, error={e}", exc_info=True)
             return None
         except Exception as e:
-            print("File download Error: ", e)
+            logger.error(f"Download error: file_id={file_id}, error={e}", exc_info=True)
             return None
-    
-if __name__ == '__main__':
-    # print(FileStreamUtils.cfgs["aliyun"]["cookies"])
-    file_id = '11C1AC088863421C9BC32A5E722F5147'
-    
-    data = FileStreamUtils.download_file(file_id)
-    data.to_excel('./recommend_list.xlsx', index=False)

+ 19 - 3
utils/report_utils.py

@@ -3,10 +3,14 @@ from models import Recommend
 from models.rank.data.config import CustConfig, ImportanceFeaturesMap, ProductConfig, DeliveryConfig
 from models.rank.data.utils import sample_data_clear
 from models.rank import generate_feats_map
+from core import get_logger
 
 import os
 import pandas as pd
 from utils.reports_process import feats_relation_process, calculate_delivery_by_recommend_data, eval_report_process_pre, eval_report_process
+
+logger = get_logger("utils.report")
+
 class ReportUtils:
     def __init__(self, city_uuid, product_id):
         self._recommend_model = Recommend(city_uuid)
@@ -64,33 +68,40 @@ class ReportUtils:
     
     def generate_feats_ralation_report(self, recall_count):
         """生成特征相关性分析报告"""
+        logger.info("Generating feature relation report")
         feats_map = self._generate_feats_map(recall_count)
         product_content = self._get_product_content()
         # 计算SHAP值
         shap_result = self._recommend_model._gbdtlr_model.generate_shap_interance(feats_map)
         report = feats_relation_process(shap_result, product_content)
-        
+
         report.to_excel(os.path.join(self._save_dir, "品规商户特征关系表.xlsx"), index=False)
+        logger.info("Feature relation report saved")
         
     def generate_product_report(self):
         """生成推荐品规信息表"""
+        logger.info("Generating product report")
         product_data = self._get_product_content()
         with open(os.path.join(self._save_dir, "卷烟信息表.xlsx"), "w", encoding='utf-8-sig') as file:
             for key, value in product_data.items():
                 if key != 'product_code':
                     file.write(f"{ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP[key]}, {value}\n")
+        logger.info("Product report saved")
                     
     def generate_recommend_report(self, recall_count, delivery_count):
         """生成推荐报告,包括投放量"""
+        logger.info("Generating recommend report")
         recommend_data = self._get_recommend_data(recall_count)
         recommend_list = list(map(lambda x: x["cust_code"], recommend_data))
         recommend_cust_infos = self._dao.get_cust_by_ids(self._city_uuid, recommend_list)
         report = calculate_delivery_by_recommend_data(recommend_data, recommend_cust_infos, delivery_count)
-        
+
         report.to_excel(os.path.join(self._save_dir, "商户售卖推荐表.xlsx"), index=False)
+        logger.info("Recommend report saved")
         
     def generate_similarity_product_report(self):
         """生成相似卷烟表"""
+        logger.info("Generating similarity product report")
         product_similarity_map = self._recommend_model._item2vec_model.generate_product_similarity_map(self._product_id)
         product_similarity_map = product_similarity_map[["product_name", "similarity", "brand_name", "factory_name", "is_low_tar", "is_medium", "is_tiny", "is_coarse", "is_exploding_beads", "is_abnormity", "is_cig", "is_chuangxin", "direct_retail_price", "tbc_total_length", "product_style"]]
         product_similarity_map = product_similarity_map.rename(
@@ -113,6 +124,7 @@ class ReportUtils:
             }
         )
         product_similarity_map.to_excel(os.path.join(self._save_dir, "相似卷烟表.xlsx"), index=False)
+        logger.info("Similarity product report saved")
         
     def generate_eval_data_pre(self):
         if self._product_id == '350139':
@@ -121,7 +133,7 @@ class ReportUtils:
             eval_product_id = self._product_id
         eval_order_data = self._dao.get_eval_order_by_product(self._city_uuid, eval_product_id)
         if not os.path.exists(os.path.join(self._save_dir, "商户售卖推荐表.xlsx")):
-            print("请先生成'商户售卖推荐表'")
+            logger.error("商户售卖推荐表 not found")
         recommend_data = pd.read_excel(os.path.join(self._save_dir, "商户售卖推荐表.xlsx"))
         report = eval_report_process_pre(eval_order_data, recommend_data)
         
@@ -129,6 +141,7 @@ class ReportUtils:
         
     def generate_eval_data(self, start_time, end_time, recommend_data):
         """根据推荐列表生成验证报告"""
+        logger.info("Generating eval report")
         if self._product_id == '350139':
             eval_product_id = "350355"
         else:
@@ -142,13 +155,16 @@ class ReportUtils:
         report = eval_report_process(delivery_data, recommend_data)
         
         report.to_excel(os.path.join(self._save_dir, "投放验证报告.xlsx"), index=False)
+        logger.info("Eval report saved")
     
     def generate_all_data(self, recall_count, delivery_count):
+        logger.info("Generating all reports")
         self.generate_feats_ralation_report(recall_count)
         self.generate_product_report()
         self.generate_recommend_report(recall_count, delivery_count)
         self.generate_similarity_product_report()
         # self.generate_eval_data()
+        logger.info("All reports generated")
         
 if __name__ == "__main__":
     city_uuid = "00000000000000000000000011445301"