|
|
@@ -1,209 +1,154 @@
|
|
|
-from config import load_config
|
|
|
-import pandas as pd
|
|
|
-from sqlalchemy import create_engine, text
|
|
|
-from sqlalchemy.orm import sessionmaker
|
|
|
-from sqlalchemy.exc import SQLAlchemyError
|
|
|
-from tqdm import tqdm
|
|
|
-
|
|
|
-cfgs = load_config()
|
|
|
-
|
|
|
-
|
|
|
-class MySqlDatabaseHelper:
|
|
|
- _instance = None
|
|
|
-
|
|
|
- def __new__(cls):
|
|
|
- if not cls._instance:
|
|
|
- cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
|
|
|
- cls._instance._initialized = False
|
|
|
- return cls._instance
|
|
|
-
|
|
|
- def __init__(self):
|
|
|
- if self._initialized:
|
|
|
- return
|
|
|
-
|
|
|
- self._host = cfgs['mysql']['host']
|
|
|
- self._port = cfgs['mysql']['port']
|
|
|
- self._user = cfgs['mysql']['user']
|
|
|
- self._passwd = cfgs['mysql']['passwd']
|
|
|
- self._dbname = cfgs['mysql']['db']
|
|
|
-
|
|
|
- self.connect_database()
|
|
|
- self._initialized = True
|
|
|
-
|
|
|
- def connect_database(self):
|
|
|
- # 创建数据库连接
|
|
|
- try:
|
|
|
- conn = "mysql+pymysql://" + self._user + ":" + self._passwd + "@" + self._host + ":" + str(self._port) + "/" + self._dbname
|
|
|
-
|
|
|
- # 通过连接池创建engine
|
|
|
- self.engine = create_engine(
|
|
|
- conn,
|
|
|
- pool_size=20, # 设置连接池大小
|
|
|
- max_overflow=30, # 超过连接池大小时的额外连接数
|
|
|
- pool_recycle=1800, # 回收连接时间
|
|
|
- pool_pre_ping=True, # 防止断开连接
|
|
|
- isolation_level="READ COMMITTED" # 降低隔离级别
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- raise ConnectionAbortedError(f"failed to create connection: {e}")
|
|
|
-
|
|
|
- self._DBSession = sessionmaker(bind=self.engine)
|
|
|
-
|
|
|
- def load_data_with_page(self, query, params, page_size=100000):
|
|
|
- """分页查询数据"""
|
|
|
- data = pd.DataFrame()
|
|
|
- # 用子查询包裹原始查询来计数,避免字符串替换
|
|
|
- count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq")
|
|
|
- query += " LIMIT :limit OFFSET :offset"
|
|
|
- query = text(query)
|
|
|
-
|
|
|
- # 获取总行数
|
|
|
- result = self.fetch_one(count_query, params)
|
|
|
- total_rows = result[0] if result is not None else 0
|
|
|
-
|
|
|
- if total_rows == 0:
|
|
|
- return data
|
|
|
-
|
|
|
- page = 1
|
|
|
- with tqdm(total=total_rows, desc="Loading data", unit="rows") as pbar:
|
|
|
- while True:
|
|
|
- offset = (page - 1) * page_size
|
|
|
- # 复制 params 避免修改调用方的字典
|
|
|
- page_params = dict(params)
|
|
|
- page_params["limit"] = page_size
|
|
|
- page_params["offset"] = offset
|
|
|
-
|
|
|
- df = pd.DataFrame(self.fetch_all(query, page_params))
|
|
|
- if df.empty:
|
|
|
- break
|
|
|
- data = pd.concat([data, df], ignore_index=True)
|
|
|
-
|
|
|
- pbar.update(len(df))
|
|
|
-
|
|
|
- page += 1
|
|
|
- return data
|
|
|
-
|
|
|
-
|
|
|
- def fetch_all(self, query, params=None):
|
|
|
- """执行SQL查询并返回所有结果"""
|
|
|
- session = self._DBSession()
|
|
|
- try:
|
|
|
- results = session.execute(query, params or {}).fetchall()
|
|
|
- return results
|
|
|
- except SQLAlchemyError as e:
|
|
|
- session.rollback()
|
|
|
- print(f"error: {e}")
|
|
|
- raise
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
-
|
|
|
- def fetch_one(self, query, params=None):
|
|
|
- """执行SQL查询并返回单条结果"""
|
|
|
- session = self._DBSession()
|
|
|
- try:
|
|
|
- result = session.execute(query, params or {}).fetchone()
|
|
|
- return result
|
|
|
-
|
|
|
- except SQLAlchemyError as e:
|
|
|
- session.rollback()
|
|
|
- print(f"error: {e}")
|
|
|
- raise
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
-
|
|
|
- def insert_data(self, table_name, data_dict):
|
|
|
- """插入单条数据到指定表"""
|
|
|
- if not data_dict:
|
|
|
- return 0
|
|
|
-
|
|
|
- columns = ", ".join(data_dict.keys())
|
|
|
- values = ", ".join([f":{key}" for key in data_dict.keys()])
|
|
|
- query = text(f"INSERT INTO {table_name} ({columns}) VALUES ({values})")
|
|
|
-
|
|
|
- session = self._DBSession()
|
|
|
-
|
|
|
- try:
|
|
|
- result = session.execute(query, data_dict)
|
|
|
- session.commit()
|
|
|
- return result.rowcount
|
|
|
-
|
|
|
- except SQLAlchemyError as e:
|
|
|
- session.rollback()
|
|
|
- print(f"Error inserting data: {e}")
|
|
|
- return 0
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
-
|
|
|
- def update_data(self, table_name, update_dict, conditions, condition_params=None):
|
|
|
- """更新表中符合条件的数据"""
|
|
|
- if not update_dict:
|
|
|
- return 0
|
|
|
-
|
|
|
- set_clause = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
|
|
|
-
|
|
|
- if len(conditions) == 1:
|
|
|
- where_clause = f"WHERE {conditions[0]}"
|
|
|
- elif len(conditions) > 1:
|
|
|
- where_clause = f"WHERE {' AND '.join(conditions)}"
|
|
|
- else:
|
|
|
- where_clause = ""
|
|
|
-
|
|
|
- query = text(f"UPDATE {table_name} SET {set_clause} {where_clause}")
|
|
|
-
|
|
|
- params = update_dict.copy()
|
|
|
- if condition_params:
|
|
|
- params.update(condition_params)
|
|
|
-
|
|
|
- session = self._DBSession()
|
|
|
- try:
|
|
|
- result = session.execute(query, params)
|
|
|
- session.commit()
|
|
|
- return result.rowcount
|
|
|
- except SQLAlchemyError as e:
|
|
|
- session.rollback()
|
|
|
- print(f"Error updating data: {e}")
|
|
|
- return 0
|
|
|
-
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
-
|
|
|
- def execute_query(self, query, params=None):
|
|
|
- """执行SQL语句 (无返回值, 如INSERT, UPDATE, DELETE)"""
|
|
|
- session = self._DBSession()
|
|
|
- try:
|
|
|
- session.execute(query, params or {})
|
|
|
- session.commit()
|
|
|
- except SQLAlchemyError as e:
|
|
|
- session.rollback()
|
|
|
- print(f"Error: {e}")
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
-
|
|
|
-if __name__ == '__main__':
|
|
|
- db_helper = MySqlDatabaseHelper()
|
|
|
-
|
|
|
- table_name = 'tads_brandcul_report'
|
|
|
- data_dict = {
|
|
|
- 'cultivacation_id': 10000002,
|
|
|
- 'city_uuid': '00000000000000000000000011445301',
|
|
|
- 'limit_cycle_name': '202505W1(05.05-05.11)',
|
|
|
- 'product_code': '440298',
|
|
|
- 'product_info_table': 'D72E3FAE8DCE4270BD23C3EC015C0A35',
|
|
|
- 'relation_table': 'AD889019FD4F4EE7B887981162BA09EC',
|
|
|
- 'similarity_product_table': 'CE436AC24D96461FA0C091CB01E9BC05',
|
|
|
- 'recommend_table': 'A7C5918B8DDB4BEA9D921936955CBAF6',
|
|
|
- }
|
|
|
-
|
|
|
- # db_helper.insert_data(table_name, data_dict)
|
|
|
-
|
|
|
- update_data = {"val_table": "A7C5918B8DDB4BEA9D921936955CBAF6"}
|
|
|
- conditions = [
|
|
|
- "cultivacation_id = :cultivacation_id",
|
|
|
- "city_uuid = :city_uuid"
|
|
|
- ]
|
|
|
- condition_params = {
|
|
|
- 'cultivacation_id': 10000001,
|
|
|
- 'city_uuid': '00000000000000000000000011445301',
|
|
|
- }
|
|
|
-
|
|
|
- db_helper.update_data(table_name, update_data, conditions, condition_params)
|
|
|
+from contextlib import contextmanager
|
|
|
+from core import get_logger, settings, DatabaseException
|
|
|
+import pandas as pd
|
|
|
+from sqlalchemy import create_engine, text
|
|
|
+from sqlalchemy.orm import sessionmaker
|
|
|
+from sqlalchemy.exc import SQLAlchemyError
|
|
|
+
|
|
|
+logger = get_logger("database.mysql")
|
|
|
+
|
|
|
+
|
|
|
+class MySqlDatabaseHelper:
|
|
|
+ _instance = None
|
|
|
+
|
|
|
+ def __new__(cls):
|
|
|
+ if not cls._instance:
|
|
|
+ cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
|
|
|
+ cls._instance._initialized = False
|
|
|
+ return cls._instance
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ if self._initialized:
|
|
|
+ return
|
|
|
+ self._connect_database()
|
|
|
+ self._initialized = True
|
|
|
+
|
|
|
+ def _connect_database(self):
|
|
|
+ try:
|
|
|
+ conn_str = (
|
|
|
+ f"mysql+pymysql://{settings.mysql_user}:{settings.mysql_password}"
|
|
|
+ f"@{settings.mysql_host}:{settings.mysql_port}/{settings.mysql_db}"
|
|
|
+ )
|
|
|
+ self.engine = create_engine(
|
|
|
+ conn_str,
|
|
|
+ pool_size=20,
|
|
|
+ max_overflow=30,
|
|
|
+ pool_recycle=1800,
|
|
|
+ pool_pre_ping=True,
|
|
|
+ isolation_level="READ COMMITTED",
|
|
|
+ )
|
|
|
+ self._DBSession = sessionmaker(bind=self.engine)
|
|
|
+ logger.info("MySQL connection pool created", extra={"extra_data": {"host": settings.mysql_host, "db": settings.mysql_db}})
|
|
|
+ except Exception as e:
|
|
|
+ logger.error("Failed to create MySQL connection", exc_info=True)
|
|
|
+ raise DatabaseException(message="数据库连接失败", detail=str(e))
|
|
|
+
|
|
|
+ @contextmanager
|
|
|
+ def get_session(self):
|
|
|
+ session = self._DBSession()
|
|
|
+ try:
|
|
|
+ yield session
|
|
|
+ except SQLAlchemyError as e:
|
|
|
+ session.rollback()
|
|
|
+ logger.error("Database operation failed", exc_info=True)
|
|
|
+ raise DatabaseException(message="数据库操作失败", detail=str(e))
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+ def load_data_with_page(self, query, params, page_size=100000):
|
|
|
+ """分页查询数据"""
|
|
|
+ count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq")
|
|
|
+ query += " LIMIT :limit OFFSET :offset"
|
|
|
+ query = text(query)
|
|
|
+
|
|
|
+ result = self.fetch_one(count_query, params)
|
|
|
+ total_rows = result[0] if result is not None else 0
|
|
|
+
|
|
|
+ if total_rows == 0:
|
|
|
+ logger.debug("Query returned 0 rows")
|
|
|
+ return pd.DataFrame()
|
|
|
+
|
|
|
+ logger.debug(f"Loading {total_rows} rows with page_size={page_size}")
|
|
|
+ data = pd.DataFrame()
|
|
|
+ page = 1
|
|
|
+ while True:
|
|
|
+ offset = (page - 1) * page_size
|
|
|
+ page_params = dict(params)
|
|
|
+ page_params["limit"] = page_size
|
|
|
+ page_params["offset"] = offset
|
|
|
+
|
|
|
+ df = pd.DataFrame(self.fetch_all(query, page_params))
|
|
|
+ if df.empty:
|
|
|
+ break
|
|
|
+ data = pd.concat([data, df], ignore_index=True)
|
|
|
+ page += 1
|
|
|
+
|
|
|
+ logger.debug(f"Loaded {len(data)} rows in {page - 1} pages")
|
|
|
+ return data
|
|
|
+
|
|
|
+ def fetch_all(self, query, params=None):
|
|
|
+ """执行SQL查询并返回所有结果"""
|
|
|
+ with self.get_session() as session:
|
|
|
+ results = session.execute(query, params or {}).fetchall()
|
|
|
+ return results
|
|
|
+
|
|
|
+ def fetch_one(self, query, params=None):
|
|
|
+ """执行SQL查询并返回单条结果"""
|
|
|
+ with self.get_session() as session:
|
|
|
+ result = session.execute(query, params or {}).fetchone()
|
|
|
+ return result
|
|
|
+
|
|
|
+ def insert_data(self, table_name, data_dict):
|
|
|
+ """插入单条数据到指定表"""
|
|
|
+ if not data_dict:
|
|
|
+ return 0
|
|
|
+
|
|
|
+ columns = ", ".join(data_dict.keys())
|
|
|
+ values = ", ".join([f":{key}" for key in data_dict.keys()])
|
|
|
+ query = text(f"INSERT INTO {table_name} ({columns}) VALUES ({values})")
|
|
|
+
|
|
|
+ with self.get_session() as session:
|
|
|
+ result = session.execute(query, data_dict)
|
|
|
+ session.commit()
|
|
|
+ logger.info(f"Inserted 1 row into {table_name}")
|
|
|
+ return result.rowcount
|
|
|
+
|
|
|
+ def update_data(self, table_name, update_dict, conditions, condition_params=None):
|
|
|
+ """更新表中符合条件的数据"""
|
|
|
+ if not update_dict:
|
|
|
+ return 0
|
|
|
+
|
|
|
+ set_clause = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
|
|
|
+
|
|
|
+ if len(conditions) == 1:
|
|
|
+ where_clause = f"WHERE {conditions[0]}"
|
|
|
+ elif len(conditions) > 1:
|
|
|
+ where_clause = f"WHERE {' AND '.join(conditions)}"
|
|
|
+ else:
|
|
|
+ where_clause = ""
|
|
|
+
|
|
|
+ query = text(f"UPDATE {table_name} SET {set_clause} {where_clause}")
|
|
|
+
|
|
|
+ params = update_dict.copy()
|
|
|
+ if condition_params:
|
|
|
+ params.update(condition_params)
|
|
|
+
|
|
|
+ with self.get_session() as session:
|
|
|
+ result = session.execute(query, params)
|
|
|
+ session.commit()
|
|
|
+ logger.info(f"Updated {result.rowcount} rows in {table_name}")
|
|
|
+ return result.rowcount
|
|
|
+
|
|
|
+ def execute_query(self, query, params=None):
|
|
|
+ """执行SQL语句"""
|
|
|
+ with self.get_session() as session:
|
|
|
+ session.execute(query, params or {})
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ def check_connection(self) -> bool:
|
|
|
+ """检查数据库连接是否正常"""
|
|
|
+ try:
|
|
|
+ self.fetch_one(text("SELECT 1"), {})
|
|
|
+ return True
|
|
|
+ except Exception:
|
|
|
+ return False
|