|
|
@@ -0,0 +1,374 @@
|
|
|
+# 数据库操作 Bug 修复实现计划
|
|
|
+
|
|
|
+> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking.
|
|
|
+
|
|
|
+**目标:** 修复数据库操作层的 6 个 bug,消除 SQL 注入风险,改进错误处理
|
|
|
+
|
|
|
+**架构:** 外科手术式修复,不改变现有架构。修复顺序:修复 2 → 修复 1/3 → 修复 6 → 修复 4/5
|
|
|
+
|
|
|
+**技术栈:** SQLAlchemy, pandas, pymysql
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+## Chunk 1: mysql.py 修复
|
|
|
+
|
|
|
+### 任务 1: 修复 load_data_with_page 的 count 查询和 params 副作用(修复 2)
|
|
|
+
|
|
|
+**文件:**
|
|
|
+- 修改: `database/db/mysql.py:52-78`
|
|
|
+
|
|
|
+- [ ] **步骤 1: 修改 load_data_with_page 方法**
|
|
|
+
|
|
|
+将第 52-78 行的 `load_data_with_page` 方法替换为:
|
|
|
+
|
|
|
+```python
|
|
|
+def load_data_with_page(self, query, params, page_size=100000):
|
|
|
+ """分页查询数据"""
|
|
|
+ data = pd.DataFrame()
|
|
|
+ # 用子查询包裹原始查询来计数,避免字符串替换
|
|
|
+ count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq")
|
|
|
+ query += " LIMIT :limit OFFSET :offset"
|
|
|
+ query = text(query)
|
|
|
+
|
|
|
+ # 获取总行数
|
|
|
+ total_rows = self.fetch_one(count_query, params)[0]
|
|
|
+
|
|
|
+ page = 1
|
|
|
+ with tqdm(total=total_rows, desc="Loading data", unit="rows") as pbar:
|
|
|
+ while True:
|
|
|
+ offset = (page - 1) * page_size
|
|
|
+ # 复制 params 避免修改调用方的字典
|
|
|
+ page_params = dict(params)
|
|
|
+ page_params["limit"] = page_size
|
|
|
+ page_params["offset"] = offset
|
|
|
+
|
|
|
+ df = pd.DataFrame(self.fetch_all(query, page_params))
|
|
|
+ if df.empty:
|
|
|
+ break
|
|
|
+ data = pd.concat([data, df], ignore_index=True)
|
|
|
+
|
|
|
+ pbar.update(len(df))
|
|
|
+
|
|
|
+ page += 1
|
|
|
+ return data
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 2: 提交修复 2**
|
|
|
+
|
|
|
+```bash
|
|
|
+git add database/db/mysql.py
|
|
|
+git commit -m "fix: 修复 load_data_with_page 的 count 查询和 params 副作用
|
|
|
+
|
|
|
+- 用子查询包裹原始查询来计数,不再依赖字符串替换
|
|
|
+- 每次分页前复制 params 避免修改调用方的字典"
|
|
|
+```
|
|
|
+
|
|
|
+### 任务 2: 修复 connect_database 和 fetch 方法的错误处理(修复 1 和 3)
|
|
|
+
|
|
|
+**文件:**
|
|
|
+- 修改: `database/db/mysql.py:33-48, 81-104`
|
|
|
+
|
|
|
+- [ ] **步骤 1: 修改 connect_database 方法**
|
|
|
+
|
|
|
+将第 33-48 行的 `connect_database` 方法替换为:
|
|
|
+
|
|
|
+```python
|
|
|
+def connect_database(self):
|
|
|
+ # 创建数据库连接
|
|
|
+ try:
|
|
|
+ conn = "mysql+pymysql://" + self._user + ":" + self._passwd + "@" + self._host + ":" + str(self._port) + "/" + self._dbname
|
|
|
+
|
|
|
+ # 通过连接池创建engine
|
|
|
+ self.engine = create_engine(
|
|
|
+ conn,
|
|
|
+ pool_size=20, # 设置连接池大小
|
|
|
+ max_overflow=30, # 超过连接池大小时的额外连接数
|
|
|
+ pool_recycle=1800, # 回收连接时间
|
|
|
+ pool_pre_ping=True, # 防止断开连接
|
|
|
+ isolation_level="READ COMMITTED" # 降低隔离级别
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ raise ConnectionAbortedError(f"failed to create connection: {e}")
|
|
|
+
|
|
|
+ self._DBSession = sessionmaker(bind=self.engine)
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 2: 修改 fetch_all 方法**
|
|
|
+
|
|
|
+将第 81-91 行的 `fetch_all` 方法替换为:
|
|
|
+
|
|
|
+```python
|
|
|
+def fetch_all(self, query, params=None):
|
|
|
+ """执行SQL查询并返回所有结果"""
|
|
|
+ session = self._DBSession()
|
|
|
+ try:
|
|
|
+ results = session.execute(query, params or {}).fetchall()
|
|
|
+ return results
|
|
|
+ except SQLAlchemyError as e:
|
|
|
+ session.rollback()
|
|
|
+ print(f"error: {e}")
|
|
|
+ raise
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 3: 修改 fetch_one 方法**
|
|
|
+
|
|
|
+将第 93-104 行的 `fetch_one` 方法替换为:
|
|
|
+
|
|
|
+```python
|
|
|
+def fetch_one(self, query, params=None):
|
|
|
+ """执行SQL查询并返回单条结果"""
|
|
|
+ session = self._DBSession()
|
|
|
+ try:
|
|
|
+ result = session.execute(query, params or {}).fetchone()
|
|
|
+ return result
|
|
|
+
|
|
|
+ except SQLAlchemyError as e:
|
|
|
+ session.rollback()
|
|
|
+ print(f"error: {e}")
|
|
|
+ raise
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 4: 提交修复 1 和 3**
|
|
|
+
|
|
|
+```bash
|
|
|
+git add database/db/mysql.py
|
|
|
+git commit -m "fix: 修复 connect_database 和 fetch 方法的错误处理
|
|
|
+
|
|
|
+修复 1: 将 create_engine 移入 try 块,确保连接失败被捕获
|
|
|
+修复 3: 在 fetch_all/fetch_one 的 except 块中添加 raise
|
|
|
+
|
|
|
+注意:修复 3 是行为破坏性变更,调用方现在会收到异常而非 None"
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+## Chunk 2: mysql_dao.py 修复
|
|
|
+
|
|
|
+### 任务 3: 修复 get_cust_list 和 get_product_from_order 的接口不一致(修复 6)
|
|
|
+
|
|
|
+**文件:**
|
|
|
+- 修改: `database/dao/mysql_dao.py:251-278`
|
|
|
+
|
|
|
+- [ ] **步骤 1: 修改 get_product_from_order 方法**
|
|
|
+
|
|
|
+将第 251-257 行的 `get_product_from_order` 方法替换为:
|
|
|
+
|
|
|
+```python
|
|
|
+def get_product_from_order(self, city_uuid):
|
|
|
+ query = f"SELECT DISTINCT product_code FROM {self._order_tablename} WHERE city_uuid = :city_uuid"
|
|
|
+ params = {"city_uuid": city_uuid}
|
|
|
+
|
|
|
+ data = self.db_helper.load_data_with_page(query, params)
|
|
|
+
|
|
|
+ return data
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 2: 修改 get_cust_list 方法**
|
|
|
+
|
|
|
+将第 272-278 行的 `get_cust_list` 方法替换为:
|
|
|
+
|
|
|
+```python
|
|
|
+def get_cust_list(self, city_uuid):
|
|
|
+ query = f"SELECT DISTINCT BB_RETAIL_CUSTOMER_CODE FROM {self._cust_tablename} WHERE BA_CITY_ORG_CODE = :city_uuid"
|
|
|
+ params = {"city_uuid": city_uuid}
|
|
|
+
|
|
|
+ data = self.db_helper.load_data_with_page(query, params)
|
|
|
+
|
|
|
+ return data
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 3: 提交修复 6**
|
|
|
+
|
|
|
+```bash
|
|
|
+git add database/dao/mysql_dao.py
|
|
|
+git commit -m "fix: 统一 get_cust_list 和 get_product_from_order 的接口
|
|
|
+
|
|
|
+- 改用 load_data_with_page 替代直接调用 fetch_all
|
|
|
+- 与其他查询方法保持一致"
|
|
|
+```
|
|
|
+
|
|
|
+### 任务 4: 修复所有 IN 子句的 SQL 注入风险(修复 4)
|
|
|
+
|
|
|
+**文件:**
|
|
|
+- 修改: `database/dao/mysql_dao.py:2, 105-175`
|
|
|
+
|
|
|
+- [ ] **步骤 1: 添加 bindparam 导入**
|
|
|
+
|
|
|
+将第 2 行从:
|
|
|
+```python
|
|
|
+from sqlalchemy import text
|
|
|
+```
|
|
|
+
|
|
|
+修改为:
|
|
|
+```python
|
|
|
+from sqlalchemy import text, bindparam
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 2: 修改 get_cust_by_ids 方法**
|
|
|
+
|
|
|
+将第 105-120 行的 `get_cust_by_ids` 方法替换为:
|
|
|
+
|
|
|
+```python
|
|
|
+def get_cust_by_ids(self, city_uuid, cust_id_list):
|
|
|
+ """根据零售户列表查询其信息"""
|
|
|
+ if not cust_id_list:
|
|
|
+ return None
|
|
|
+
|
|
|
+ query = text(f"""
|
|
|
+ SELECT *
|
|
|
+ FROM {self._cust_tablename}
|
|
|
+ WHERE BA_CITY_ORG_CODE = :city_uuid
|
|
|
+ AND BB_RETAIL_CUSTOMER_CODE IN :ids
|
|
|
+ """).bindparams(bindparam("ids", expanding=True))
|
|
|
+ params = {"city_uuid": city_uuid, "ids": list(cust_id_list)}
|
|
|
+ data = pd.DataFrame(self.db_helper.fetch_all(query, params))
|
|
|
+
|
|
|
+ return data
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 3: 修改 get_shop_by_ids 方法**
|
|
|
+
|
|
|
+将第 122-137 行的 `get_shop_by_ids` 方法替换为:
|
|
|
+
|
|
|
+```python
|
|
|
+def get_shop_by_ids(self, city_uuid, cust_id_list):
|
|
|
+ """根据零售户列表查询其信息"""
|
|
|
+ if not cust_id_list:
|
|
|
+ return None
|
|
|
+
|
|
|
+ query = text(f"""
|
|
|
+ SELECT *
|
|
|
+ FROM {self._shopping_tablename}
|
|
|
+ WHERE city_uuid = :city_uuid
|
|
|
+ AND cust_code IN :ids
|
|
|
+ """).bindparams(bindparam("ids", expanding=True))
|
|
|
+ params = {"city_uuid": city_uuid, "ids": list(cust_id_list)}
|
|
|
+ data = pd.DataFrame(self.db_helper.fetch_all(query, params))
|
|
|
+
|
|
|
+ return data
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 4: 修改 get_product_by_ids 方法**
|
|
|
+
|
|
|
+将第 139-154 行的 `get_product_by_ids` 方法替换为:
|
|
|
+
|
|
|
+```python
|
|
|
+def get_product_by_ids(self, city_uuid, product_id_list):
|
|
|
+ """根据product_code列表查询其信息"""
|
|
|
+ if not product_id_list:
|
|
|
+ return None
|
|
|
+
|
|
|
+ query = text(f"""
|
|
|
+ SELECT *
|
|
|
+ FROM {self._product_tablename}
|
|
|
+ WHERE city_uuid = :city_uuid
|
|
|
+ AND product_code IN :ids
|
|
|
+ """).bindparams(bindparam("ids", expanding=True))
|
|
|
+ params = {"city_uuid": city_uuid, "ids": list(product_id_list)}
|
|
|
+ data = pd.DataFrame(self.db_helper.fetch_all(query, params))
|
|
|
+
|
|
|
+ return data
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 5: 修改 get_order_by_product_ids 方法**
|
|
|
+
|
|
|
+将第 156-175 行的 `get_order_by_product_ids` 方法替换为:
|
|
|
+
|
|
|
+```python
|
|
|
+def get_order_by_product_ids(self, city_uuid, product_ids):
|
|
|
+ """获取指定香烟列表的所有售卖记录"""
|
|
|
+ if not product_ids:
|
|
|
+ return None
|
|
|
+
|
|
|
+ query = text(f"""
|
|
|
+ SELECT *
|
|
|
+ FROM {self._order_tablename}
|
|
|
+ WHERE city_uuid = :city_uuid
|
|
|
+ AND product_code IN :ids
|
|
|
+ """).bindparams(bindparam("ids", expanding=True))
|
|
|
+ params = {"city_uuid": city_uuid, "ids": list(product_ids)}
|
|
|
+ data = pd.DataFrame(self.db_helper.fetch_all(query, params))
|
|
|
+
|
|
|
+ cust_list = self.get_cust_list(city_uuid)
|
|
|
+ cust_index = cust_list.set_index("BB_RETAIL_CUSTOMER_CODE")
|
|
|
+ data = data.join(cust_index, on="cust_code", how="inner")
|
|
|
+
|
|
|
+ return data
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 6: 提交修复 4**
|
|
|
+
|
|
|
+```bash
|
|
|
+git add database/dao/mysql_dao.py
|
|
|
+git commit -m "fix: 修复所有 IN 子句的 SQL 注入风险
|
|
|
+
|
|
|
+- 用 bindparam(expanding=True) 替代字符串拼接
|
|
|
+- 修复方法: get_cust_by_ids, get_shop_by_ids, get_product_by_ids, get_order_by_product_ids
|
|
|
+- 改用 fetch_all 直接查询,跳过分页(结果集大小由输入列表决定)"
|
|
|
+```
|
|
|
+
|
|
|
+### 任务 5: 修复 get_product_by_id 的分页开销(修复 5)
|
|
|
+
|
|
|
+**文件:**
|
|
|
+- 修改: `database/dao/mysql_dao.py:92-103`
|
|
|
+
|
|
|
+- [ ] **步骤 1: 修改 get_product_by_id 方法**
|
|
|
+
|
|
|
+将第 92-103 行的 `get_product_by_id` 方法替换为:
|
|
|
+
|
|
|
+```python
|
|
|
+def get_product_by_id(self, city_uuid, product_id):
|
|
|
+ """根据city_uuid 和 product_id 从表中获取拼柜信息"""
|
|
|
+ query = text(f"""
|
|
|
+ SELECT *
|
|
|
+ FROM {self._product_tablename}
|
|
|
+ WHERE city_uuid = :city_uuid
|
|
|
+ AND product_code = :product_id
|
|
|
+ """)
|
|
|
+ params = {"city_uuid": city_uuid, "product_id": product_id}
|
|
|
+ result = self.db_helper.fetch_one(query, params)
|
|
|
+ return pd.DataFrame([dict(result._mapping)] if result else [])
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 2: 提交修复 5**
|
|
|
+
|
|
|
+```bash
|
|
|
+git add database/dao/mysql_dao.py
|
|
|
+git commit -m "fix: 修复 get_product_by_id 的分页开销
|
|
|
+
|
|
|
+- 改用 fetch_one 查询单条记录
|
|
|
+- 返回单行 DataFrame 保持接口一致"
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+## 验证
|
|
|
+
|
|
|
+- [ ] **步骤 1: 检查所有提交**
|
|
|
+
|
|
|
+```bash
|
|
|
+git log --oneline -5
|
|
|
+```
|
|
|
+
|
|
|
+预期:看到 5 个修复提交(修复 2, 修复 1+3, 修复 6, 修复 4, 修复 5)
|
|
|
+
|
|
|
+- [ ] **步骤 2: 检查修改的文件**
|
|
|
+
|
|
|
+```bash
|
|
|
+git diff HEAD~5 --stat
|
|
|
+```
|
|
|
+
|
|
|
+预期:
|
|
|
+```
|
|
|
+database/db/mysql.py | 修改行数
|
|
|
+database/dao/mysql_dao.py | 修改行数
|
|
|
+2 files changed
|
|
|
+```
|
|
|
+
|
|
|
+- [ ] **步骤 3: 最终确认**
|
|
|
+
|
|
|
+所有 6 个修复已完成,按照依赖顺序执行。
|