|
|
@@ -0,0 +1,105 @@
|
|
|
+from config import load_config
|
|
|
+import pandas as pd
|
|
|
+from sqlalchemy import create_engine, text
|
|
|
+from sqlalchemy.orm import sessionmaker
|
|
|
+from sqlalchemy.exc import SQLAlchemyError
|
|
|
+
|
|
|
+cfgs = load_config()
|
|
|
+
|
|
|
+
|
|
|
+class MySqlDatabaseHelper:
|
|
|
+ _instance = None
|
|
|
+
|
|
|
+ def __new__(cls, *args, **kwargs):
|
|
|
+ if not cls._instance:
|
|
|
+ cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
|
|
|
+ cls._instance._initialized = False
|
|
|
+ return cls._instance
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ if self._initialized:
|
|
|
+ return
|
|
|
+
|
|
|
+ self._host = cfgs['mysql']['host']
|
|
|
+ self._port = cfgs['mysql']['port']
|
|
|
+ self._user = cfgs['mysql']['user']
|
|
|
+ self._passwd = cfgs['mysql']['passwd']
|
|
|
+ self._dbname = cfgs['mysql']['db']
|
|
|
+
|
|
|
+ self.connect_database()
|
|
|
+ self._initialized = True
|
|
|
+
|
|
|
+ def connect_database(self):
|
|
|
+ # 创建数据库连接
|
|
|
+ try:
|
|
|
+ conn = "mysql+pymysql://" + self._user + ":" + self._passwd + "@" + self._host + ":" + str(self._port) + "/" + self._dbname
|
|
|
+ except Exception as e:
|
|
|
+ raise ConnectionAbortedError(f"failed to create connection string: {e}")
|
|
|
+
|
|
|
+ # 通过连接池创建engine
|
|
|
+ self.engine = create_engine(
|
|
|
+ conn,
|
|
|
+ pool_size=10, # 设置连接池大小
|
|
|
+ max_overflow=20, # 超过连接池大小时的额外连接数
|
|
|
+ pool_recycle=3600 # 回收连接时间
|
|
|
+ )
|
|
|
+
|
|
|
+ self._DBSession = sessionmaker(bind=self.engine)
|
|
|
+
|
|
|
+ def load_data_with_page(self, query, params, page_size=1000):
|
|
|
+ """分页查询数据"""
|
|
|
+ data = pd.DataFrame()
|
|
|
+ query += " LIMIT :limit OFFSET :offset"
|
|
|
+ query = text(query)
|
|
|
+
|
|
|
+ page = 1
|
|
|
+ while True:
|
|
|
+ offset = (page - 1) * page_size # 计算偏移量
|
|
|
+ params["limit"] = page_size
|
|
|
+ params["offset"] = offset
|
|
|
+
|
|
|
+ df = pd.DataFrame(self.fetch_all(query, params))
|
|
|
+ if df.empty:
|
|
|
+ break
|
|
|
+ data = pd.concat([data, df], ignore_index=True)
|
|
|
+ print(f"Page {page}: Retrieved {len(df)} rows, Total rows so far: {len(data)}")
|
|
|
+
|
|
|
+ page += 1
|
|
|
+ return data
|
|
|
+
|
|
|
+
|
|
|
+ def fetch_all(self, query, params=None):
|
|
|
+ """执行SQL查询并返回所有结果"""
|
|
|
+ session = self._DBSession()
|
|
|
+ try:
|
|
|
+ results = session.execute(query, params or {}).fetchall()
|
|
|
+ return results
|
|
|
+ except SQLAlchemyError as e:
|
|
|
+ session.rollback()
|
|
|
+ print(f"error: {e}")
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+ def fetch_one(self, query, params=None):
|
|
|
+ """执行SQL查询并返回单条结果"""
|
|
|
+ session = self._DBSession()
|
|
|
+ try:
|
|
|
+ result = session.execute(query, params or {}).fetchone()
|
|
|
+ return result
|
|
|
+ except SQLAlchemyError as e:
|
|
|
+ session.rollback()
|
|
|
+ print(f"error: {e}")
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+ def execute_query(self, query, params=None):
|
|
|
+ """执行SQL语句 (无返回值, 如INSERT, UPDATE, DELETE)"""
|
|
|
+ session = self._DBSession()
|
|
|
+ try:
|
|
|
+ session.execute(query, params or {})
|
|
|
+ session.commit()
|
|
|
+ except SQLAlchemyError as e:
|
|
|
+ session.rollback()
|
|
|
+ print(f"Error: {e}")
|
|
|
+ finally:
|
|
|
+ session.close()
|