2026-03-15-database-bugfix.md 10 KB

数据库操作 Bug 修复实现计划

For agentic workers: REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (- [ ]) syntax for tracking.

目标: 修复数据库操作层的 6 个 bug,消除 SQL 注入风险,改进错误处理

架构: 外科手术式修复,不改变现有架构。修复顺序:修复 2 → 修复 1/3 → 修复 6 → 修复 4/5

技术栈: SQLAlchemy, pandas, pymysql


Chunk 1: mysql.py 修复

任务 1: 修复 load_data_with_page 的 count 查询和 params 副作用(修复 2)

文件:

  • 修改: database/db/mysql.py:52-78

  • [ ] 步骤 1: 修改 load_data_with_page 方法

将第 52-78 行的 load_data_with_page 方法替换为:

def load_data_with_page(self, query, params, page_size=100000):
    """分页查询数据"""
    data = pd.DataFrame()
    # 用子查询包裹原始查询来计数,避免字符串替换
    count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq")
    query += " LIMIT :limit OFFSET :offset"
    query = text(query)

    # 获取总行数
    total_rows = self.fetch_one(count_query, params)[0]

    page = 1
    with tqdm(total=total_rows, desc="Loading data", unit="rows") as pbar:
        while True:
            offset = (page - 1) * page_size
            # 复制 params 避免修改调用方的字典
            page_params = dict(params)
            page_params["limit"] = page_size
            page_params["offset"] = offset

            df = pd.DataFrame(self.fetch_all(query, page_params))
            if df.empty:
                break
            data = pd.concat([data, df], ignore_index=True)

            pbar.update(len(df))

            page += 1
    return data
  • [ ] 步骤 2: 提交修复 2

    git add database/db/mysql.py
    git commit -m "fix: 修复 load_data_with_page 的 count 查询和 params 副作用
    
    - 用子查询包裹原始查询来计数,不再依赖字符串替换
    - 每次分页前复制 params 避免修改调用方的字典"
    

任务 2: 修复 connect_database 和 fetch 方法的错误处理(修复 1 和 3)

文件:

  • 修改: database/db/mysql.py:33-48, 81-104

  • [ ] 步骤 1: 修改 connect_database 方法

将第 33-48 行的 connect_database 方法替换为:

def connect_database(self):
    # 创建数据库连接
    try:
        conn = "mysql+pymysql://" + self._user + ":" + self._passwd + "@" + self._host + ":" + str(self._port) + "/" + self._dbname

        # 通过连接池创建engine
        self.engine = create_engine(
            conn,
            pool_size=20, # 设置连接池大小
            max_overflow=30, # 超过连接池大小时的额外连接数
            pool_recycle=1800, # 回收连接时间
            pool_pre_ping=True, # 防止断开连接
            isolation_level="READ COMMITTED" # 降低隔离级别
        )
    except Exception as e:
        raise ConnectionAbortedError(f"failed to create connection: {e}")

    self._DBSession = sessionmaker(bind=self.engine)
  • 步骤 2: 修改 fetch_all 方法

将第 81-91 行的 fetch_all 方法替换为:

def fetch_all(self, query, params=None):
    """执行SQL查询并返回所有结果"""
    session = self._DBSession()
    try:
        results = session.execute(query, params or {}).fetchall()
        return results
    except SQLAlchemyError as e:
        session.rollback()
        print(f"error: {e}")
        raise
    finally:
        session.close()
  • 步骤 3: 修改 fetch_one 方法

将第 93-104 行的 fetch_one 方法替换为:

def fetch_one(self, query, params=None):
    """执行SQL查询并返回单条结果"""
    session = self._DBSession()
    try:
        result = session.execute(query, params or {}).fetchone()
        return result

    except SQLAlchemyError as e:
        session.rollback()
        print(f"error: {e}")
        raise
    finally:
        session.close()
  • [ ] 步骤 4: 提交修复 1 和 3

    git add database/db/mysql.py
    git commit -m "fix: 修复 connect_database 和 fetch 方法的错误处理
    
    修复 1: 将 create_engine 移入 try 块,确保连接失败被捕获
    修复 3: 在 fetch_all/fetch_one 的 except 块中添加 raise
    
    注意:修复 3 是行为破坏性变更,调用方现在会收到异常而非 None"
    

Chunk 2: mysql_dao.py 修复

任务 3: 修复 get_cust_list 和 get_product_from_order 的接口不一致(修复 6)

文件:

  • 修改: database/dao/mysql_dao.py:251-278

  • [ ] 步骤 1: 修改 get_product_from_order 方法

将第 251-257 行的 get_product_from_order 方法替换为:

def get_product_from_order(self, city_uuid):
    query = f"SELECT DISTINCT product_code FROM {self._order_tablename} WHERE city_uuid = :city_uuid"
    params = {"city_uuid": city_uuid}

    data = self.db_helper.load_data_with_page(query, params)

    return data
  • 步骤 2: 修改 get_cust_list 方法

将第 272-278 行的 get_cust_list 方法替换为:

def get_cust_list(self, city_uuid):
    query = f"SELECT DISTINCT BB_RETAIL_CUSTOMER_CODE FROM {self._cust_tablename} WHERE BA_CITY_ORG_CODE = :city_uuid"
    params = {"city_uuid": city_uuid}

    data = self.db_helper.load_data_with_page(query, params)

    return data
  • [ ] 步骤 3: 提交修复 6

    git add database/dao/mysql_dao.py
    git commit -m "fix: 统一 get_cust_list 和 get_product_from_order 的接口
    
    - 改用 load_data_with_page 替代直接调用 fetch_all
    - 与其他查询方法保持一致"
    

