فهرست منبع

重构数据读取库代码

yangzeyu 1 سال پیش
والد
کامیت
c3e7ab7136
3فایلهای تغییر یافته به همراه175 افزوده شده و 0 حذف شده
  1. 5 0
      database/__init__.py
  2. 65 0
      database/dao/mysql_dao.py
  3. 105 0
      database/db/mysql.py

+ 5 - 0
database/__init__.py

@@ -0,0 +1,5 @@
+from database.db.mysql import MySqlDatabaseHelper
+
+__all__ = [
+    "MySqlDatabaseHelper"
+]

+ 65 - 0
database/dao/mysql_dao.py

@@ -0,0 +1,65 @@
+from database import MySqlDatabaseHelper
+
+class MySqlDao:
+    def __init__(self):
+        self.db_helper = MySqlDatabaseHelper()
+        
+    def load_product_data(self, city_uuid):
+        """从数据库中读取商品信息"""
+        tablename = "tads_brandcul_product_info"
+        
+        query = f"SELECT * FROM {tablename} WHERE city_uuid = :city_uuid"
+        params = {"city_uuid": city_uuid}
+        
+        data = self.db_helper.load_data_with_page(query, params)
+        return data
+        
+    def load_cust_data(self, city_uuid):
+        """从数据库中读取商户信息"""
+        tablename = "tads_brandcul_cust_info"
+        
+        query = f"SELECT * FROM {tablename} WHERE BA_CITY_ORG_CODE = :city_uuid"
+        params = {"city_uuid": city_uuid}
+        
+        data = self.db_helper.load_data_with_page(query, params)
+        return data
+    
+    def load_order_data(self, city_uuid):
+        """从数据库中读取订单信息"""
+        tablename = "tads_brandcul_cust_order"
+
+        query = f"SELECT * FROM {tablename} WHERE city_uuid = :city_uuid"
+        params = {"city_uuid": city_uuid}
+        
+        data = self.db_helper.load_data_with_page(query, params)
+        
+        data.drop('stat_month', axis=1, inplace=True)
+        data.drop('city_uuid', axis=1, inplace=True)
+        
+        # 去除重复值和填补缺失值
+        data.drop_duplicates(inplace=True)
+        data.fillna(0, inplace=True)
+        data = data.infer_objects(copy=False)
+        return data
+    
+    def load_mock_order_data(self, city_uuid):
+        """从数据库中读取mock的订单信息"""
+        tablename = "yunfu_mock_data"
+
+        query = f"SELECT * FROM {tablename}"
+        params = {"city_uuid": city_uuid}
+        
+        data = self.db_helper.load_data_with_page(query, params)
+        
+        # 去除重复值和填补缺失值
+        data.drop_duplicates(inplace=True)
+        data.fillna(0, inplace=True)
+        data = data.infer_objects(copy=False)
+        
+        return data
+        
+if __name__ == "__main__":
+    dao = MySqlDao()
+    city_uuid = "00000000000000000000000011445301"
+    # city_uuid = "00000000000000000000000011441801"
+    dao.load_mock_order_data(city_uuid)

+ 105 - 0
database/db/mysql.py

@@ -0,0 +1,105 @@
+from config import load_config
+import pandas as pd
+from sqlalchemy import create_engine, text
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.exc import SQLAlchemyError
+
+cfgs = load_config()
+
+
+class MySqlDatabaseHelper:
+    _instance = None
+    
+    def __new__(cls, *args, **kwargs):
+        if not cls._instance:
+            cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
+            cls._instance._initialized = False
+        return cls._instance
+        
+    def __init__(self):
+        if self._initialized:
+            return
+        
+        self._host = cfgs['mysql']['host']
+        self._port = cfgs['mysql']['port']
+        self._user = cfgs['mysql']['user']
+        self._passwd = cfgs['mysql']['passwd']
+        self._dbname = cfgs['mysql']['db']
+        
+        self.connect_database()
+        self._initialized = True
+        
+    def connect_database(self):
+        # 创建数据库连接
+        try:
+            conn = "mysql+pymysql://" + self._user + ":" + self._passwd + "@" + self._host + ":" + str(self._port) + "/" + self._dbname
+        except Exception as e:
+            raise ConnectionAbortedError(f"failed to create connection string: {e}")
+        
+        # 通过连接池创建engine
+        self.engine = create_engine(
+            conn,
+            pool_size=10, # 设置连接池大小
+            max_overflow=20, # 超过连接池大小时的额外连接数
+            pool_recycle=3600 # 回收连接时间
+        )
+        
+        self._DBSession = sessionmaker(bind=self.engine)
+        
+    def load_data_with_page(self, query, params, page_size=1000):
+        """分页查询数据"""
+        data = pd.DataFrame()
+        query += " LIMIT :limit OFFSET :offset"
+        query = text(query)
+        
+        page = 1
+        while True:
+            offset = (page - 1) * page_size # 计算偏移量
+            params["limit"] = page_size
+            params["offset"] = offset
+
+            df = pd.DataFrame(self.fetch_all(query, params))
+            if df.empty:
+                break
+            data = pd.concat([data, df], ignore_index=True)
+            print(f"Page {page}: Retrieved {len(df)} rows, Total rows so far: {len(data)}")
+            
+            page += 1
+        return data
+        
+        
+    def fetch_all(self, query, params=None):
+        """执行SQL查询并返回所有结果"""
+        session = self._DBSession()
+        try:
+            results = session.execute(query, params or {}).fetchall()
+            return results
+        except SQLAlchemyError as e:
+            session.rollback()
+            print(f"error: {e}")
+        finally:
+            session.close()
+            
+    def fetch_one(self, query, params=None):
+        """执行SQL查询并返回单条结果"""
+        session = self._DBSession()
+        try:
+            result = session.execute(query, params or {}).fetchone()
+            return result
+        except SQLAlchemyError as e:
+            session.rollback()
+            print(f"error: {e}")
+        finally:
+            session.close()
+            
+    def execute_query(self, query, params=None):
+        """执行SQL语句 (无返回值, 如INSERT, UPDATE, DELETE)"""
+        session = self._DBSession()
+        try:
+            session.execute(query, params or {})
+            session.commit()
+        except SQLAlchemyError as e:
+            session.rollback()
+            print(f"Error: {e}")
+        finally:
+            session.close()