mysql.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from config import laod_database_config
  2. import pandas as pd
  3. import re
  4. from sqlalchemy import create_engine, text
  5. from sqlalchemy.orm import sessionmaker
  6. from sqlalchemy.exc import SQLAlchemyError
  7. from tqdm import tqdm
  8. cfgs = laod_database_config()
  9. def replace_select_content(sql):
  10. """
  11. 将SELECT和FROM之间的内容替换为SELECT COUNT(原内容) FROM
  12. """
  13. pattern = r'SELECT\s+(.*?)\s+FROM'
  14. def replace_func(match):
  15. content = match.group(1) # 获取SELECT和FROM之间的内容
  16. # 如果是DISTINCT或ALL等特殊关键字,需要特殊处理
  17. if re.match(r'^(DISTINCT|ALL)\s+', content, re.IGNORECASE):
  18. # 保留DISTINCT/ALL关键字
  19. return f'SELECT COUNT({content}) FROM'
  20. else:
  21. return f'SELECT COUNT(*) FROM'
  22. # 使用re.sub进行替换,re.IGNORECASE忽略大小写
  23. result = re.sub(pattern, replace_func, sql, flags=re.IGNORECASE)
  24. return result
  25. class MySqlDatabaseHelper:
  26. _instance = None
  27. def __new__(cls):
  28. if not cls._instance:
  29. cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
  30. cls._instance._initialized = False
  31. return cls._instance
  32. def __init__(self):
  33. if self._initialized:
  34. return
  35. self._host = cfgs['mysql']['host']
  36. self._port = cfgs['mysql']['port']
  37. self._user = cfgs['mysql']['user']
  38. self._passwd = cfgs['mysql']['passwd']
  39. self._dbname = cfgs['mysql']['db']
  40. self.connect_database()
  41. self._initialized = True
  42. def connect_database(self):
  43. # 创建数据库连接
  44. try:
  45. conn = "mysql+pymysql://" + self._user + ":" + self._passwd + "@" + self._host + ":" + str(self._port) + "/" + self._dbname
  46. except Exception as e:
  47. raise ConnectionAbortedError(f"failed to create connection string: {e}")
  48. # 通过连接池创建engine
  49. self.engine = create_engine(
  50. conn,
  51. pool_size=20, # 设置连接池大小
  52. max_overflow=30, # 超过连接池大小时的额外连接数
  53. pool_recycle=1800, # 回收连接时间
  54. pool_pre_ping=True, # 防止断开连接
  55. isolation_level="READ COMMITTED" # 降低隔离级别
  56. )
  57. self._DBSession = sessionmaker(bind=self.engine)
  58. def load_data_with_page(self, query, params, page_size=100000):
  59. """分页查询数据"""
  60. data = pd.DataFrame()
  61. count_query = text(replace_select_content(query))
  62. query += " LIMIT :limit OFFSET :offset"
  63. query = text(query)
  64. print(count_query)
  65. # 获取总行数
  66. total_rows = self.fetch_one(count_query, params)[0]
  67. page = 1
  68. with tqdm(total=total_rows, desc="Loading data", unit="rows") as pbar: # 初始化进度条
  69. while True:
  70. offset = (page - 1) * page_size # 计算偏移量
  71. params["limit"] = page_size
  72. params["offset"] = offset
  73. df = pd.DataFrame(self.fetch_all(query, params))
  74. if df.empty:
  75. break
  76. data = pd.concat([data, df], ignore_index=True)
  77. # 更新进度条
  78. pbar.update(len(df)) # 更新进度条的行数
  79. page += 1
  80. return data
  81. def fetch_all(self, query, params=None):
  82. """执行SQL查询并返回所有结果"""
  83. session = self._DBSession()
  84. try:
  85. results = session.execute(query, params or {}).fetchall()
  86. return results
  87. except SQLAlchemyError as e:
  88. session.rollback()
  89. print(f"error: {e}")
  90. finally:
  91. session.close()
  92. def fetch_one(self, query, params=None):
  93. """执行SQL查询并返回单条结果"""
  94. session = self._DBSession()
  95. try:
  96. result = session.execute(query, params or {}).fetchone()
  97. return result
  98. except SQLAlchemyError as e:
  99. session.rollback()
  100. print(f"error: {e}")
  101. finally:
  102. session.close()
  103. def insert_data(self, table_name, data_dict):
  104. """插入单条数据到指定表"""
  105. if not data_dict:
  106. return 0
  107. columns = ", ".join(data_dict.keys())
  108. values = ", ".join([f":{key}" for key in data_dict.keys()])
  109. query = text(f"INSERT INTO {table_name} ({columns}) VALUES ({values})")
  110. session = self._DBSession()
  111. try:
  112. result = session.execute(query, data_dict)
  113. session.commit()
  114. return result.rowcount
  115. except SQLAlchemyError as e:
  116. session.rollback()
  117. print(f"Error inserting data: {e}")
  118. return 0
  119. finally:
  120. session.close()
  121. def update_data(self, table_name, update_dict, conditions, condition_params=None):
  122. """更新表中符合条件的数据"""
  123. if not update_dict:
  124. return 0
  125. set_clause = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
  126. if len(conditions) == 1:
  127. where_clause = f"WHERE {conditions[0]}"
  128. elif len(conditions) > 1:
  129. where_clause = f"WHERE {' AND '.join(conditions)}"
  130. else:
  131. where_clause = ""
  132. query = text(f"UPDATE {table_name} SET {set_clause} {where_clause}")
  133. params = update_dict.copy()
  134. if condition_params:
  135. params.update(condition_params)
  136. session = self._DBSession()
  137. try:
  138. result = session.execute(query, params)
  139. session.commit()
  140. return result.rowcount
  141. except SQLAlchemyError as e:
  142. session.rollback()
  143. print(f"Error updating data: {e}")
  144. return 0
  145. finally:
  146. session.close()
  147. def execute_query(self, query, params=None):
  148. """执行SQL语句 (无返回值, 如INSERT, UPDATE, DELETE)"""
  149. session = self._DBSession()
  150. try:
  151. session.execute(query, params or {})
  152. session.commit()
  153. except SQLAlchemyError as e:
  154. session.rollback()
  155. print(f"Error: {e}")
  156. finally:
  157. session.close()