# 数据库操作 Bug 修复实现计划 > **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. **目标:** 修复数据库操作层的 6 个 bug,消除 SQL 注入风险,改进错误处理 **架构:** 外科手术式修复,不改变现有架构。修复顺序:修复 2 → 修复 1/3 → 修复 6 → 修复 4/5 **技术栈:** SQLAlchemy, pandas, pymysql --- ## Chunk 1: mysql.py 修复 ### 任务 1: 修复 load_data_with_page 的 count 查询和 params 副作用(修复 2) **文件:** - 修改: `database/db/mysql.py:52-78` - [ ] **步骤 1: 修改 load_data_with_page 方法** 将第 52-78 行的 `load_data_with_page` 方法替换为: ```python def load_data_with_page(self, query, params, page_size=100000): """分页查询数据""" data = pd.DataFrame() # 用子查询包裹原始查询来计数,避免字符串替换 count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq") query += " LIMIT :limit OFFSET :offset" query = text(query) # 获取总行数 total_rows = self.fetch_one(count_query, params)[0] page = 1 with tqdm(total=total_rows, desc="Loading data", unit="rows") as pbar: while True: offset = (page - 1) * page_size # 复制 params 避免修改调用方的字典 page_params = dict(params) page_params["limit"] = page_size page_params["offset"] = offset df = pd.DataFrame(self.fetch_all(query, page_params)) if df.empty: break data = pd.concat([data, df], ignore_index=True) pbar.update(len(df)) page += 1 return data ``` - [ ] **步骤 2: 提交修复 2** ```bash git add database/db/mysql.py git commit -m "fix: 修复 load_data_with_page 的 count 查询和 params 副作用 - 用子查询包裹原始查询来计数,不再依赖字符串替换 - 每次分页前复制 params 避免修改调用方的字典" ``` ### 任务 2: 修复 connect_database 和 fetch 方法的错误处理(修复 1 和 3) **文件:** - 修改: `database/db/mysql.py:33-48, 81-104` - [ ] **步骤 1: 修改 connect_database 方法** 将第 33-48 行的 `connect_database` 方法替换为: ```python def connect_database(self): # 创建数据库连接 try: conn = "mysql+pymysql://" + self._user + ":" + self._passwd + "@" + self._host + ":" + str(self._port) + "/" + self._dbname # 通过连接池创建engine self.engine = create_engine( conn, pool_size=20, # 设置连接池大小 max_overflow=30, # 超过连接池大小时的额外连接数 pool_recycle=1800, # 回收连接时间 pool_pre_ping=True, # 防止断开连接 isolation_level="READ COMMITTED" # 降低隔离级别 ) except Exception as e: raise ConnectionAbortedError(f"failed to create connection: {e}") self._DBSession = sessionmaker(bind=self.engine) ``` - [ ] **步骤 2: 修改 fetch_all 方法** 将第 81-91 行的 `fetch_all` 方法替换为: ```python def fetch_all(self, query, params=None): """执行SQL查询并返回所有结果""" session = self._DBSession() try: results = session.execute(query, params or {}).fetchall() return results except SQLAlchemyError as e: session.rollback() print(f"error: {e}") raise finally: session.close() ``` - [ ] **步骤 3: 修改 fetch_one 方法** 将第 93-104 行的 `fetch_one` 方法替换为: ```python def fetch_one(self, query, params=None): """执行SQL查询并返回单条结果""" session = self._DBSession() try: result = session.execute(query, params or {}).fetchone() return result except SQLAlchemyError as e: session.rollback() print(f"error: {e}") raise finally: session.close() ``` - [ ] **步骤 4: 提交修复 1 和 3** ```bash git add database/db/mysql.py git commit -m "fix: 修复 connect_database 和 fetch 方法的错误处理 修复 1: 将 create_engine 移入 try 块,确保连接失败被捕获 修复 3: 在 fetch_all/fetch_one 的 except 块中添加 raise 注意:修复 3 是行为破坏性变更,调用方现在会收到异常而非 None" ``` --- ## Chunk 2: mysql_dao.py 修复 ### 任务 3: 修复 get_cust_list 和 get_product_from_order 的接口不一致(修复 6) **文件:** - 修改: `database/dao/mysql_dao.py:251-278` - [ ] **步骤 1: 修改 get_product_from_order 方法** 将第 251-257 行的 `get_product_from_order` 方法替换为: ```python def get_product_from_order(self, city_uuid): query = f"SELECT DISTINCT product_code FROM {self._order_tablename} WHERE city_uuid = :city_uuid" params = {"city_uuid": city_uuid} data = self.db_helper.load_data_with_page(query, params) return data ``` - [ ] **步骤 2: 修改 get_cust_list 方法** 将第 272-278 行的 `get_cust_list` 方法替换为: ```python def get_cust_list(self, city_uuid): query = f"SELECT DISTINCT BB_RETAIL_CUSTOMER_CODE FROM {self._cust_tablename} WHERE BA_CITY_ORG_CODE = :city_uuid" params = {"city_uuid": city_uuid} data = self.db_helper.load_data_with_page(query, params) return data ``` - [ ] **步骤 3: 提交修复 6** ```bash git add database/dao/mysql_dao.py git commit -m "fix: 统一 get_cust_list 和 get_product_from_order 的接口 - 改用 load_data_with_page 替代直接调用 fetch_all - 与其他查询方法保持一致" ``` ### 任务 4: 修复所有 IN 子句的 SQL 注入风险(修复 4) **文件:** - 修改: `database/dao/mysql_dao.py:2, 105-175` - [ ] **步骤 1: 添加 bindparam 导入** 将第 2 行从: ```python from sqlalchemy import text ``` 修改为: ```python from sqlalchemy import text, bindparam ``` - [ ] **步骤 2: 修改 get_cust_by_ids 方法** 将第 105-120 行的 `get_cust_by_ids` 方法替换为: ```python def get_cust_by_ids(self, city_uuid, cust_id_list): """根据零售户列表查询其信息""" if not cust_id_list: return None query = text(f""" SELECT * FROM {self._cust_tablename} WHERE BA_CITY_ORG_CODE = :city_uuid AND BB_RETAIL_CUSTOMER_CODE IN :ids """).bindparams(bindparam("ids", expanding=True)) params = {"city_uuid": city_uuid, "ids": list(cust_id_list)} data = pd.DataFrame(self.db_helper.fetch_all(query, params)) return data ``` - [ ] **步骤 3: 修改 get_shop_by_ids 方法** 将第 122-137 行的 `get_shop_by_ids` 方法替换为: ```python def get_shop_by_ids(self, city_uuid, cust_id_list): """根据零售户列表查询其信息""" if not cust_id_list: return None query = text(f""" SELECT * FROM {self._shopping_tablename} WHERE city_uuid = :city_uuid AND cust_code IN :ids """).bindparams(bindparam("ids", expanding=True)) params = {"city_uuid": city_uuid, "ids": list(cust_id_list)} data = pd.DataFrame(self.db_helper.fetch_all(query, params)) return data ``` - [ ] **步骤 4: 修改 get_product_by_ids 方法** 将第 139-154 行的 `get_product_by_ids` 方法替换为: ```python def get_product_by_ids(self, city_uuid, product_id_list): """根据product_code列表查询其信息""" if not product_id_list: return None query = text(f""" SELECT * FROM {self._product_tablename} WHERE city_uuid = :city_uuid AND product_code IN :ids """).bindparams(bindparam("ids", expanding=True)) params = {"city_uuid": city_uuid, "ids": list(product_id_list)} data = pd.DataFrame(self.db_helper.fetch_all(query, params)) return data ``` - [ ] **步骤 5: 修改 get_order_by_product_ids 方法** 将第 156-175 行的 `get_order_by_product_ids` 方法替换为: ```python def get_order_by_product_ids(self, city_uuid, product_ids): """获取指定香烟列表的所有售卖记录""" if not product_ids: return None query = text(f""" SELECT * FROM {self._order_tablename} WHERE city_uuid = :city_uuid AND product_code IN :ids """).bindparams(bindparam("ids", expanding=True)) params = {"city_uuid": city_uuid, "ids": list(product_ids)} data = pd.DataFrame(self.db_helper.fetch_all(query, params)) cust_list = self.get_cust_list(city_uuid) cust_index = cust_list.set_index("BB_RETAIL_CUSTOMER_CODE") data = data.join(cust_index, on="cust_code", how="inner") return data ``` - [ ] **步骤 6: 提交修复 4** ```bash git add database/dao/mysql_dao.py git commit -m "fix: 修复所有 IN 子句的 SQL 注入风险 - 用 bindparam(expanding=True) 替代字符串拼接 - 修复方法: get_cust_by_ids, get_shop_by_ids, get_product_by_ids, get_order_by_product_ids - 改用 fetch_all 直接查询,跳过分页(结果集大小由输入列表决定)" ``` ### 任务 5: 修复 get_product_by_id 的分页开销(修复 5) **文件:** - 修改: `database/dao/mysql_dao.py:92-103` - [ ] **步骤 1: 修改 get_product_by_id 方法** 将第 92-103 行的 `get_product_by_id` 方法替换为: ```python def get_product_by_id(self, city_uuid, product_id): """根据city_uuid 和 product_id 从表中获取拼柜信息""" query = text(f""" SELECT * FROM {self._product_tablename} WHERE city_uuid = :city_uuid AND product_code = :product_id """) params = {"city_uuid": city_uuid, "product_id": product_id} result = self.db_helper.fetch_one(query, params) return pd.DataFrame([dict(result._mapping)] if result else []) ``` - [ ] **步骤 2: 提交修复 5** ```bash git add database/dao/mysql_dao.py git commit -m "fix: 修复 get_product_by_id 的分页开销 - 改用 fetch_one 查询单条记录 - 返回单行 DataFrame 保持接口一致" ``` --- ## 验证 - [ ] **步骤 1: 检查所有提交** ```bash git log --oneline -5 ``` 预期:看到 5 个修复提交(修复 2, 修复 1+3, 修复 6, 修复 4, 修复 5) - [ ] **步骤 2: 检查修改的文件** ```bash git diff HEAD~5 --stat ``` 预期: ``` database/db/mysql.py | 修改行数 database/dao/mysql_dao.py | 修改行数 2 files changed ``` - [ ] **步骤 3: 最终确认** 所有 6 个修复已完成,按照依赖顺序执行。