| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- from contextlib import contextmanager
- from core import get_logger, settings, DatabaseException
- import pandas as pd
- from sqlalchemy import create_engine, text
- from sqlalchemy.orm import sessionmaker
- from sqlalchemy.exc import SQLAlchemyError
- logger = get_logger("database.mysql")
- class MySqlDatabaseHelper:
- _instance = None
- def __new__(cls):
- if not cls._instance:
- cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
- cls._instance._initialized = False
- return cls._instance
- def __init__(self):
- if self._initialized:
- return
- self._connect_database()
- self._initialized = True
- def _connect_database(self):
- try:
- conn_str = (
- f"mysql+pymysql://{settings.mysql_user}:{settings.mysql_password}"
- f"@{settings.mysql_host}:{settings.mysql_port}/{settings.mysql_db}"
- )
- self.engine = create_engine(
- conn_str,
- pool_size=20,
- max_overflow=30,
- pool_recycle=1800,
- pool_pre_ping=True,
- isolation_level="READ COMMITTED",
- )
- self._DBSession = sessionmaker(bind=self.engine)
- logger.info("MySQL connection pool created", extra={"extra_data": {"host": settings.mysql_host, "db": settings.mysql_db}})
- except Exception as e:
- logger.error("Failed to create MySQL connection", exc_info=True)
- raise DatabaseException(message="数据库连接失败", detail=str(e))
- @contextmanager
- def get_session(self):
- session = self._DBSession()
- try:
- yield session
- except SQLAlchemyError as e:
- session.rollback()
- logger.error("Database operation failed", exc_info=True)
- raise DatabaseException(message="数据库操作失败", detail=str(e))
- finally:
- session.close()
- def load_data_with_page(self, query, params, page_size=100000):
- """分页查询数据"""
- count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq")
- query += " LIMIT :limit OFFSET :offset"
- query = text(query)
- result = self.fetch_one(count_query, params)
- total_rows = result[0] if result is not None else 0
- if total_rows == 0:
- logger.debug("Query returned 0 rows")
- return pd.DataFrame()
- logger.debug(f"Loading {total_rows} rows with page_size={page_size}")
- data = pd.DataFrame()
- page = 1
- while True:
- offset = (page - 1) * page_size
- page_params = dict(params)
- page_params["limit"] = page_size
- page_params["offset"] = offset
- df = pd.DataFrame(self.fetch_all(query, page_params))
- if df.empty:
- break
- data = pd.concat([data, df], ignore_index=True)
- page += 1
- logger.debug(f"Loaded {len(data)} rows in {page - 1} pages")
- return data
- def fetch_all(self, query, params=None):
- """执行SQL查询并返回所有结果"""
- with self.get_session() as session:
- results = session.execute(query, params or {}).fetchall()
- return results
- def fetch_one(self, query, params=None):
- """执行SQL查询并返回单条结果"""
- with self.get_session() as session:
- result = session.execute(query, params or {}).fetchone()
- return result
- def insert_data(self, table_name, data_dict):
- """插入单条数据到指定表"""
- if not data_dict:
- return 0
- columns = ", ".join(data_dict.keys())
- values = ", ".join([f":{key}" for key in data_dict.keys()])
- query = text(f"INSERT INTO {table_name} ({columns}) VALUES ({values})")
- with self.get_session() as session:
- result = session.execute(query, data_dict)
- session.commit()
- logger.info(f"Inserted 1 row into {table_name}")
- return result.rowcount
- def update_data(self, table_name, update_dict, conditions, condition_params=None):
- """更新表中符合条件的数据"""
- if not update_dict:
- return 0
- set_clause = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
- if len(conditions) == 1:
- where_clause = f"WHERE {conditions[0]}"
- elif len(conditions) > 1:
- where_clause = f"WHERE {' AND '.join(conditions)}"
- else:
- where_clause = ""
- query = text(f"UPDATE {table_name} SET {set_clause} {where_clause}")
- params = update_dict.copy()
- if condition_params:
- params.update(condition_params)
- with self.get_session() as session:
- result = session.execute(query, params)
- session.commit()
- logger.info(f"Updated {result.rowcount} rows in {table_name}")
- return result.rowcount
- def execute_query(self, query, params=None):
- """执行SQL语句"""
- with self.get_session() as session:
- session.execute(query, params or {})
- session.commit()
- def check_connection(self) -> bool:
- """检查数据库连接是否正常"""
- try:
- self.fetch_one(text("SELECT 1"), {})
- return True
- except Exception:
- return False
|