Quellcode durchsuchen

docs: add database bugfix implementation plan

Sherlock vor 2 Tagen
Ursprung
Commit
e5d416fd08
1 geänderte Dateien mit 374 neuen und 0 gelöschten Zeilen
  1. 374 0
      docs/superpowers/plans/2026-03-15-database-bugfix.md

+ 374 - 0
docs/superpowers/plans/2026-03-15-database-bugfix.md

@@ -0,0 +1,374 @@
+# 数据库操作 Bug 修复实现计划
+
+> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking.
+
+**目标:** 修复数据库操作层的 6 个 bug,消除 SQL 注入风险,改进错误处理
+
+**架构:** 外科手术式修复,不改变现有架构。修复顺序:修复 2 → 修复 1/3 → 修复 6 → 修复 4/5
+
+**技术栈:** SQLAlchemy, pandas, pymysql
+
+---
+
+## Chunk 1: mysql.py 修复
+
+### 任务 1: 修复 load_data_with_page 的 count 查询和 params 副作用(修复 2)
+
+**文件:**
+- 修改: `database/db/mysql.py:52-78`
+
+- [ ] **步骤 1: 修改 load_data_with_page 方法**
+
+将第 52-78 行的 `load_data_with_page` 方法替换为:
+
+```python
+def load_data_with_page(self, query, params, page_size=100000):
+    """分页查询数据"""
+    data = pd.DataFrame()
+    # 用子查询包裹原始查询来计数,避免字符串替换
+    count_query = text(f"SELECT COUNT(*) FROM ({query}) AS _count_subq")
+    query += " LIMIT :limit OFFSET :offset"
+    query = text(query)
+
+    # 获取总行数
+    total_rows = self.fetch_one(count_query, params)[0]
+
+    page = 1
+    with tqdm(total=total_rows, desc="Loading data", unit="rows") as pbar:
+        while True:
+            offset = (page - 1) * page_size
+            # 复制 params 避免修改调用方的字典
+            page_params = dict(params)
+            page_params["limit"] = page_size
+            page_params["offset"] = offset
+
+            df = pd.DataFrame(self.fetch_all(query, page_params))
+            if df.empty:
+                break
+            data = pd.concat([data, df], ignore_index=True)
+
+            pbar.update(len(df))
+
+            page += 1
+    return data
+```
+
+- [ ] **步骤 2: 提交修复 2**
+
+```bash
+git add database/db/mysql.py
+git commit -m "fix: 修复 load_data_with_page 的 count 查询和 params 副作用
+
+- 用子查询包裹原始查询来计数,不再依赖字符串替换
+- 每次分页前复制 params 避免修改调用方的字典"
+```
+
+### 任务 2: 修复 connect_database 和 fetch 方法的错误处理(修复 1 和 3)
+
+**文件:**
+- 修改: `database/db/mysql.py:33-48, 81-104`
+
+- [ ] **步骤 1: 修改 connect_database 方法**
+
+将第 33-48 行的 `connect_database` 方法替换为:
+
+```python
+def connect_database(self):
+    # 创建数据库连接
+    try:
+        conn = "mysql+pymysql://" + self._user + ":" + self._passwd + "@" + self._host + ":" + str(self._port) + "/" + self._dbname
+
+        # 通过连接池创建engine
+        self.engine = create_engine(
+            conn,
+            pool_size=20, # 设置连接池大小
+            max_overflow=30, # 超过连接池大小时的额外连接数
+            pool_recycle=1800, # 回收连接时间
+            pool_pre_ping=True, # 防止断开连接
+            isolation_level="READ COMMITTED" # 降低隔离级别
+        )
+    except Exception as e:
+        raise ConnectionAbortedError(f"failed to create connection: {e}")
+
+    self._DBSession = sessionmaker(bind=self.engine)
+```
+
+- [ ] **步骤 2: 修改 fetch_all 方法**
+
+将第 81-91 行的 `fetch_all` 方法替换为:
+
+```python
+def fetch_all(self, query, params=None):
+    """执行SQL查询并返回所有结果"""
+    session = self._DBSession()
+    try:
+        results = session.execute(query, params or {}).fetchall()
+        return results
+    except SQLAlchemyError as e:
+        session.rollback()
+        print(f"error: {e}")
+        raise
+    finally:
+        session.close()
+```
+
+- [ ] **步骤 3: 修改 fetch_one 方法**
+
+将第 93-104 行的 `fetch_one` 方法替换为:
+
+```python
+def fetch_one(self, query, params=None):
+    """执行SQL查询并返回单条结果"""
+    session = self._DBSession()
+    try:
+        result = session.execute(query, params or {}).fetchone()
+        return result
+
+    except SQLAlchemyError as e:
+        session.rollback()
+        print(f"error: {e}")
+        raise
+    finally:
+        session.close()
+```
+
+- [ ] **步骤 4: 提交修复 1 和 3**
+
+```bash
+git add database/db/mysql.py
+git commit -m "fix: 修复 connect_database 和 fetch 方法的错误处理
+
+修复 1: 将 create_engine 移入 try 块,确保连接失败被捕获
+修复 3: 在 fetch_all/fetch_one 的 except 块中添加 raise
+
+注意:修复 3 是行为破坏性变更,调用方现在会收到异常而非 None"
+```
+
+---
+
+## Chunk 2: mysql_dao.py 修复
+
+### 任务 3: 修复 get_cust_list 和 get_product_from_order 的接口不一致(修复 6)
+
+**文件:**
+- 修改: `database/dao/mysql_dao.py:251-278`
+
+- [ ] **步骤 1: 修改 get_product_from_order 方法**
+
+将第 251-257 行的 `get_product_from_order` 方法替换为:
+
+```python
+def get_product_from_order(self, city_uuid):
+    query = f"SELECT DISTINCT product_code FROM {self._order_tablename} WHERE city_uuid = :city_uuid"
+    params = {"city_uuid": city_uuid}
+
+    data = self.db_helper.load_data_with_page(query, params)
+
+    return data
+```
+
+- [ ] **步骤 2: 修改 get_cust_list 方法**
+
+将第 272-278 行的 `get_cust_list` 方法替换为:
+
+```python
+def get_cust_list(self, city_uuid):
+    query = f"SELECT DISTINCT BB_RETAIL_CUSTOMER_CODE FROM {self._cust_tablename} WHERE BA_CITY_ORG_CODE = :city_uuid"
+    params = {"city_uuid": city_uuid}
+
+    data = self.db_helper.load_data_with_page(query, params)
+
+    return data
+```
+
+- [ ] **步骤 3: 提交修复 6**
+
+```bash
+git add database/dao/mysql_dao.py
+git commit -m "fix: 统一 get_cust_list 和 get_product_from_order 的接口
+
+- 改用 load_data_with_page 替代直接调用 fetch_all
+- 与其他查询方法保持一致"
+```
+
+### 任务 4: 修复所有 IN 子句的 SQL 注入风险(修复 4)
+
+**文件:**
+- 修改: `database/dao/mysql_dao.py:2, 105-175`
+
+- [ ] **步骤 1: 添加 bindparam 导入**
+
+将第 2 行从:
+```python
+from sqlalchemy import text
+```
+
+修改为:
+```python
+from sqlalchemy import text, bindparam
+```
+
+- [ ] **步骤 2: 修改 get_cust_by_ids 方法**
+
+将第 105-120 行的 `get_cust_by_ids` 方法替换为:
+
+```python
+def get_cust_by_ids(self, city_uuid, cust_id_list):
+    """根据零售户列表查询其信息"""
+    if not cust_id_list:
+        return None
+
+    query = text(f"""
+        SELECT *
+        FROM {self._cust_tablename}
+        WHERE BA_CITY_ORG_CODE = :city_uuid
+        AND BB_RETAIL_CUSTOMER_CODE IN :ids
+    """).bindparams(bindparam("ids", expanding=True))
+    params = {"city_uuid": city_uuid, "ids": list(cust_id_list)}
+    data = pd.DataFrame(self.db_helper.fetch_all(query, params))
+
+    return data
+```
+
+- [ ] **步骤 3: 修改 get_shop_by_ids 方法**
+
+将第 122-137 行的 `get_shop_by_ids` 方法替换为:
+
+```python
+def get_shop_by_ids(self, city_uuid, cust_id_list):
+    """根据零售户列表查询其信息"""
+    if not cust_id_list:
+        return None
+
+    query = text(f"""
+        SELECT *
+        FROM {self._shopping_tablename}
+        WHERE city_uuid = :city_uuid
+        AND cust_code IN :ids
+    """).bindparams(bindparam("ids", expanding=True))
+    params = {"city_uuid": city_uuid, "ids": list(cust_id_list)}
+    data = pd.DataFrame(self.db_helper.fetch_all(query, params))
+
+    return data
+```
+
+- [ ] **步骤 4: 修改 get_product_by_ids 方法**
+
+将第 139-154 行的 `get_product_by_ids` 方法替换为:
+
+```python
+def get_product_by_ids(self, city_uuid, product_id_list):
+    """根据product_code列表查询其信息"""
+    if not product_id_list:
+        return None
+
+    query = text(f"""
+        SELECT *
+        FROM {self._product_tablename}
+        WHERE city_uuid = :city_uuid
+        AND product_code IN :ids
+    """).bindparams(bindparam("ids", expanding=True))
+    params = {"city_uuid": city_uuid, "ids": list(product_id_list)}
+    data = pd.DataFrame(self.db_helper.fetch_all(query, params))
+
+    return data
+```
+
+- [ ] **步骤 5: 修改 get_order_by_product_ids 方法**
+
+将第 156-175 行的 `get_order_by_product_ids` 方法替换为:
+
+```python
+def get_order_by_product_ids(self, city_uuid, product_ids):
+    """获取指定香烟列表的所有售卖记录"""
+    if not product_ids:
+        return None
+
+    query = text(f"""
+        SELECT *
+        FROM {self._order_tablename}
+        WHERE city_uuid = :city_uuid
+        AND product_code IN :ids
+    """).bindparams(bindparam("ids", expanding=True))
+    params = {"city_uuid": city_uuid, "ids": list(product_ids)}
+    data = pd.DataFrame(self.db_helper.fetch_all(query, params))
+
+    cust_list = self.get_cust_list(city_uuid)
+    cust_index = cust_list.set_index("BB_RETAIL_CUSTOMER_CODE")
+    data = data.join(cust_index, on="cust_code", how="inner")
+
+    return data
+```
+
+- [ ] **步骤 6: 提交修复 4**
+
+```bash
+git add database/dao/mysql_dao.py
+git commit -m "fix: 修复所有 IN 子句的 SQL 注入风险
+
+- 用 bindparam(expanding=True) 替代字符串拼接
+- 修复方法: get_cust_by_ids, get_shop_by_ids, get_product_by_ids, get_order_by_product_ids
+- 改用 fetch_all 直接查询,跳过分页(结果集大小由输入列表决定)"
+```
+
+### 任务 5: 修复 get_product_by_id 的分页开销(修复 5)
+
+**文件:**
+- 修改: `database/dao/mysql_dao.py:92-103`
+
+- [ ] **步骤 1: 修改 get_product_by_id 方法**
+
+将第 92-103 行的 `get_product_by_id` 方法替换为:
+
+```python
+def get_product_by_id(self, city_uuid, product_id):
+    """根据city_uuid 和 product_id 从表中获取拼柜信息"""
+    query = text(f"""
+        SELECT *
+        FROM {self._product_tablename}
+        WHERE city_uuid = :city_uuid
+        AND product_code = :product_id
+    """)
+    params = {"city_uuid": city_uuid, "product_id": product_id}
+    result = self.db_helper.fetch_one(query, params)
+    return pd.DataFrame([dict(result._mapping)] if result else [])
+```
+
+- [ ] **步骤 2: 提交修复 5**
+
+```bash
+git add database/dao/mysql_dao.py
+git commit -m "fix: 修复 get_product_by_id 的分页开销
+
+- 改用 fetch_one 查询单条记录
+- 返回单行 DataFrame 保持接口一致"
+```
+
+---
+
+## 验证
+
+- [ ] **步骤 1: 检查所有提交**
+
+```bash
+git log --oneline -5
+```
+
+预期:看到 5 个修复提交(修复 2, 修复 1+3, 修复 6, 修复 4, 修复 5)
+
+- [ ] **步骤 2: 检查修改的文件**
+
+```bash
+git diff HEAD~5 --stat
+```
+
+预期:
+```
+database/db/mysql.py       | 修改行数
+database/dao/mysql_dao.py  | 修改行数
+2 files changed
+```
+
+- [ ] **步骤 3: 最终确认**
+
+所有 6 个修复已完成,按照依赖顺序执行。