yangzeyu 1 год назад
Родитель
Сommit
9ee8367266

+ 0 - 16
dao/__init__.py

@@ -1,16 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-from dao.mysql_client import Mysql
-from dao.dao import load_order_data_from_mysql, load_cust_data_from_mysql, load_product_data_from_mysql, get_product_by_id, get_custs_by_ids, get_cust_list_from_database
-from dao.redis_db import Redis
-
-__all__ = [
-    "Mysql",
-    "load_order_data_from_mysql",
-    "load_cust_data_from_mysql",
-    "load_product_data_from_mysql",
-    "Redis",
-    "get_product_by_id",
-    "get_custs_by_ids",
-    "get_cust_list_from_database"
-]

+ 0 - 78
dao/dao.py

@@ -1,78 +0,0 @@
-from dao import Mysql
-
-def load_order_data_from_mysql(city_uuid):
-    """从数据库中读取订单数据"""
-    client = Mysql()
-    # tablename = "yunfu_mock_data"
-    tablename = "tads_brandcul_cust_order"
-    query_text = "*"
-    # city_uuid = "00000000000000000000000011441801"
-    df = client.load_data(tablename, query_text, "city_uuid", city_uuid)
-    # df = client.load_mock_data(tablename, query_text)
-    if len(df) == 0:
-        return None
-    
-    df.drop('stat_month', axis=1, inplace=True)
-    df.drop('city_uuid', axis=1, inplace=True)
-    
-    # 去除重复值和填补缺失值
-    df.drop_duplicates(inplace=True)
-    df.fillna(0, inplace=True)
-    df = df.infer_objects(copy=False)
-    return df
-
-def load_cust_data_from_mysql(city_uuid):
-    """从数据库中读取商户信息数据"""
-    client = Mysql()
-    tablename = "tads_brandcul_cust_info"
-    query_text = "*"
-    
-    df = client.load_data(tablename, query_text, "BA_CITY_ORG_CODE", city_uuid)
-    if len(df) == 0:
-        return None
-    
-    return df
-
-def get_cust_list_from_database(city_uuid):
-    client = Mysql()
-    tablename = "tads_brandcul_cust_info"
-    query_text = "*"
-    
-    df = client.load_data(tablename, query_text, "BA_CITY_ORG_CODE", city_uuid)
-    cust_list = df["BB_RETAIL_CUSTOMER_CODE"].to_list()
-    if len(cust_list) == 0:
-        return []
-    
-    return cust_list
-
-def load_product_data_from_mysql(city_uuid):
-    """从数据库中读取商品信息"""
-    client = Mysql()
-    tablename = "tads_brandcul_product_info"
-    query_text = "*"
-    
-    df = client.load_data(tablename, query_text, "city_uuid", city_uuid)
-    if len(df) == 0:
-        return None
-    
-    return df
-
-def get_product_by_id(city_uuid, product_id):
-    client = Mysql()
-    
-    res = client.get_product_by_id(city_uuid, product_id)
-    if len(res) == 0:
-        return None
-    return res
-
-def get_custs_by_ids(city_uuid, cust_ids):
-    client = Mysql()
-    
-    res = client.get_cust_by_ids(city_uuid, cust_ids)
-    if len(res) == 0:
-        return None
-    return res
-
-if __name__ == '__main__':
-    data = load_order_data_from_mysql("00000000000000000000000011445301")
-    print(data)

+ 0 - 151
dao/mysql_client.py

