Browse Source

refactor: improve DAO logging, rewrite run_api with middleware and health check

Sherlock 3 weeks ago
parent
commit
b92be3efd9
2 changed files with 94 additions and 30 deletions
  1. 29 19
      database/dao/mysql_dao.py
  2. 65 11
      run_api.py

+ 29 - 19
database/dao/mysql_dao.py

@@ -1,7 +1,10 @@
+from core import get_logger
 from database import MySqlDatabaseHelper
 from sqlalchemy import text, bindparam
 import pandas as pd
 
+logger = get_logger("database.dao")
+
 class MySqlDao:
     _instance = None
     
@@ -31,6 +34,7 @@ class MySqlDao:
         
     def load_product_data(self, city_uuid):
         """从数据库中读取商品信息"""
+        logger.info(f"Loading product data for city_uuid={city_uuid}")
         query = f"SELECT * FROM {self._product_tablename} WHERE city_uuid = :city_uuid AND org_is_active = '是'"
         params = {"city_uuid": city_uuid}
         
@@ -39,6 +43,7 @@ class MySqlDao:
         
     def load_cust_data(self, city_uuid):
         """从数据库中读取商户信息"""
+        logger.info(f"Loading cust data for city_uuid={city_uuid}")
         query = f"SELECT * FROM {self._cust_tablename} WHERE corp_uuid = :city_uuid"
         params = {"city_uuid": city_uuid}
         
@@ -47,6 +52,7 @@ class MySqlDao:
     
     def load_order_data(self, city_uuid):
         """从数据库中读取订单信息"""
+        logger.info(f"Loading order data for city_uuid={city_uuid}")
         query = f"SELECT * FROM {self._order_tablename} WHERE city_uuid = :city_uuid"
         params = {"city_uuid": city_uuid}
         
@@ -61,6 +67,7 @@ class MySqlDao:
     
     def load_order_analysis_index_data(self, city_uuid):
         """从数据库中读取销售指标评估表"""
+        logger.info(f"Loading order analysis index data for city_uuid={city_uuid}")
         query = f"SELECT * FROM {self._order_analysis_table_name} WHERE city_uuid = :city_uuid"
         params = {"city_uuid": city_uuid}
         
@@ -69,6 +76,7 @@ class MySqlDao:
 
     def load_delivery_order_data(self, city_uuid, start_time, end_time):
         """从数据库中读取订单信息"""
