Przeglądaj źródła

refactor(database): add logging, session context manager, env-based config

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sherlock 3 tygodni temu
rodzic
commit
a7a8187584
2 zmienionych plików z 196 dodań i 261 usunięć
  1. 154 209
      database/db/mysql.py
  2. 42 52
      database/db/redis_db.py

+ 154 - 209
database/db/mysql.py

@@ -1,209 +1,154 @@
-from config import load_config
-import pandas as pd
-from sqlalchemy import create_engine, text
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy.exc import SQLAlchemyError
-from tqdm import tqdm
-
-cfgs = load_config()
-
-
-class MySqlDatabaseHelper:
-    _instance = None
-    
-    def __new__(cls):
-        if not cls._instance:
-            cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
-            cls._instance._initialized = False
-        return cls._instance
-        
-    def __init__(self):
-        if self._initialized:
-            return
-        
-        self._host = cfgs['mysql']['host']
-        self._port = cfgs['mysql']['port']
-        self._user = cfgs['mysql']['user']
-        self._passwd = cfgs['mysql']['passwd']
-        self._dbname = cfgs['mysql']['db']
-        
-        self.connect_database()
-        self._initialized = True
-        
-    def connect_database(self):
-        # 创建数据库连接
-        try:
-            conn = "mysql+pymysql://" + self._user + ":" + self._passwd + "@" + self._host + ":" + str(self._port) + "/" + self._dbname
-
-            # 通过连接池创建engine
-            self.engine = create_engine(
-                conn,
-                pool_size=20, # 设置连接池大小
-                max_overflow=30, # 超过连接池大小时的额外连接数
-                pool_recycle=1800, # 回收连接时间
-                pool_pre_ping=True, # 防止断开连接
-                isolation_level="READ COMMITTED" # 降低隔离级别
-            )
-        except Exception as e:
-            raise ConnectionAbortedError(f"failed to create connection: {e}")
-
-        self._DBSession = sessionmaker(bind=self.engine)
-    
-    def load_data_with_page(self, query, params, page_size=100000):
-        """分页查询数据"""
-        data = pd.DataFrame()
-        # 用子查询包裹原始查询来计数,避免字符串替换
-        count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq")
-        query += " LIMIT :limit OFFSET :offset"
-        query = text(query)
-
-        # 获取总行数
-        result = self.fetch_one(count_query, params)
-        total_rows = result[0] if result is not None else 0
-
-        if total_rows == 0:
-            return data
-
-        page = 1
-        with tqdm(total=total_rows, desc="Loading data", unit="rows") as pbar:
-            while True:
-                offset = (page - 1) * page_size
-                # 复制 params 避免修改调用方的字典
-                page_params = dict(params)
-                page_params["limit"] = page_size
-                page_params["offset"] = offset
-
-                df = pd.DataFrame(self.fetch_all(query, page_params))
-                if df.empty:
-                    break
-                data = pd.concat([data, df], ignore_index=True)
-
-                pbar.update(len(df))
-
-                page += 1
-        return data
-        
-        
-    def fetch_all(self, query, params=None):
-        """执行SQL查询并返回所有结果"""
-        session = self._DBSession()
-        try:
-            results = session.execute(query, params or {}).fetchall()
-            return results
-        except SQLAlchemyError as e:
-            session.rollback()
-            print(f"error: {e}")
-            raise
-        finally:
-            session.close()
-            
-    def fetch_one(self, query, params=None):
-        """执行SQL查询并返回单条结果"""
-        session = self._DBSession()
-        try:
-            result = session.execute(query, params or {}).fetchone()
-            return result
-
-        except SQLAlchemyError as e:
-            session.rollback()
-            print(f"error: {e}")
-            raise
-        finally:
-            session.close()
-            
-    def insert_data(self, table_name, data_dict):
-        """插入单条数据到指定表"""
-        if not data_dict:
-            return 0
-        
-        columns = ", ".join(data_dict.keys())
-        values = ", ".join([f":{key}" for key in data_dict.keys()])
-        query = text(f"INSERT INTO {table_name} ({columns}) VALUES ({values})")
-        
-        session = self._DBSession()
-        
-        try:
-            result = session.execute(query, data_dict)
-            session.commit()
-            return result.rowcount
-        
-        except SQLAlchemyError as e:
-            session.rollback()
-            print(f"Error inserting data: {e}")
-            return 0
-        finally:
-            session.close()
-            
-    def update_data(self, table_name, update_dict, conditions, condition_params=None):
-        """更新表中符合条件的数据"""
-        if not update_dict:
-            return 0
-        
-        set_clause = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
-        
-        if len(conditions) == 1:
-            where_clause = f"WHERE {conditions[0]}"
-        elif len(conditions) > 1:
-            where_clause = f"WHERE {' AND '.join(conditions)}"
-        else:
-            where_clause = ""
-        
-        query = text(f"UPDATE {table_name} SET {set_clause} {where_clause}")
-        
-        params = update_dict.copy()
-        if condition_params:
-            params.update(condition_params)
-            
-        session = self._DBSession()
-        try:
-            result = session.execute(query, params)
-            session.commit()
-            return result.rowcount
-        except SQLAlchemyError as e:
-            session.rollback()
-            print(f"Error updating data: {e}")
-            return 0
-        
-        finally:
-            session.close()
-    
-    def execute_query(self, query, params=None):
-        """执行SQL语句 (无返回值, 如INSERT, UPDATE, DELETE)"""
-        session = self._DBSession()
-        try:
-            session.execute(query, params or {})
-            session.commit()
-        except SQLAlchemyError as e:
-            session.rollback()
-            print(f"Error: {e}")
-        finally:
-            session.close()
-            
-if __name__ == '__main__':
-    db_helper = MySqlDatabaseHelper()
-    
-    table_name = 'tads_brandcul_report'
-    data_dict = {
-        'cultivacation_id': 10000002,
-        'city_uuid': '00000000000000000000000011445301',
-        'limit_cycle_name': '202505W1(05.05-05.11)',
-        'product_code': '440298',
-        'product_info_table': 'D72E3FAE8DCE4270BD23C3EC015C0A35',
-        'relation_table': 'AD889019FD4F4EE7B887981162BA09EC',
-        'similarity_product_table': 'CE436AC24D96461FA0C091CB01E9BC05',
-        'recommend_table': 'A7C5918B8DDB4BEA9D921936955CBAF6',
-    }
-    
-    # db_helper.insert_data(table_name, data_dict)
-    
-    update_data = {"val_table": "A7C5918B8DDB4BEA9D921936955CBAF6"}
-    conditions = [
-        "cultivacation_id = :cultivacation_id",
-        "city_uuid = :city_uuid"
-    ]
-    condition_params = {
-        'cultivacation_id': 10000001,
-        'city_uuid': '00000000000000000000000011445301',
-    }
-    
-    db_helper.update_data(table_name, update_data, conditions, condition_params)
+from contextlib import contextmanager
+from core import get_logger, settings, DatabaseException
+import pandas as pd
+from sqlalchemy import create_engine, text
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.exc import SQLAlchemyError
+
+logger = get_logger("database.mysql")
+
+
+class MySqlDatabaseHelper:
+    _instance = None
+
+    def __new__(cls):
+        if not cls._instance:
+            cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
+            cls._instance._initialized = False
+        return cls._instance
+
+    def __init__(self):
+        if self._initialized:
+            return
+        self._connect_database()
+        self._initialized = True
+
+    def _connect_database(self):
+        try:
+            conn_str = (
+                f"mysql+pymysql://{settings.mysql_user}:{settings.mysql_password}"
+                f"@{settings.mysql_host}:{settings.mysql_port}/{settings.mysql_db}"
+            )
+            self.engine = create_engine(
+                conn_str,
+                pool_size=20,
+                max_overflow=30,
+                pool_recycle=1800,
+                pool_pre_ping=True,
+                isolation_level="READ COMMITTED",
+            )
+            self._DBSession = sessionmaker(bind=self.engine)
+            logger.info("MySQL connection pool created", extra={"extra_data": {"host": settings.mysql_host, "db": settings.mysql_db}})
+        except Exception as e:
+            logger.error("Failed to create MySQL connection", exc_info=True)
+            raise DatabaseException(message="数据库连接失败", detail=str(e))
+
+    @contextmanager
+    def get_session(self):
+        session = self._DBSession()
+        try:
+            yield session
+        except SQLAlchemyError as e:
+            session.rollback()
+            logger.error("Database operation failed", exc_info=True)
+            raise DatabaseException(message="数据库操作失败", detail=str(e))
+        finally:
+            session.close()
+
+    def load_data_with_page(self, query, params, page_size=100000):
+        """分页查询数据"""
+        count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq")
+        query += " LIMIT :limit OFFSET :offset"
+        query = text(query)
+
+        result = self.fetch_one(count_query, params)
+        total_rows = result[0] if result is not None else 0
+
+        if total_rows == 0:
+            logger.debug("Query returned 0 rows")
+            return pd.DataFrame()
+
+        logger.debug(f"Loading {total_rows} rows with page_size={page_size}")
+        data = pd.DataFrame()
+        page = 1
+        while True:
+            offset = (page - 1) * page_size
+            page_params = dict(params)
+            page_params["limit"] = page_size
+            page_params["offset"] = offset
+
+            df = pd.DataFrame(self.fetch_all(query, page_params))
+            if df.empty:
+                break
+            data = pd.concat([data, df], ignore_index=True)
+            page += 1
+
+        logger.debug(f"Loaded {len(data)} rows in {page - 1} pages")
+        return data
+
+    def fetch_all(self, query, params=None):
+        """执行SQL查询并返回所有结果"""
+        with self.get_session() as session:
+            results = session.execute(query, params or {}).fetchall()
+            return results
+
+    def fetch_one(self, query, params=None):
+        """执行SQL查询并返回单条结果"""
+        with self.get_session() as session:
+            result = session.execute(query, params or {}).fetchone()
+            return result
+
+    def insert_data(self, table_name, data_dict):
+        """插入单条数据到指定表"""
+        if not data_dict:
+            return 0
+
+        columns = ", ".join(data_dict.keys())
+        values = ", ".join([f":{key}" for key in data_dict.keys()])
+        query = text(f"INSERT INTO {table_name} ({columns}) VALUES ({values})")
+
+        with self.get_session() as session:
+            result = session.execute(query, data_dict)
+            session.commit()
+            logger.info(f"Inserted 1 row into {table_name}")
+            return result.rowcount
+
+    def update_data(self, table_name, update_dict, conditions, condition_params=None):
+        """更新表中符合条件的数据"""
+        if not update_dict:
+            return 0
+
+        set_clause = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
+
+        if len(conditions) == 1:
+            where_clause = f"WHERE {conditions[0]}"
+        elif len(conditions) > 1:
+            where_clause = f"WHERE {' AND '.join(conditions)}"
+        else:
+            where_clause = ""
+
+        query = text(f"UPDATE {table_name} SET {set_clause} {where_clause}")
+
+        params = update_dict.copy()
+        if condition_params:
+            params.update(condition_params)
+
+        with self.get_session() as session:
+            result = session.execute(query, params)
+            session.commit()
+            logger.info(f"Updated {result.rowcount} rows in {table_name}")
+            return result.rowcount
+
+    def execute_query(self, query, params=None):
+        """执行SQL语句"""
+        with self.get_session() as session:
+            session.execute(query, params or {})
+            session.commit()
+
+    def check_connection(self) -> bool:
+        """检查数据库连接是否正常"""
+        try:
+            self.fetch_one(text("SELECT 1"), {})
+            return True
+        except Exception:
+            return False