@@ -1,151 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-import os
-from sqlalchemy import create_engine, text
-from sqlalchemy.dialects.mysql import pymysql
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy.ext.declarative import declarative_base
-from config import load_config
-import pandas as pd
-import sys
-
-cfgs = load_config()
-
-class Mysql(object):
-    def __init__(self):
-        self._host = cfgs['mysql']['host']
-        self._port = cfgs['mysql']['port']
-        self._user = cfgs['mysql']['user']
-        self._passwd = cfgs['mysql']['passwd']
-        self._dbname = cfgs['mysql']['db']
-        
-        # 通过连接池创建engine
-        self.engine = create_engine(
-            self._connect(self._host, self._port, self._user, self._passwd, self._dbname),
-            pool_size=10, # 设置连接池大小
-            max_overflow=20, # 超过连接池大小时的额外连接数
-            pool_recycle=3600 # 回收连接时间
-        )
-        self._DBSession = sessionmaker(bind=self.engine)
-
-    def _connect(self, host, port, user, pwd, db):
-        try:
-            client = "mysql+pymysql://" + user + ":" + pwd + "@" + host + ":" + str(port) + "/" + db
-            return client
-        except Exception as e:
-            raise ConnectionError(f"failed to create connection string: {e}")
-        
-    def create_session(self):
-        """创建返回一个新的数据库session"""
-        return self._DBSession()
-    
-    def fetch_data_with_pagination(self, tablename, query_text, field_name, city_uuid, page=1, page_size=1000):
-        """分页查询数据,并根据 city_uuid 进行过滤"""
-        offset = (page - 1) * page_size  # 计算偏移量
-        query = text(f"SELECT {query_text} FROM {tablename} WHERE {field_name} = :city_uuid LIMIT :limit OFFSET :offset")
-    
-        with self.create_session() as session:
-            results = session.execute(query, {"city_uuid": city_uuid, "limit": page_size, "offset": offset}).fetchall()
-            df = pd.DataFrame(results)
-    
-        return df
-    
-    def load_data(self, tablename, query_text, field_name, city_uuid, page=1, page_size=1000):
-        # 创建一个空的DataFrame用于存储所有数据
-        total_df = pd.DataFrame()
-    
-        try:
-            while True:
-                df = self.fetch_data_with_pagination(tablename, query_text, field_name, city_uuid, page, page_size)
-                if df.empty:
-                    break
-            
-                total_df = pd.concat([total_df, df], ignore_index=True)
-                print(f"Page {page}: Retrieved {len(df)} rows, Total rows so far: {len(total_df)}")
-                page += 1  # 继续下一页
-                
-        except Exception as e:
-            print(f"Error: {e}")
-            return None
-        
-        finally:
-            self.closed()
-            return total_df
-        
-    def get_product_by_id(self, city_uuid, product_id):
-        """根据 city_uuid 和 product_id 从表中获取品规信息"""
-        query = text("""
-            SELECT * 
-            FROM tads_brandcul_product_info 
-            WHERE city_uuid = :city_uuid 
-            AND product_code = :product_id
-        """)
-        
-        with self.create_session() as session:
-            result = session.execute(query, {"city_uuid": city_uuid, "product_id": product_id}).fetchall()
-            result = pd.DataFrame(result)
-        return result
-        
-    def get_cust_by_ids(self, city_uuid, cust_id_list):
-        """根据 city_uuid 和 cust_id 列表从表中获取零售户信息"""
-        if not cust_id_list:
-            return []
-        
-        cust_id_str = ",".join([f"'{cust_id}'" for cust_id in cust_id_list])
-        
-        query = text(f"""
-            SELECT * 
-            FROM tads_brandcul_cust_info
-            WHERE BA_CITY_ORG_CODE = :city_uuid 
-            AND BB_RETAIL_CUSTOMER_CODE IN ({cust_id_str})
-        """)
-        
-        with self.create_session() as session:
-            results = session.execute(query, {"city_uuid": city_uuid}).fetchall()
-            results = pd.DataFrame(results)
-        
-        return results
-        
-    def load_mock_data(self, tablename, query_text, page=1, page_size=1000):
-        # 创建一个空的DataFrame用于存储所有数据
-        total_df = pd.DataFrame()
-    
-        try:
-            while True:
-                offset = (page - 1) * page_size  # 计算偏移量
-                query = text(f"SELECT {query_text} FROM {tablename} LIMIT :limit OFFSET :offset")
-    
-                with self.create_session() as session:
-                    results = session.execute(query, { "limit": page_size, "offset": offset}).fetchall()
-                    df = pd.DataFrame(results)
-                if df.empty:
-                    break
-            
-                total_df = pd.concat([total_df, df], ignore_index=True)
-                print(f"Page {page}: Retrieved {len(df)} rows, Total rows so far: {len(total_df)}")
-                page += 1  # 继续下一页
-                
-        except Exception as e:
-            print(f"Error: {e}")
-            return None
-        
-        finally:
-            self.closed()
-            return total_df
-    
-    def closed(self):
-        """关闭连接,回收资源"""
-        self.engine.dispose()
-
-
-if __name__ == '__main__':
-    
-    client = Mysql()
-    tablename = "mock_order"
-    
-    # 设置分页参数
-    page = 1
-    page_size = 1000
-    
-    query_text = '*'
-    client.load_data("mock_order", query_text, page, page_size)

+ 5 - 1
database/__init__.py

@@ -1,5 +1,9 @@
 from database.db.mysql import MySqlDatabaseHelper
+from database.db.redis_db import RedisDatabaseHelper
+from database.dao.mysql_dao import MySqlDao
 
 __all__ = [
-    "MySqlDatabaseHelper"
+    "MySqlDatabaseHelper",
+    "RedisDatabaseHelper",
+    "MySqlDao"
 ]