+        logger.info(f"Loading delivery order data for city_uuid={city_uuid}, start_time={start_time}, end_time={end_time}")
         query = f"SELECT * FROM {self._eval_order_name} WHERE city_uuid = :city_uuid AND cycle_begin_date = :start_time AND cycle_end_date = :end_time"
         params = {
             "city_uuid": city_uuid,
@@ -83,6 +91,7 @@ class MySqlDao:
     
     def load_mock_order_data(self):
         """从数据库中读取mock的订单信息"""
+        logger.info("Loading mock order data")
         query = f"SELECT * FROM {self._mock_order_tablename}"
         
         data = self.db_helper.load_data_with_page(query, {})
@@ -91,6 +100,7 @@ class MySqlDao:
     
     def load_shopping_data(self, city_uuid):
         """从数据库中读取商圈数据"""
+        logger.info(f"Loading shopping data for city_uuid={city_uuid}")
         query = f"SELECT * FROM {self._shopping_tablename} WHERE city_uuid = :city_uuid"
         params = {"city_uuid": city_uuid}
         
@@ -100,6 +110,7 @@ class MySqlDao:
     
     def get_product_by_id(self, city_uuid, product_id):
         """根据city_uuid 和 product_id 从表中获取拼柜信息"""
+        logger.info(f"Getting product by id for city_uuid={city_uuid}, product_id={product_id}")
         query = text(f"""
             SELECT *
             FROM {self._product_tablename}
@@ -112,6 +123,7 @@ class MySqlDao:
     
     def get_cust_by_ids(self, city_uuid, cust_id_list):
         """根据零售户列表查询其信息"""
+        logger.info(f"Getting cust by ids for city_uuid={city_uuid}, count={len(cust_id_list) if cust_id_list else 0}")
         if not cust_id_list:
             return pd.DataFrame()
 
@@ -128,6 +140,7 @@ class MySqlDao:
     
     def get_shop_by_ids(self, city_uuid, cust_id_list):
         """根据零售户列表查询其信息"""
+        logger.info(f"Getting shop by ids for city_uuid={city_uuid}, count={len(cust_id_list) if cust_id_list else 0}")
         if not cust_id_list:
             return pd.DataFrame()
 
@@ -144,6 +157,7 @@ class MySqlDao:
     
     def get_product_by_ids(self, city_uuid, product_id_list):
         """根据product_code列表查询其信息"""
+        logger.info(f"Getting products by ids for city_uuid={city_uuid}, count={len(product_id_list) if product_id_list else 0}")
         if not product_id_list:
             return pd.DataFrame()
 
@@ -166,6 +180,7 @@ class MySqlDao:
     
     def get_order_by_product_ids(self, city_uuid, product_ids):
         """获取指定香烟列表的所有售卖记录"""
+        logger.info(f"Getting orders by product ids for city_uuid={city_uuid}, count={len(product_ids) if product_ids else 0}")
         if not product_ids:
             return pd.DataFrame()
 
@@ -193,7 +208,7 @@ class MySqlDao:
         return data
     
     def get_order_by_product(self, city_uuid, product_id):
-        
+        logger.info(f"Getting orders by product for city_uuid={city_uuid}, product_id={product_id}")
         query = f"""
             SELECT *
             FROM {self._order_tablename}
@@ -210,6 +225,7 @@ class MySqlDao:
         return data
     
     def get_eval_order_by_product(self, city_uuid, product_id):
+        logger.info(f"Getting eval orders by product for city_uuid={city_uuid}, product_id={product_id}")
         query = f"""
             SELECT *
             FROM {self._eval_order_name}
@@ -223,6 +239,7 @@ class MySqlDao:
     
     def get_delivery_data_by_product(self, city_uuid, product_id, start_time, end_time):
         """通过品规获取验证数据"""
+        logger.info(f"Getting delivery data by product for city_uuid={city_uuid}, product_id={product_id}, start_time={start_time}, end_time={end_time}")
         query = f"""
             SELECT *
             FROM {self._eval_order_name}
@@ -242,6 +259,7 @@ class MySqlDao:
         return data
     
     def get_order_by_cust(self, city_uuid, cust_id):
+        logger.info(f"Getting orders by cust for city_uuid={city_uuid}, cust_id={cust_id}")
         query = f"""
             SELECT *
             FROM {self._order_tablename}
@@ -254,6 +272,7 @@ class MySqlDao:
         return data
     
     def get_order_by_cust_and_product(self, city_uuid, cust_id, product_id):
+        logger.info(f"Getting orders by cust and product for city_uuid={city_uuid}, cust_id={cust_id}, product_id={product_id}")
         query = f"""
             SELECT *
             FROM {self._order_tablename}
@@ -267,6 +286,7 @@ class MySqlDao:
         return data
     
     def get_product_from_order(self, city_uuid):
+        logger.info(f"Getting products from order for city_uuid={city_uuid}")
         query = f"SELECT DISTINCT product_code FROM {self._order_tablename} WHERE city_uuid = :city_uuid ORDER BY product_code"
         params = {"city_uuid": city_uuid}
         
@@ -275,6 +295,7 @@ class MySqlDao:
         return data
 
     def get_cust_list(self, city_uuid):
+        logger.info(f"Getting cust list for city_uuid={city_uuid}")
         query = f"SELECT DISTINCT cust_code FROM {self._cust_tablename} WHERE corp_uuid = :city_uuid ORDER BY cust_code"
         params = {"city_uuid": city_uuid}
         
@@ -282,26 +303,14 @@ class MySqlDao:
 
         return data
 
-    def data_preprocess(self, data: pd.DataFrame):
-        """数据预处理"""
-        data.drop(["cust_uuid", "longitude", "latitude", "range_radius"], axis=1, inplace=True)
-        remaining_cols = data.columns.drop(["city_uuid", "cust_code"])
-        col_with_missing = remaining_cols[data[remaining_cols].isnull().any()].tolist() # 判断有缺失的字段
-        col_all_missing = remaining_cols[data[remaining_cols].isnull().all()].to_list() # 全部缺失的字段
-        col_partial_missing = list(set(col_with_missing) - set(col_all_missing)) # 部分缺失的字段
-        
-        for col in col_partial_missing:
-            data[col] = data[col].fillna(data[col].mean())
-        
-        for col in col_all_missing:
-            data[col] = data[col].fillna(0).infer_objects(copy=False)
-        
     def insert_report(self, data_dict):
         """向report中插入数据"""
+        logger.info("Inserting report data")
         return self.db_helper.insert_data(self._report_tablename, data_dict)
     
     def update_eval_report_data(self, cultivacation_id, eval_fileid):
         """更新投放记录中的验证报告fileid"""
+        logger.info(f"Updating eval report data for cultivacation_id={cultivacation_id}")
         update_data = {"val_table": eval_fileid}
         conditions = [
             "cultivacation_id = :cultivacation_id",
@@ -314,13 +323,14 @@ class MySqlDao:
     
     def get_report_file_id(self, cultivacation_id):
         """从report中根据cultivacation_id获取对应文件的fileid"""
+        logger.info(f"Getting report file id for cultivacation_id={cultivacation_id}")
         query = f"SELECT product_info_table, relation_table, similarity_product_table, recommend_table, val_table FROM {self._report_tablename} WHERE cultivacation_id = :cultivacation_id"
         params = {"cultivacation_id": cultivacation_id}
-        
         result = self.db_helper.fetch_one(text(query), params)
-        data = pd.DataFrame([dict(result._mapping)] if result else None)
-        
-        return data
+        if result is None:
+            logger.warning(f"No report found for cultivacation_id={cultivacation_id}")
+            return pd.DataFrame()
+        return pd.DataFrame([dict(result._mapping)])
         
 if __name__ == "__main__":
     dao = MySqlDao()

+ 65 - 11
run_api.py

@@ -1,35 +1,89 @@
 from api import recommend_router, report_router, eval_report_router
+from core import get_logger, AppException, RequestLoggingMiddleware, get_request_id
+from database.db.mysql import MySqlDatabaseHelper
+from database.db.redis_db import RedisDatabaseHelper
 from fastapi import FastAPI, Request, status
 from fastapi.exceptions import RequestValidationError
 from fastapi.responses import JSONResponse
 
 import uvicorn
 
+logger = get_logger("app")
+
 app = FastAPI()
 
-# 添加全局异常处理器
+app.add_middleware(RequestLoggingMiddleware)
+
+
 @app.exception_handler(RequestValidationError)
 async def validation_exception_handler(request: Request, exc: RequestValidationError):
+    logger.warning(f"Validation error: {exc.errors()}")
     return JSONResponse(
         status_code=status.HTTP_400_BAD_REQUEST,
         content={
             "code": 400,
             "msg": "请求参数错误",
-            "data": {
-                "detail": exc.errors(),
-                "body": exc.body
-            }
+            "data": {"detail": exc.errors(), "body": exc.body},
+            "request_id": get_request_id(),
         },
     )
 
-url_prefix = '/brandcultivation/api/v1'
-  
-# 注册路由
+
+@app.exception_handler(AppException)
+async def app_exception_handler(request: Request, exc: AppException):
+    logger.error(f"AppException: {exc.message}", extra={"extra_data": {"detail": exc.detail}})
+    return JSONResponse(
+        status_code=exc.code,
+        content={
+            "code": exc.code,
+            "msg": exc.message,
+            "data": {"detail": exc.detail},
+            "request_id": get_request_id(),
+        },
+    )
+
+
+@app.exception_handler(Exception)
+async def unhandled_exception_handler(request: Request, exc: Exception):
+    logger.error("Unhandled exception", exc_info=True)
+    return JSONResponse(
+        status_code=500,
+        content={
+            "code": 500,
+            "msg": "服务器内部错误",
+            "data": None,
+            "request_id": get_request_id(),
+        },
+    )
+
+
+@app.get("/health")
+async def health_check():
+    mysql_ok = False
+    redis_ok = False
+    try:
+        mysql_ok = MySqlDatabaseHelper().check_connection()
+    except Exception:
+        pass
+    try:
+        redis_ok = RedisDatabaseHelper().check_connection()
+    except Exception:
+        pass
+
+    healthy = mysql_ok and redis_ok
+    return {
+        "status": "healthy" if healthy else "degraded",
+        "mysql": "ok" if mysql_ok else "error",
+        "redis": "ok" if redis_ok else "error",
+    }
+
+
+url_prefix = "/brandcultivation/api/v1"
+
 app.include_router(recommend_router, prefix=url_prefix)
 app.include_router(report_router, prefix=url_prefix)
 app.include_router(eval_report_router, prefix=url_prefix)
-    
+
 if __name__ == "__main__":
+    logger.info("Starting BrandCultivation API server on port 7960")
     uvicorn.run(app, host="0.0.0.0", port=7960)
-    # report_dir = "./data/reports/00000000000000000000000011445301/440298"
-    # upload_file(report_dir)