任务 4: 修复所有 IN 子句的 SQL 注入风险(修复 4)

文件:

  • 修改: database/dao/mysql_dao.py:2, 105-175

  • [ ] 步骤 1: 添加 bindparam 导入

将第 2 行从:

from sqlalchemy import text

修改为:

from sqlalchemy import text, bindparam
  • 步骤 2: 修改 get_cust_by_ids 方法

将第 105-120 行的 get_cust_by_ids 方法替换为:

def get_cust_by_ids(self, city_uuid, cust_id_list):
    """根据零售户列表查询其信息"""
    if not cust_id_list:
        return None

    query = text(f"""
        SELECT *
        FROM {self._cust_tablename}
        WHERE BA_CITY_ORG_CODE = :city_uuid
        AND BB_RETAIL_CUSTOMER_CODE IN :ids
    """).bindparams(bindparam("ids", expanding=True))
    params = {"city_uuid": city_uuid, "ids": list(cust_id_list)}
    data = pd.DataFrame(self.db_helper.fetch_all(query, params))

    return data
  • 步骤 3: 修改 get_shop_by_ids 方法

将第 122-137 行的 get_shop_by_ids 方法替换为:

def get_shop_by_ids(self, city_uuid, cust_id_list):
    """根据零售户列表查询其信息"""
    if not cust_id_list:
        return None

    query = text(f"""
        SELECT *
        FROM {self._shopping_tablename}
        WHERE city_uuid = :city_uuid
        AND cust_code IN :ids
    """).bindparams(bindparam("ids", expanding=True))
    params = {"city_uuid": city_uuid, "ids": list(cust_id_list)}
    data = pd.DataFrame(self.db_helper.fetch_all(query, params))

    return data
  • 步骤 4: 修改 get_product_by_ids 方法

将第 139-154 行的 get_product_by_ids 方法替换为:

def get_product_by_ids(self, city_uuid, product_id_list):
    """根据product_code列表查询其信息"""
    if not product_id_list:
        return None

    query = text(f"""
        SELECT *
        FROM {self._product_tablename}
        WHERE city_uuid = :city_uuid
        AND product_code IN :ids
    """).bindparams(bindparam("ids", expanding=True))
    params = {"city_uuid": city_uuid, "ids": list(product_id_list)}
    data = pd.DataFrame(self.db_helper.fetch_all(query, params))

    return data
  • 步骤 5: 修改 get_order_by_product_ids 方法

将第 156-175 行的 get_order_by_product_ids 方法替换为:

def get_order_by_product_ids(self, city_uuid, product_ids):
    """获取指定香烟列表的所有售卖记录"""
    if not product_ids:
        return None

    query = text(f"""
        SELECT *
        FROM {self._order_tablename}
        WHERE city_uuid = :city_uuid
        AND product_code IN :ids
    """).bindparams(bindparam("ids", expanding=True))
    params = {"city_uuid": city_uuid, "ids": list(product_ids)}
    data = pd.DataFrame(self.db_helper.fetch_all(query, params))

    cust_list = self.get_cust_list(city_uuid)
    cust_index = cust_list.set_index("BB_RETAIL_CUSTOMER_CODE")
    data = data.join(cust_index, on="cust_code", how="inner")

    return data
  • [ ] 步骤 6: 提交修复 4

    git add database/dao/mysql_dao.py
    git commit -m "fix: 修复所有 IN 子句的 SQL 注入风险
    
    - 用 bindparam(expanding=True) 替代字符串拼接
    - 修复方法: get_cust_by_ids, get_shop_by_ids, get_product_by_ids, get_order_by_product_ids
    - 改用 fetch_all 直接查询,跳过分页(结果集大小由输入列表决定)"
    

任务 5: 修复 get_product_by_id 的分页开销(修复 5)

文件:

  • 修改: database/dao/mysql_dao.py:92-103

  • [ ] 步骤 1: 修改 get_product_by_id 方法

将第 92-103 行的 get_product_by_id 方法替换为:

def get_product_by_id(self, city_uuid, product_id):
    """根据city_uuid 和 product_id 从表中获取拼柜信息"""
    query = text(f"""
        SELECT *
        FROM {self._product_tablename}
        WHERE city_uuid = :city_uuid
        AND product_code = :product_id
    """)
    params = {"city_uuid": city_uuid, "product_id": product_id}
    result = self.db_helper.fetch_one(query, params)
    return pd.DataFrame([dict(result._mapping)] if result else [])
  • [ ] 步骤 2: 提交修复 5

    git add database/dao/mysql_dao.py
    git commit -m "fix: 修复 get_product_by_id 的分页开销
    
    - 改用 fetch_one 查询单条记录
    - 返回单行 DataFrame 保持接口一致"
    

验证

  • [ ] 步骤 1: 检查所有提交

    git log --oneline -5
    

预期:看到 5 个修复提交(修复 2, 修复 1+3, 修复 6, 修复 4, 修复 5)

  • [ ] 步骤 2: 检查修改的文件

    git diff HEAD~5 --stat
    

预期:

database/db/mysql.py       | 修改行数
database/dao/mysql_dao.py  | 修改行数
2 files changed
  • 步骤 3: 最终确认

所有 6 个修复已完成,按照依赖顺序执行。