+ 67 - 15
database/dao/mysql_dao.py

@@ -1,14 +1,31 @@
 from database import MySqlDatabaseHelper
+from sqlalchemy import text
 
 class MySqlDao:
+    _instance = None
+    
+    def __new__(cls):
+        if not cls._instance:
+            cls._instance = super(MySqlDao, cls).__new__(cls)
+            cls._instance._initialized = False
+        return cls._instance
+    
+    
     def __init__(self):
+        if self._initialized:
+            return
+        
         self.db_helper = MySqlDatabaseHelper()
+        self._product_tablename = "tads_brandcul_product_info"
+        self._cust_tablename = "tads_brandcul_cust_info"
+        self._order_tablename = "tads_brandcul_cust_order"
+        self._mock_order_tablename = "yunfu_mock_data"
+        
+        self._initialized = True
         
     def load_product_data(self, city_uuid):
         """从数据库中读取商品信息"""
-        tablename = "tads_brandcul_product_info"
-        
-        query = f"SELECT * FROM {tablename} WHERE city_uuid = :city_uuid"
+        query = f"SELECT * FROM {self._product_tablename} WHERE city_uuid = :city_uuid"
         params = {"city_uuid": city_uuid}
         
         data = self.db_helper.load_data_with_page(query, params)
@@ -16,9 +33,7 @@ class MySqlDao:
         
     def load_cust_data(self, city_uuid):
         """从数据库中读取商户信息"""
-        tablename = "tads_brandcul_cust_info"
-        
-        query = f"SELECT * FROM {tablename} WHERE BA_CITY_ORG_CODE = :city_uuid"
+        query = f"SELECT * FROM {self._cust_tablename} WHERE BA_CITY_ORG_CODE = :city_uuid"
         params = {"city_uuid": city_uuid}
         
         data = self.db_helper.load_data_with_page(query, params)
@@ -26,9 +41,7 @@ class MySqlDao:
     
     def load_order_data(self, city_uuid):
         """从数据库中读取订单信息"""
-        tablename = "tads_brandcul_cust_order"
-
-        query = f"SELECT * FROM {tablename} WHERE city_uuid = :city_uuid"
+        query = f"SELECT * FROM {self._order_tablename} WHERE city_uuid = :city_uuid"
         params = {"city_uuid": city_uuid}
         
         data = self.db_helper.load_data_with_page(query, params)
@@ -44,9 +57,7 @@ class MySqlDao:
     
     def load_mock_order_data(self, city_uuid):
         """从数据库中读取mock的订单信息"""
-        tablename = "yunfu_mock_data"
-
-        query = f"SELECT * FROM {tablename}"
+        query = f"SELECT * FROM {self._mock_order_tablename}"
         params = {"city_uuid": city_uuid}
         
         data = self.db_helper.load_data_with_page(query, params)
@@ -57,9 +68,50 @@ class MySqlDao:
         data = data.infer_objects(copy=False)
         
         return data
+    
+    def get_cust_list(self, city_uuid):
+        """获取商户列表"""
+        data = self.load_cust_data(city_uuid)
+        cust_list = data["BB_RETAIL_CUSTOMER_CODE"].to_list()
+        if len(cust_list) == 0:
+            return []
+        
+        return cust_list
+    
+    def get_product_by_id(self, city_uuid, product_id):
+        """根据city_uuid 和 product_id 从表中获取拼柜信息"""
+        query = text(f"""
+            SELECT *
+            FROM {self._product_tablename}
+            WHERE city_uuid = :city_uuid
+            AND product_code = :product_id
+        """)
+        params = {"city_uuid": city_uuid, "product_id": product_id}
+        data = self.db_helper.fetch_one(query, params)
+        
+        return data
+    
+    def get_cust_by_ids(self, city_uuid, cust_id_list):
+        """根据零售户列表查询其信息"""
+        if not cust_id_list:
+            return None
+        
+        cust_id_str = ",".join([f"'{cust_id}'" for cust_id in cust_id_list])
+        query = text(f"""
+            SELECT *
+            FROM {self._cust_tablename}
+            WHERE BA_CITY_ORG_CODE = :city_uuid
+            AND BB_RETAIL_CUSTOMER_CODE IN ({cust_id_str})
+        """)
+        params = {"city_uuid": city_uuid}
+        data = self.db_helper.fetch_all(query, params)
+        
+        return data
         
 if __name__ == "__main__":
     dao = MySqlDao()