+ 42 - 52
database/db/redis_db.py

@@ -1,52 +1,42 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-import redis
-from config import load_config
-
-cfgs = load_config()
-
-
-class RedisDatabaseHelper:
-    _instance = None
-    
-    def __new__(cls):
-        if not cls._instance:
-            cls._instance = super(RedisDatabaseHelper, cls).__new__(cls)
-            cls._instance._initialized = False
-        return cls._instance
-        
-    def __init__(self):
-        if self._initialized:
-            return
-        self.redis = redis.StrictRedis(host=cfgs['redis']['host'],
-                                       port=cfgs['redis']['port'],
-                                       password=cfgs['redis']['passwd'],
-                                       db=cfgs['redis']['db'],
-                                       decode_responses=True)
-        
-        self._initialized = True
-
-
-if __name__ == '__main__':
-    import random
-    # 连接到 Redis 服务器
-    r = RedisDatabaseHelper().redis
-
-    # 有序集合的键名
-    zset_key = 'configs:hotkeys'
-
-    data_list = ['ORDER_FULLORDR_RATE', 'MONTH6_SALE_QTY_YOY', 'MONTH6_SALE_QTY_MOM', 'MONTH6_SALE_QTY']
-
-    # 清空已有的有序集合(可选,若需要全新的集合可执行此操作)
-    r.delete(zset_key)
-    
-    for item in data_list:
-        # 生成 80 到 100 之间的随机数,小数点后保留 4 位
-        score = round(random.uniform(80, 100), 4)
-        # 将元素和对应的分数添加到有序集合中
-        r.zadd(zset_key, {item: score})
-
-    # # 从 Redis 中读取有序集合并打印
-    # result = r.zrange(zset_key, 0, -1, withscores=True)
-    # for item, score in result:
-    #     print(f"元素: {item}, 分数: {score}")
+import redis
+from core import get_logger, settings, DatabaseException
+
+logger = get_logger("database.redis")
+
+
+class RedisDatabaseHelper:
+    _instance = None
+
+    def __new__(cls):
+        if not cls._instance:
+            cls._instance = super(RedisDatabaseHelper, cls).__new__(cls)
+            cls._instance._initialized = False
+        return cls._instance
+
+    def __init__(self):
+        if self._initialized:
+            return
+        try:
+            pool = redis.ConnectionPool(
+                host=settings.redis_host,
+                port=settings.redis_port,
+                password=settings.redis_password,
+                db=settings.redis_db,
+                decode_responses=True,
+                max_connections=50,
+            )
+            self.redis = redis.StrictRedis(connection_pool=pool)
+            self.redis.ping()
+            logger.info("Redis connection established", extra={"extra_data": {"host": settings.redis_host, "db": settings.redis_db}})
+        except redis.ConnectionError as e:
+            logger.error("Failed to connect to Redis", exc_info=True)
+            raise DatabaseException(message="Redis连接失败", detail=str(e))
+        self._initialized = True
+
+    def check_connection(self) -> bool:
+        """检查 Redis 连接是否正常"""
+        try:
+            self.redis.ping()
+            return True
+        except Exception:
+            return False