-    city_uuid = "00000000000000000000000011445301"
-    # city_uuid = "00000000000000000000000011441801"
-    dao.load_mock_order_data(city_uuid)
+    # city_uuid = "00000000000000000000000011445301"
+    city_uuid = "00000000000000000000000011441801"
+    cust_id_list = ["441800100006", "441800100051", "441800100811"]
+    cust_list = dao.get_cust_by_ids(city_uuid, cust_id_list)
+    print(len(cust_list))

+ 1 - 1
database/db/mysql.py

@@ -10,7 +10,7 @@ cfgs = load_config()
 class MySqlDatabaseHelper:
     _instance = None
     
-    def __new__(cls, *args, **kwargs):
+    def __new__(cls):
         if not cls._instance:
             cls._instance = super(MySqlDatabaseHelper, cls).__new__(cls)
             cls._instance._initialized = False

+ 14 - 2
dao/redis_db.py → database/db/redis_db.py

@@ -6,19 +6,31 @@ from config import load_config
 cfgs = load_config()
 
 
-class Redis(object):
+class RedisDatabaseHelper:
+    _instance = None
+    
+    def __new__(cls):
+        if not cls._instance:
+            cls._instance = super(RedisDatabaseHelper, cls).__new__(cls)
+            cls._instance._initialized = False
+        return cls._instance
+        
     def __init__(self):
+        if self._initialized:
+            return
         self.redis = redis.StrictRedis(host=cfgs['redis']['host'],
                                        port=cfgs['redis']['port'],
                                        password=cfgs['redis']['passwd'],
                                        db=cfgs['redis']['db'],
                                        decode_responses=True)
+        
+        self._initialized = True
 
 
 if __name__ == '__main__':
     import random
     # 连接到 Redis 服务器
-    r = Redis().redis
+    r = RedisDatabaseHelper().redis
 
     # 有序集合的键名
     zset_key = 'configs:hotkeys'

+ 6 - 4
models/rank/data/preprocess.py

@@ -1,4 +1,4 @@
-from dao.dao import load_cust_data_from_mysql, load_product_data_from_mysql, load_order_data_from_mysql
+from database import MySqlDao
 from models.rank.data.config import CustConfig, ProductConfig, OrderConfig
 import os
 import pandas as pd
@@ -8,13 +8,15 @@ import numpy as np
 
 class DataProcess():
     def __init__(self, city_uuid, save_path):
+        self._mysql_dao = MySqlDao()
         self._save_res_path = save_path
         print("正在加载cust_info...")
-        self._cust_data = load_cust_data_from_mysql(city_uuid)
+        self._cust_data = self._mysql_dao.load_cust_data(city_uuid)
         print("正在加载product_info...")
-        self._product_data = load_product_data_from_mysql(city_uuid)
+        self._product_data = self._mysql_dao.load_product_data(city_uuid)
         print("正在加载order_info...")
-        self._order_data = load_order_data_from_mysql(city_uuid)
+        # self._order_data = self._mysql_dao.load_cust_data(city_uuid)
+        self._order_data = self._mysql_dao.load_mock_order_data(city_uuid)
         
     def data_process(self):
         """数据预处理"""

+ 8 - 6
models/rank/gbdt_lr_sort.py

@@ -1,5 +1,6 @@
 import joblib
-from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
+# from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
+from database import RedisDatabaseHelper, MySqlDao
 from models.rank.data import ProductConfig, CustConfig, ImportanceFeaturesMap
 from models.rank.data.utils import one_hot_embedding, sample_data_clear
 import pandas as pd
@@ -9,7 +10,8 @@ from sklearn.preprocessing import StandardScaler
 class GbdtLrModel:
     def __init__(self, model_path):
         self.load_model(model_path)
-        self.redis = Redis().redis
+        self.redis = RedisDatabaseHelper().redis
+        self._mysql_dao = MySqlDao()
     
     def load_model(self, model_path):
         models = joblib.load(model_path)
@@ -22,13 +24,13 @@ class GbdtLrModel:
     #     self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
     
     # def load_recall_data(self, city_uuid, product_id):
-    #     self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
-    #     self.custs_data = get_custs_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
+    #     self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
+    #     self.custs_data = self._mysql_dao.get_cust_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
         
     def get_cust_and_product_data(self, city_uuid, product_id):
         """从商户数据库中获取指定城市所有商户的id"""
-        self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
-        self.custs_data = load_cust_data_from_mysql(city_uuid)[CustConfig.FEATURE_COLUMNS]
+        self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
+        self.custs_data = self._mysql_dao.load_cust_data(city_uuid)[CustConfig.FEATURE_COLUMNS]
     
     def generate_feats_map(self, city_uuid, product_id):
         """组合卷烟、商户特征矩阵"""