ソースを参照

Merge branch 'dev'

Sherlock 1 年間 前
コミット
0f45934bd5

+ 1 - 1
app.py

@@ -61,7 +61,7 @@ def run():
     parser.add_argument("--k", type=int, default=20)
     parser.add_argument("--top_n", type=int, default=2000, help='default n * k')
     parser.add_argument("--n_jobs", type=int, default=4)
-    parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011441801', help="City UUID for filtering data")
+    parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301', help="City UUID for filtering data")
     
     # 协同过滤推理配置
     parser.add_argument("--product_code", type=int, default=110111)

+ 9 - 2
dao/__init__.py

@@ -1,9 +1,16 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 from dao.mysql_client import Mysql
-from dao.dao import load_order_data_from_mysql
+from dao.dao import load_order_data_from_mysql, load_cust_data_from_mysql, load_product_data_from_mysql, get_product_by_id, get_custs_by_ids, get_cust_list_from_database
+from dao.redis_db import Redis
 
 __all__ = [
     "Mysql",
-    "load_order_data_from_mysql"
+    "load_order_data_from_mysql",
+    "load_cust_data_from_mysql",
+    "load_product_data_from_mysql",
+    "Redis",
+    "get_product_by_id",
+    "get_custs_by_ids",
+    "get_cust_list_from_database"
 ]

+ 67 - 9
dao/dao.py

@@ -1,20 +1,78 @@
 from dao import Mysql
 
 def load_order_data_from_mysql(city_uuid):
-    """从数据库中读取数据"""
+    """从数据库中读取订单数据"""
     client = Mysql()
-    tablename = "tads_brandcul_cust_order"
+    tablename = "yunfu_mock_data"
+    # tablename = "tads_brandcul_cust_order"
     query_text = "*"
-    
-    df = client.load_data(tablename, query_text, city_uuid)
+    # city_uuid = "00000000000000000000000011441801"
+    # df = client.load_data(tablename, query_text, "city_uuid", city_uuid)
+    df = client.load_mock_data(tablename, query_text)
     if len(df) == 0:
         return None
     
-    df.drop('stat_month', axis=1, inplace=True)
-    df.drop('city_uuid', axis=1, inplace=True)
-    print(df.columns)
+    # df.drop('stat_month', axis=1, inplace=True)
+    # df.drop('city_uuid', axis=1, inplace=True)
     
-     # 去除重复值和填补缺失值
+    # 去除重复值和填补缺失值
     df.drop_duplicates(inplace=True)
     df.fillna(0, inplace=True)
-    return df
+    df = df.infer_objects(copy=False)
+    return df
+
+def load_cust_data_from_mysql(city_uuid):
+    """从数据库中读取商户信息数据"""
+    client = Mysql()
+    tablename = "tads_brandcul_cust_info"
+    query_text = "*"
+    
+    df = client.load_data(tablename, query_text, "BA_CITY_ORG_CODE", city_uuid)
+    if len(df) == 0:
+        return None
+    
+    return df
+
+def get_cust_list_from_database(city_uuid):
+    client = Mysql()
+    tablename = "tads_brandcul_cust_info"
+    query_text = "*"
+    
+    df = client.load_data(tablename, query_text, "BA_CITY_ORG_CODE", city_uuid)
+    cust_list = df["BB_RETAIL_CUSTOMER_CODE"].to_list()
+    if len(cust_list) == 0:
+        return []
+    
+    return cust_list
+
+def load_product_data_from_mysql(city_uuid):
+    """从数据库中读取商品信息"""
+    client = Mysql()
+    tablename = "tads_brandcul_product_info"
+    query_text = "*"
+    
+    df = client.load_data(tablename, query_text, "city_uuid", city_uuid)
+    if len(df) == 0:
+        return None
+    
+    return df
+
+def get_product_by_id(city_uuid, product_id):
+    client = Mysql()
+    
+    res = client.get_product_by_id(city_uuid, product_id)
+    if len(res) == 0:
+        return None
+    return res
+
+def get_custs_by_ids(city_uuid, cust_ids):
+    client = Mysql()
+    
+    res = client.get_cust_by_ids(city_uuid, cust_ids)
+    if len(res) == 0:
+        return None
+    return res
+
+if __name__ == '__main__':
+    data = load_order_data_from_mysql("00000000000000000000000011445301")
+    print(data)

+ 65 - 4
dao/mysql_client.py

@@ -39,10 +39,10 @@ class Mysql(object):
         """创建返回一个新的数据库session"""
         return self._DBSession()
     
-    def fetch_data_with_pagination(self, tablename, query_text, city_uuid, page=1, page_size=1000):
+    def fetch_data_with_pagination(self, tablename, query_text, field_name, city_uuid, page=1, page_size=1000):
         """分页查询数据,并根据 city_uuid 进行过滤"""
         offset = (page - 1) * page_size  # 计算偏移量
-        query = text(f"SELECT {query_text} FROM {tablename} WHERE city_uuid = :city_uuid LIMIT :limit OFFSET :offset")
+        query = text(f"SELECT {query_text} FROM {tablename} WHERE {field_name} = :city_uuid LIMIT :limit OFFSET :offset")
     
         with self.create_session() as session:
             results = session.execute(query, {"city_uuid": city_uuid, "limit": page_size, "offset": offset}).fetchall()
@@ -50,13 +50,74 @@ class Mysql(object):
     
         return df
     
-    def load_data(self, tablename, query_text, city_uuid, page=1, page_size=1000):
+    def load_data(self, tablename, query_text, field_name, city_uuid, page=1, page_size=1000):
         # 创建一个空的DataFrame用于存储所有数据
         total_df = pd.DataFrame()
     
         try:
             while True:
-                df = self.fetch_data_with_pagination(tablename, query_text, city_uuid, page, page_size)
+                df = self.fetch_data_with_pagination(tablename, query_text, field_name, city_uuid, page, page_size)
+                if df.empty:
+                    break
+            
+                total_df = pd.concat([total_df, df], ignore_index=True)
+                print(f"Page {page}: Retrieved {len(df)} rows, Total rows so far: {len(total_df)}")
+                page += 1  # 继续下一页
+                
+        except Exception as e:
+            print(f"Error: {e}")
+            return None
+        
+        finally:
+            self.closed()
+            return total_df
+        
+    def get_product_by_id(self, city_uuid, product_id):
+        """根据 city_uuid 和 product_id 从表中获取品规信息"""
+        query = text("""
+            SELECT * 
+            FROM tads_brandcul_product_info 
+            WHERE city_uuid = :city_uuid 
+            AND product_code = :product_id
+        """)
+        
+        with self.create_session() as session:
+            result = session.execute(query, {"city_uuid": city_uuid, "product_id": product_id}).fetchall()
+            result = pd.DataFrame(result)
+        return result
+        
+    def get_cust_by_ids(self, city_uuid, cust_id_list):
+        """根据 city_uuid 和 cust_id 列表从表中获取零售户信息"""
+        if not cust_id_list:
+            return []
+        
+        cust_id_str = ",".join([f"'{cust_id}'" for cust_id in cust_id_list])
+        
+        query = text(f"""
+            SELECT * 
+            FROM tads_brandcul_cust_info
+            WHERE BA_CITY_ORG_CODE = :city_uuid 
+            AND BB_RETAIL_CUSTOMER_CODE IN ({cust_id_str})
+        """)
+        
+        with self.create_session() as session:
+            results = session.execute(query, {"city_uuid": city_uuid}).fetchall()
+            results = pd.DataFrame(results)
+        
+        return results
+        
+    def load_mock_data(self, tablename, query_text, page=1, page_size=1000):
+        # 创建一个空的DataFrame用于存储所有数据
+        total_df = pd.DataFrame()
+    
+        try:
+            while True:
+                offset = (page - 1) * page_size  # 计算偏移量
+                query = text(f"SELECT {query_text} FROM {tablename} LIMIT :limit OFFSET :offset")
+    
+                with self.create_session() as session:
+                    results = session.execute(query, { "limit": page_size, "offset": offset}).fetchall()
+                    df = pd.DataFrame(results)
                 if df.empty:
                     break
             

+ 110 - 0
gbdt_lr.py

@@ -0,0 +1,110 @@
+import argparse
+import os
+from models.rank import DataProcess, Trainer, GbdtLrModel
+import time
+import pandas as pd
+
+# train_data_path = "./moldes/rank/data/gbdt_data.csv"
+# model_path = "./models/rank/weights"
+
+def train(args):
+    model_dir = os.path.join(args.model_path, args.city_uuid)
+    if not os.path.exists(model_dir):
+        os.makedirs(model_dir)
+    
+    # 准备数据集  
+    print("正在整合训练数据...")
+    processor = DataProcess(args.city_uuid, args.train_data_path)
+    processor.data_process()
+    print("训练数据整合完成!")
+    
+    # 进行训练
+    trainer(args, model_dir)
+
+def trainer(args, model_dir):
+    trainer = Trainer(args.train_data_path)
+    
+    start_time = time.time()
+    trainer.train()
+    end_time = time.time()
+    
+    training_time_hours = (end_time - start_time) / 3600
+    print(f"训练时间: {training_time_hours:.4f} 小时")
+    
+    eval_metrics = trainer.evaluate()
+    
+    # 输出评估结果
+    print("GBDT-LR Evaluation Metrics:")
+    for metric, value in eval_metrics.items():
+        print(f"{metric}: {value:.4f}")
+        
+    # 保存模型
+    trainer.save_model(os.path.join(model_dir, args.model_name))
+
+def recommend_by_product(args):
+    model_dir = os.path.join(args.model_path, args.city_uuid)
+    if not os.path.exists(model_dir):
+        print("暂无该城市的模型,请先进行模型训练")
+        return
+    
+    # 加载模型
+    model = GbdtLrModel(os.path.join(model_dir, args.model_name))
+    recommend_list = model.sort(args.city_uuid, args.product_id)
+    for item in recommend_list[:min(args.last_n, len(recommend_list))]:
+        print(item)
+
+def get_features_importance(args):
+    model_dir = os.path.join(args.model_path, args.city_uuid)
+    if not os.path.exists(model_dir):
+        print("暂无该城市的模型,请先进行模型训练")
+        return
+    
+    # 加载模型
+    model = GbdtLrModel(os.path.join(model_dir, args.model_name))
+    cust_features_importance, product_features_importance = model.generate_feats_importance()
+    
+    # 将字典列表转换为 DataFrame
+    cust_df = pd.DataFrame([
+        {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
+        for item in cust_features_importance
+    ])
+    
+    product_df = pd.DataFrame([
+        {"Features": list(item.keys())[0], "Importance": list(item.values())[0]}
+        for item in product_features_importance
+    ])
+    
+    cust_file_path = os.path.join(model_dir, "cust_features_importance.csv")
+    product_file_path = os.path.join(model_dir, "product_features_importance.csv")
+    cust_df.to_csv(cust_file_path, index=False, encoding='utf-8')
+    product_df.to_csv(product_file_path, index=False, encoding='utf-8')
+        
+def run():
+    parser = argparse.ArgumentParser()
+    
+    parser.add_argument("--run_train", action='store_true')
+    parser.add_argument("--recommend", action='store_true')
+    parser.add_argument("--importance", action='store_true')
+    
+    parser.add_argument("--train_data_path", type=str, default="./models/rank/data/gbdt_data.csv")
+    parser.add_argument("--model_path", type=str, default="./models/rank/weights")
+    parser.add_argument("--model_name", type=str, default='model.pkl')
+    parser.add_argument("--last_n", type=int, default=200)
+    
+    parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301')
+    parser.add_argument("--product_id", type=str, default='110102')
+    
+    
+    args = parser.parse_args()
+    
+    if args.run_train:
+        train(args)
+        
+    if args.recommend:
+        recommend_by_product(args)
+        
+    if args.importance:
+        get_features_importance(args)
+        
+if __name__ == "__main__":
+    run()

+ 9 - 0
models/rank/__init__.py

@@ -1,2 +1,11 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
+from models.rank.data.preprocess import DataProcess
+from models.rank.gbdt_lr import Trainer
+from models.rank.gbdt_lr_sort import GbdtLrModel
+
+__all__ = [
+    "DataProcess",
+    "Trainer",
+    "GbdtLrModel"
+]

+ 12 - 0
models/rank/data/__init__.py

@@ -0,0 +1,12 @@
+from models.rank.data.config import CustConfig, ProductConfig, OrderConfig, ImportanceFeaturesMap
+from models.rank.data.dataloader import DataLoader
+from models.rank.data.utils import one_hot_embedding, sample_data_clear
+__all__ = [
+    "CustConfig",
+    "ProductConfig",
+    "OrderConfig",
+    "DataLoader",
+    "one_hot_embedding",
+    "sample_data_clear",
+    "ImportanceFeaturesMap"
+]

+ 292 - 0
models/rank/data/config.py

@@ -0,0 +1,292 @@
+class CustConfig:
+    FEATURE_COLUMNS = [
+        "BB_RETAIL_CUSTOMER_CODE",                     # 零售户代码
+        "BB_RTL_CUST_GRADE_NAME",                      # 零售户分档名称
+        "BB_RTL_CUST_MARKET_TYPE_NAME",                # 零售户市场类型名称
+        "STORE_AREA",                                  # 店铺经营面积
+        "BB_RTL_CUST_BUSINESS_TYPE_NAME",              # 零售户业态名称
+        "OPERATOR_EDU_LEVEL",                          # 零售客户经营者文化程
+        "OPERATOR_AGE",                                # 经营者年龄
+        "BB_RTL_CUST_CHAIN_FLAG",                      # 零售户连锁标识
+        "PRESENT_STAR_TERMINAL",                       # 终端星级
+        "MD04_MG_RTL_CUST_CREDITCLASS_NAME",           # 零售户信用等级名称
+        "MD04_DIR_SAL_STORE_FLAG",                     # 直营店标识
+        "BB_CUSTOMER_MANAGER_SCOPE_NAME",              # 零售户经营范围名称
+        "PRODUCT_INSALE_QTY",                          # 在销品规数
+        "CUST_INVESTMENT",                             # 店铺资源投入建设
+        
+        # "NEW_PRODUCT_MEMBERS_QTY_SAMEPRICE_OCC",       # 新品订货量占同价类比重
+        # "PRODUCT_LISTING_RATE",                        # 品规上架率
+        # "STOCKOUT_DAYS",                              # 断货天数
+        # "YLT_TURNOVER_RATE",                           # 易灵通动销率
+        # "YLT_BAR_PACKAGE_SALE_OCC",                    # 易灵通条包销售占比
+        # "UNPACKING_RATE",                              # 拆包率
+        
+        
+        # "BB_RTL_CUST_POSITION_TYPE_NAME",              # 零售户商圈类型名称
+        
+        # "BB_RTL_CUST_SUB_BUSI_PLACE_NAME",             # 零售户业态细分名称
+        
+        # "BB_RTL_CUST_TERMINAL_LEVEL_NAME",             # 零售户终端层级名称
+        # "BB_RTL_CUST_TERMINALEVEL_NAME",               # 零售户终端层级细分名称
+        # "MD04_MG_SAMPLE_CUST_FLAG",                    # 样本户标识
+        # "MD07_RTL_CUST_IS_SALE_LARGE_FLAG",            # 零售户大户标识
+        # "BB_RTL_CUST_OPERATE_METHOD_NAME",             # 零售户经营方式名称
+        # "BB_RTL_CUST_CGT_OPERATE_SCOPE_NAME",          # 零售户卷烟经营规模名称
+        
+        # "AVERAGE_CONSUMER_FLOW",                       # 月均消费人流
+        # "NEW_PRODUCT_MEMBERS_QTY",                     # 新品消费会员数量
+    ]
+    # 数据清洗规则
+    CLEANING_RULES = {
+        "BB_RTL_CUST_GRADE_NAME":                   {"method": "fillna", "opt": "fill", "value": "十五档", "type": "str"},
+        "BB_RTL_CUST_MARKET_TYPE_NAME":             {"method": "fillna", "opt": "fill", "value": "城网", "type": "str"},
+        "STORE_AREA":                               {"method": "fillna", "opt": "mean", "type": "num"},
+        "BB_RTL_CUST_BUSINESS_TYPE_NAME":           {"method": "fillna", "opt": "fill", "value": "其他", "type": "str"},
+        "OPERATOR_EDU_LEVEL":                       {"method": "fillna", "opt": "fill", "value": "无数据", "type": "str"},
+        "OPERATOR_AGE":                             {"method": "fillna", "opt": "mean", "type": "num"},
+        "BB_RTL_CUST_CHAIN_FLAG":                   {"method": "fillna", "opt": "fill", "value": "否", "type": "str"},
+        "PRESENT_STAR_TERMINAL":                    {"method": "fillna", "opt": "fill", "value": "非星级", "type": "str"},
+        "MD04_MG_RTL_CUST_CREDITCLASS_NAME":        {"method": "fillna", "opt": "fill", "value": "B", "type": "str"},
+        "MD04_DIR_SAL_STORE_FLAG":                  {"method": "fillna", "opt": "fill", "value": "否", "type": "str"},
+        "BB_CUSTOMER_MANAGER_SCOPE_NAME":           {"method": "fillna", "opt": "fill", "value": "否", "type": "str"},
+        "PRODUCT_INSALE_QTY":                       {"method": "fillna", "opt": "mean", "type": "num"},
+        "CUST_INVESTMENT":                          {"method": "fillna", "opt": "mean", "type": "num"}
+        
+        
+        # "NEW_PRODUCT_MEMBERS_QTY_SAMEPRICE_OCC":    {"method": "fillna", "opt": "mean", "type": "num"},
+        # "PRODUCT_LISTING_RATE":                     {"method": "fillna", "opt": "mean", "type": "num"},
+        # "STOCKOUT_DAYS":                            {"method": "fillna", "opt": "mean", "type": "num"},
+        # "YLT_TURNOVER_RATE":                        {"method": "fillna", "opt": "mean", "type": "num"},
+        # "NEW_PRODUCT_MEMBERS_QTY":                  {"method": "fillna", "opt": "mean", "type": "num"},
+        # "PRODUCT_INSALE_QTY":                       {"method": "fillna", "opt": "mean", "type": "num"},
+        # "UNPACKING_RATE":                           {"method": "fillna", "opt": "mean", "type": "num"},
+        
+        
+        
+        
+        # "BB_RTL_CUST_POSITION_TYPE_NAME":           {"method": "fillna", "opt": "fill", "value": "其他", "type": "str"},
+        # "BB_RTL_CUST_SUB_BUSI_PLACE_NAME":          {"method": "fillna", "opt": "fill", "value": "其他", "type": "str"},
+        # "BB_RTL_CUST_TERMINALEVEL_NAME":          {"method": "fillna", "opt": "replace", "value": "BB_RTL_CUST_TERMINAL_LEVEL_NAME", "type": "str"},
+        # "MD04_MG_SAMPLE_CUST_FLAG":                 {"method": "fillna", "value": "N", "opt": "fill"},
+        # "MD07_RTL_CUST_IS_SALE_LARGE_FLAG":         {"method": "fillna", "value": "N", "opt": "fill"},
+        # "BB_RTL_CUST_CGT_OPERATE_SCOPE_NAME":       {"method": "fillna", "value": "中", "opt": "fill"},
+    }
+    
+    ONEHOT_CAT = {
+        "BB_RTL_CUST_GRADE_NAME":                   ['一档', '二档', '三档', '四档', '五档', '六档', '七档', '八档', '九档', '十档', '十一档', '十二档', 
+                                                    '十三档', '十四档', '十五档', '十六档', '十七档', '十八档', '十九档', '二十档', '二十一档', '二十二档', 
+                                                    '二十三档', '二十四档', '二十五档', '二十六档', '二十七档', '二十八档', '二十九档', '三十档'],
+        "BB_RTL_CUST_MARKET_TYPE_NAME":             ["城网", "农网"],
+        "BB_RTL_CUST_BUSINESS_TYPE_NAME":           ["便利店", "超市", "烟草专业店", "娱乐服务类", "其他"],
+        "OPERATOR_EDU_LEVEL":                       [1, 2, 3, 4, 5, 6, 7, "无数据"],
+        "BB_RTL_CUST_CHAIN_FLAG":                   ["是", "否"],
+        "PRESENT_STAR_TERMINAL":                    ["一星", "二星", "三星", "四星", "五星", "非星级"],
+        "MD04_MG_RTL_CUST_CREDITCLASS_NAME":        ["AAA", "AA", "A", "B", "C", "D"],
+        "MD04_DIR_SAL_STORE_FLAG":                  ["是", "否"],
+        "BB_CUSTOMER_MANAGER_SCOPE_NAME":           ["是", "否"],
+        
+        
+        
+        # "BB_RTL_CUST_POSITION_TYPE_NAME":           ["居民区", "商业娱乐区", "交通枢纽区", "旅游景区", "工业区", "集贸区", "院校学区", "办公区", "其他"]
+    }
+    
+    
+    
+class ProductConfig:
+    FEATURE_COLUMNS = [
+        "product_code",                                # 商品编码
+        "direct_retail_price",                         # 建议零售价
+        "is_low_tar",                                  # 是否低焦油烟
+        "tar_qty",                                     # 焦油含量
+        "is_exploding_beads",                          # 是否爆珠
+        "is_shortbranch",                              # 是否短支烟
+        "is_medium",                                   # 是否中支烟
+        "is_tiny",                                     # 是否细支
+        "product_style_code_name",                     # 包装类型名称
+        "org_is_abnormity",                            # 是否异形包装
+        "is_chuangxin",                                # 是否创新品类
+        "is_key_brand",                                # 是否重点品牌
+        "foster_level_hy",                             # 是否行业共育品规
+        "foster_level_sj",                             # 是否省级共育品规
+        "is_cigar",                                    # 是否雪茄型卷烟
+        "co_qty",                                      # 一氧化碳含量
+        "tbc_total_length",                            # 烟支总长度
+        "tbc_length",                                  # 烟支长度
+        "filter_length",                               # 滤嘴长度
+        
+
+        
+        # "adjust_price",                                # 含税调拨价
+        # "notwithtax_adjust_price",                     # 不含税调拨价
+        # "whole_sale_price",                            # 统一批发价
+        # "allot_price",                                 # 调拨价
+        # "direct_whole_price",                          # 批发指导价
+        # "retail_price",                                # 零售价
+        # "price_type_name",                             # 卷烟价类名称
+        # "gear_type_name",                              # 卷烟档位名称
+        # "category_type_name",                          # 卷烟品类名称
+        # "is_high_level",                               # 是否高端烟
+        # "is_upscale_level",                            # 是否高端烟不含高价
+        # "is_high_price",                               # 是否高价烟
+        # "is_low_price",                                # 是否低价烟
+        # "is_encourage",                                # 是否全国鼓励品牌
+        # "is_abnormity",                                # 是否异形包装
+        # "is_intake",                                   # 是否进口烟
+        # "is_short",                                    # 是否紧俏品牌
+        # "is_ordinary_price_type",                      # 是否普一类烟
+        # "source_type",                                 # 来源类型
+        # "chinese_mix",                                 # 中式混合
+        # "sub_price_type_name",                         # 细分卷烟价类名称
+    ]
+    
+    CLEANING_RULES = {
+        "direct_retail_price":                         {"method": "fillna", "opt": "mean", "type": "num"},
+        "is_low_tar":                                  {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        "tar_qty":                                     {"method": "fillna", "opt": "mean", "type": "num"},
+        "is_exploding_beads":                          {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        "is_shortbranch":                              {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        "is_medium":                                   {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        "is_tiny":                                     {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        "product_style_code_name":                     {"method": "fillna", "opt": "fill", "type": "str", "value": "其他"},
+        "org_is_abnormity":                            {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        "is_chuangxin":                                {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        "is_key_brand":                                {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        "foster_level_hy":                             {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        "foster_level_sj":                             {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        "is_cigar":                                    {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        "co_qty":                                      {"method": "fillna", "opt": "mean", "type": "num"},
+        "tbc_total_length":                            {"method": "fillna", "opt": "mean", "type": "num"},
+        "tbc_length":                                  {"method": "fillna", "opt": "mean", "type": "num"},
+        "filter_length":                               {"method": "fillna", "opt": "mean", "type": "num"},
+        
+        
+        # "adjust_price":                                {"method": "fillna", "opt": "mean", "type": "num"},
+        # "notwithtax_adjust_price":                     {"method": "fillna", "opt": "mean", "type": "num"},
+        # "whole_sale_price":                            {"method": "fillna", "opt": "mean", "type": "num"},
+        # "allot_price":                                 {"method": "fillna", "opt": "fill", "type": "num", "value": 0.0},
+        # "direct_whole_price":                          {"method": "fillna", "opt": "mean", "type": "num"},
+        # "retail_price":                                {"method": "fillna", "opt": "mean", "type": "num"},
+        # "price_type_name":                             {"method": "fillna", "opt": "fill", "type": "str", "value": "一类烟"},
+        # "gear_type_name":                              {"method": "fillna", "opt": "fill", "type": "str", "value": "其他"},
+        # "category_type_name":                          {"method": "fillna", "opt": "fill", "type": "str", "value": "其他"},
+        # "is_high_level":                               {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        # "is_upscale_level":                            {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        # "is_high_price":                               {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        # "is_low_price":                                {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        # "is_encourage":                                {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        # "is_abnormity":                                {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        # "is_intake":                                   {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        # "is_short":                                    {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        # "is_ordinary_price_type":                      {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        # "source_type":                                 {"method": "fillna", "opt": "fill", "type": "str", "value": "其他"},
+        # "chinese_mix":                                 {"method": "fillna", "opt": "fill", "type": "str", "value": "否"},
+        # "sub_price_type_name":                         {"method": "fillna", "opt": "fill", "type": "str", "value": "普一类烟"},
+    }
+    
+
+    ONEHOT_CAT = {
+        "is_low_tar":                                  ["是", "否"],
+        "is_exploding_beads":                          ["是", "否"],
+        "is_shortbranch":                              ["是", "否"],
+        "is_medium":                                   ["是", "否"],
+        "is_tiny":                                     ["是", "否"],
+        "product_style_code_name":                     ["条盒硬盒", "条包硬盒", "条盒软盒", "条包软盒", "铁盒", "其他"],
+        "org_is_abnormity":                            ["是", "否"],
+        "is_chuangxin":                                ["是", "否"],
+        "is_key_brand":                                ["是", "否"],
+        "foster_level_hy":                             ["是", "否"],
+        "foster_level_sj":                             ["是", "否"],
+        "is_cigar":                                    ["是", "否"],
+        
+        
+        
+        # "price_type_name":                             ["一类烟", "二类烟", "三类烟", "四类烟", "五类烟", "无价类"],
+        # "gear_type_name":                              ["第一档位", "第二档位", "第三档位", "第四档位", "第五档位", "第六档位", "第七档位", "第八档位", "其他"],
+        # "category_type_name":                          ["第1品类", "第2品类", "第3品类", "第4品类", "第5品类", "第6品类", "第7品类", 
+        #                                                 "第8品类", "第9品类", "第10品类", "第11品类", "第12品类", "第13品类", "其他"],
+        # "is_high_level":                               ["是", "否"],
+        # "is_upscale_level":                            ["是", "否"],
+        # "is_high_price":                               ["是", "否"],
+        # "is_low_price":                                ["是", "否"],
+        # "is_encourage":                                ["是", "否"],
+        # "is_abnormity":                                ["是", "否"],
+        # "is_intake":                                   ["是", "否"],
+        # "is_short":                                    ["是", "否"],
+        # "is_ordinary_price_type":                      ["是", "否"],
+        # "source_type":                                 ["是", "否"],
+        # "chinese_mix":                                 ["是", "否"],
+        # "sub_price_type_name":                         ["高端烟", "高价位烟", "普一类烟", "二类烟", "三类烟", "四类烟", "五类烟", "无价类"],
+    }
+    
+class OrderConfig:
+    FEATURE_COLUMNS = [
+        "BB_RETAIL_CUSTOMER_CODE",                          # 零售户编码
+        "PRODUCT_CODE",                                     # 卷烟编码
+        "MONTH6_SALE_QTY",                                  # 近半年销量(箱)
+        "MONTH6_SALE_AMT",                                  # 近半年销售额(万元)
+        "MONTH6_GROSS_PROFIT_RATE",                         # 近半年毛利率
+        "MONTH6_SALE_QTY_YOY",                              # 销售量同比
+        "MONTH6_SALE_QTY_MOM",                              # 销售量环比
+        "MONTH6_SALE_AMT_YOY",                              # 销售额(购进额)同比
+        "MONTH6_SALE_AMT_MOM",                              # 销售额(购进额)环比
+        "STOCK_QTY",                                        # 库存
+        "ORDER_FULLORDR_RATE",                              # 订足率
+        "FULL_FILLMENT_RATE",                               # 订单满足率
+        "ORDER_FULLORDR_RATE_MOM",                          # 订足率环比
+        "CUSTOMER_REPURCHASE_RATE",                         # 会员重购率   
+        "DEMAND_RATE",                                      # 需求量满足率
+        "LISTING_RATE",                                     # 品规商上架率
+        "PUT_MARKET_FINISH_RATE",                           # 投放完成率
+        "OUT_STOCK_DAYS",                                   # 断货天数
+        "YLT_TURNOVER_RATE",                                # 易灵通动销率
+        "YLT_BAR_PACKAGE_SALE_OCC",                         # 易灵通调包销售占比
+        "UNPACKING_RATE",                                   # 拆包率
+        "POS_PACKAGE_PRICE",                                # pos机单包价格
+    ]
+    
+    WEIGHTS = {
+        "MONTH6_SALE_QTY":                                  0.15,
+        "MONTH6_SALE_QTY_MOM":                              0.2,
+        "ORDER_FULLORDR_RATE":                              0.3,
+        "ORDER_FULLORDR_RATE_MOM":                          0.35,
+    }
+    
+class ImportanceFeaturesMap:
+    CUSTOM_FEATRUES_MAP = {
+        "BB_RTL_CUST_GRADE_NAME":                           "零售户分档名称",
+        "BB_RTL_CUST_MARKET_TYPE_NAME":                     "零售户市场类型名称",
+        "STORE_AREA":                                       "店铺经营面积",
+        "BB_RTL_CUST_BUSINESS_TYPE_NAME":                   "零售户业态名称",
+        "OPERATOR_EDU_LEVEL":                               "零售客户经营者文化程",
+        "OPERATOR_AGE":                                     "经营者年龄",
+        "BB_RTL_CUST_CHAIN_FLAG":                           "零售户连锁标识",
+        "PRESENT_STAR_TERMINAL":                            "终端星级",
+        "MD04_MG_RTL_CUST_CREDITCLASS_NAME":                "零售户信用等级名称",
+        "MD04_DIR_SAL_STORE_FLAG":                          "直营店标识",
+        "BB_CUSTOMER_MANAGER_SCOPE_NAME":                   "零售户经营范围名称",
+        "PRODUCT_INSALE_QTY":                               "在销品规数",
+        "CUST_INVESTMENT":                                  "店铺资源投入建设",
+    }
+    
+    PRODUCT_FEATRUES_MAP = {
+        # ProductConfig 字段映射
+        "direct_retail_price":                              "建议零售价",
+        "is_low_tar":                                       "是否低焦油烟",
+        "tar_qty":                                          "焦油含量",
+        "is_exploding_beads":                               "是否爆珠",
+        "is_shortbranch":                                   "是否短支烟",
+        "is_medium":                                        "是否中支烟",
+        "is_tiny":                                          "是否细支",
+        "product_style_code_name":                          "包装类型名称",
+        "org_is_abnormity":                                 "是否异形包装",
+        "is_chuangxin":                                     "是否创新品类",
+        "is_key_brand":                                     "是否重点品牌",
+        "foster_level_hy":                                  "是否行业共育品规",
+        "foster_level_sj":                                  "是否省级共育品规",
+        "is_cigar":                                         "是否雪茄型卷烟",
+        "co_qty":                                           "一氧化碳含量",
+        "tbc_total_length":                                 "烟支总长度",
+        "tbc_length":                                       "烟支长度",
+        "filter_length":                                    "滤嘴长度",
+    }

+ 62 - 0
models/rank/data/dataloader.py

@@ -0,0 +1,62 @@
+import pandas as pd
+from models.rank.data.config import CustConfig, ProductConfig
+from sklearn.model_selection import train_test_split
+from sklearn.preprocessing import StandardScaler
+from models.rank.data.utils import one_hot_embedding
+
+class DataLoader:
+    def __init__(self,path):
+        self._gbdt_data_path = path
+        self._load_data()
+    
+    def _load_data(self):
+       
+        self._gbdt_data = pd.read_csv(self._gbdt_data_path, encoding="utf-8")
+        self._gbdt_data.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
+        self._gbdt_data.drop('product_code', axis=1, inplace=True)
+        
+        self._onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
+        
+        self._onehot_columns = list(self._onehot_feats.keys())
+        self._numeric_columns = self._gbdt_data.drop(self._onehot_columns + ["label"], axis=1).columns
+        
+        # 将类别数据进行one-hot编码
+        self._gbdt_data = one_hot_embedding(self._gbdt_data, self._onehot_feats)
+        
+    
+    def split_dataset(self):
+        """数据集划分,将数据集划分为训练集、验证集、测试集"""
+        # 1. 分离特征和标签
+        features = self._gbdt_data.drop("label", axis=1)
+        labels = self._gbdt_data["label"]
+        
+        # 2. 划分数据集,80%训练集、20%的测试集
+        X_train, X_test, y_train, y_test = train_test_split(
+            features, labels, 
+            test_size=0.2, 
+            random_state=42, 
+            shuffle=True,
+            stratify=labels,
+        )
+        
+        # 3. 数据标准化(仅对特征进行标准化)
+        scaler = StandardScaler()
+        X_train[self._numeric_columns] = scaler.fit_transform(X_train[self._numeric_columns])
+        X_test[self._numeric_columns] = scaler.fit_transform(X_test[self._numeric_columns])
+        
+        train_dataset = {"data": X_train, "label": y_train}
+        test_dataset = {"data": X_test, "label": y_test}
+        
+        return train_dataset, test_dataset
+    
+if __name__ == '__main__':
+    path = './models/rank/data/gbdt_data.csv'
+    dataloader = DataLoader(path)
+    train_dataset, test_dataset = dataloader.split_dataset()
+    
+    # 打印训练集和测试集的正负样本分布
+    print("训练集正负样本分布:")
+    print(train_dataset["label"].value_counts(normalize=True))
+    
+    print("测试集正负样本分布:")
+    print(test_dataset["label"].value_counts(normalize=True))

+ 163 - 0
models/rank/data/preprocess.py

@@ -0,0 +1,163 @@
+from dao.dao import load_cust_data_from_mysql, load_product_data_from_mysql, load_order_data_from_mysql
+from models.rank.data.config import CustConfig, ProductConfig, OrderConfig
+import os
+import pandas as pd
+from sklearn.preprocessing import MinMaxScaler
+from sklearn.utils import shuffle
+import numpy as np
+
+class DataProcess():
+    def __init__(self, city_uuid, save_path):
+        self._save_res_path = save_path
+        print("正在加载cust_info...")
+        self._cust_data = load_cust_data_from_mysql(city_uuid)
+        print("正在加载product_info...")
+        self._product_data = load_product_data_from_mysql(city_uuid)
+        print("正在加载order_info...")
+        self._order_data = load_order_data_from_mysql(city_uuid)
+        
+    def data_process(self):
+        """数据预处理"""
+        if os.path.exists(self._save_res_path):
+            os.remove(self._save_res_path)
+        
+        # 1. 获取指定的特征组合
+        self._cust_data = self._cust_data[CustConfig.FEATURE_COLUMNS]
+        self._product_data = self._product_data[ProductConfig.FEATURE_COLUMNS]
+        self._order_data = self._order_data[OrderConfig.FEATURE_COLUMNS]
+        
+        # 2. 数据清洗
+        self._clean_cust_data()
+        self._clean_product_data()
+        self._clean_order_data()
+        
+        # # 3. 将零售户信息表与卷烟信息表进行笛卡尔积连接
+        # self._descartes()
+        
+        # # 4. 根据order表中的信息给数据打标签
+        # self._labeled_data()
+        
+        # 3. 根据特征权重给order表中的记录打分
+        self._calculate_score()
+        
+        # 4. 根据中位数打标签
+        self.labeled_data()
+        
+        # 5. 选取训练样本
+        self._generate_train_data()
+        
+    
+    def _clean_cust_data(self):
+        """用户信息表数据清洗"""
+        # 根据配置规则清洗数据
+        for feature, rules, in CustConfig.CLEANING_RULES.items():
+            if rules["type"] == "num":
+                # 先将数值型字符串转换为数值
+                self._cust_data[feature] = pd.to_numeric(self._cust_data[feature], errors="coerce")
+                
+            if rules["method"] == "fillna":
+                if rules["opt"] == "fill":
+                    self._cust_data[feature] = self._cust_data[feature].fillna(rules["value"])
+                elif rules["opt"] == "replace":
+                    self._cust_data[feature] = self._cust_data[feature].fillna(self._cust_data[rules["value"]])
+                elif rules["opt"] == "mean":
+                    self._cust_data[feature] = self._cust_data[feature].fillna(self._cust_data[feature].mean())
+                self._cust_data[feature] = self._cust_data[feature].infer_objects(copy=False)
+    
+    def _clean_product_data(self):
+        """卷烟信息表数据清洗"""
+        for feature, rules, in ProductConfig.CLEANING_RULES.items():
+            if rules["type"] == "num":
+                self._product_data[feature] = pd.to_numeric(self._product_data[feature], errors="coerce")
+            
+            if rules["method"] == "fillna":
+                if rules["opt"] == "fill":
+                    self._product_data[feature] = self._product_data[feature].fillna(rules["value"])
+                elif rules["opt"] == "mean":
+                    self._product_data[feature] = self._product_data[feature].fillna(self._product_data[feature].mean())
+                self._product_data[feature] = self._product_data[feature].infer_objects(copy=False)
+                    
+    def _clean_order_data(self):
+        pass
+    
+    def _calculate_score(self):
+        """计算order记录的fens"""
+        self._order_score = self._order_data.copy()
+        # 对参与算分的特征值进行归一化
+        scaler = MinMaxScaler()
+        self._order_score[list(OrderConfig.WEIGHTS.keys())] = scaler.fit_transform(self._order_score[list(OrderConfig.WEIGHTS.keys())])
+        # 计算加权分数
+        self._order_score["score"] = sum(self._order_score[feat] * weight 
+                          for feat, weight in OrderConfig.WEIGHTS.items())
+    
+    def labeled_data(self):
+        """通过计算分数打标签"""
+        # 按品规分组计算中位数
+        product_medians = self._order_score.groupby("PRODUCT_CODE")["score"].median().reset_index()
+        product_medians.columns = ["PRODUCT_CODE", "median_score"]
+        
+        # 合并中位数到原始订单数据
+        self._order_score = pd.merge(self._order_score, product_medians, on="PRODUCT_CODE")
+        
+        # 生成标签 (1: 大于等于中位数, 0: 小于中位数)
+        self._order_score["label"] = np.where(
+            self._order_score["score"] >= self._order_score["median_score"], 1, 0
+        )
+        self._order_score = self._order_score.sort_values("score", ascending=False)
+        self._order_score = self._order_score[["BB_RETAIL_CUSTOMER_CODE", "PRODUCT_CODE", "label"]]
+        self._order_score.rename(columns={"PRODUCT_CODE": "product_code"}, inplace=True)
+    
+    def _generate_train_data(self):
+        cust_feats = self._cust_data.set_index("BB_RETAIL_CUSTOMER_CODE")
+        product_feats = self._product_data.set_index("product_code")
+        
+        self._train_data = self._order_score.copy()
+        
+        self._train_data = self._train_data.join(cust_feats, on="BB_RETAIL_CUSTOMER_CODE", how="left")
+        self._train_data = self._train_data.join(product_feats, on="product_code", how="left")
+        
+        self._train_data = shuffle(self._train_data, random_state=42)
+
+        self._train_data.to_csv(self._save_res_path, index=False)
+    
+    def _descartes(self):
+        """将零售户信息与卷烟信息进行笛卡尔积连接"""
+        self._cust_data["descartes"] = 1
+        self._product_data["descartes"] = 1
+        
+        self._descartes_data = pd.merge(self._cust_data, self._product_data, on="descartes").drop("descartes", axis=1)
+        
+    def _labeled_data_from_descartes(self):
+        """根据order表信息给descartes_data数据打标签"""
+        # 获取order表中的正样本组合
+        order_combinations = self._order_data[["BB_RETAIL_CUSTOMER_CODE", "PRODUCT_CODE"]].drop_duplicates()
+        order_set = set(zip(order_combinations["BB_RETAIL_CUSTOMER_CODE"], order_combinations["PRODUCT_CODE"]))
+        
+        # 在descartes_data中打标签:正样本为1,负样本为0
+        self._descartes_data['label'] = self._descartes_data.apply(
+            lambda row: 1 if (row['BB_RETAIL_CUSTOMER_CODE'], row['product_code']) in order_set else 0, axis=1)
+    
+    def _generate_train_data_from_descartes(self):
+        """从descartes_data中生成训练数据"""
+        positive_samples = self._descartes_data[self._descartes_data["label"] == 1]
+        negative_samples = self._descartes_data[self._descartes_data["label"] == 0]
+        
+        positive_count = len(positive_samples)
+        negative_count = min(1 * positive_count, len(negative_samples))
+        print(positive_count)
+        print(negative_count)
+        
+        # 随机抽取2倍正样本数量的负样本
+        negative_samples_sampled = negative_samples.sample(n=negative_count, random_state=42)
+        # 合并正负样本
+        self._train_data = pd.concat([positive_samples, negative_samples_sampled], axis=0)
+        self._train_data = self._train_data.sample(frac=1, random_state=42).reset_index(drop=True)
+        
+        # 保存训练数据
+        self._train_data.to_csv(self._save_res_path, index=False)
+    
+if __name__ == '__main__':
+    city_uuid = "00000000000000000000000011445301"
+    save_path = "./models/rank/data/gbdt_data.csv"
+    processor = DataProcess(city_uuid, save_path)
+    processor.data_process()

+ 24 - 0
models/rank/data/utils.py

@@ -0,0 +1,24 @@
+import pandas as pd
+def one_hot_embedding(dataframe, onehout_feat):
+    """对数据的指定特征做embedding编码"""
+    # 先将指定的特征进行Categorical处理
+    for feat, categories in onehout_feat.items():
+        dataframe[feat] = pd.Categorical(dataframe[feat], categories=categories, ordered=False)
+    dataframe = pd.get_dummies(
+        dataframe,
+        columns=list(onehout_feat.keys()),
+        prefix_sep="_",
+        dtype=int,
+    )
+    return dataframe
+
+def sample_data_clear(data, config):
+    for feature, rules, in config.CLEANING_RULES.items():
+        if rules["type"] == "num":
+            data[feature] = pd.to_numeric(data[feature], errors="coerce")
+        if rules["method"] == "fill":
+            if rules["type"] == "str":
+                data[feature] = data[feature].fillna(rules["value"])
+            elif rules["type"] == "num":
+                data[feature] = data[feature].fillna(0.0)
+    return data

+ 123 - 0
models/rank/gbdt_lr.py

@@ -1,2 +1,125 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
+import numpy as np
+from models.rank.data import DataLoader
+from sklearn.ensemble import GradientBoostingClassifier
+from sklearn.linear_model import LogisticRegression
+from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
+from sklearn.model_selection import GridSearchCV
+from sklearn.preprocessing import OneHotEncoder
+import joblib
+import time
+
+class Trainer:
+    def __init__(self, path):
+        self._load_data(path)
+        
+        # 初始化GBDT和LR模型参数
+        self._gbdt_params = {
+            'n_estimators': 100,
+            'learning_rate': 0.01,
+            'max_depth': 6,
+            'subsample': 0.8,
+            'random_state': 42,
+        }
+        self._lr_params = {
+            "max_iter": 1000,
+            'C': 1.0, 
+            'penalty': 'elasticnet', 
+            'l1_ratio': 0.8,  # 添加 l1_ratio 参数,可以根据需要调整
+            'solver': 'saga',
+            'random_state': 42,
+            'class_weight': 'balanced'
+        }
+        
+        # 初始化模型
+        self._gbdt_model = GradientBoostingClassifier(**self._gbdt_params)
+        self._lr_model = LogisticRegression(**self._lr_params)
+        
+        self._onehot_encoder = OneHotEncoder(sparse_output=True, handle_unknown='ignore')
+        
+    def _load_data(self, path):
+        dataloader = DataLoader(path)
+        self._train_dataset, self._test_dataset = dataloader.split_dataset()
+        
+    def train(self):
+        """模型训练"""
+        print("开始训练GBDT模型...")
+        # 训练GBDT模型
+        self._gbdt_model.fit(self._train_dataset["data"], self._train_dataset["label"])
+        
+        # 获取GBDT的每棵树的分数(决策值)
+        gbdt_train_preds = self._gbdt_model.apply(self._train_dataset["data"])[:, :, 0]  # 仅取每棵树的叶节点输出
+        
+        gbdt_feats_encoded = self._onehot_encoder.fit_transform(gbdt_train_preds)
+        
+        print("开始训练LR模型...")
+        # 使用决策树输出作为LR的输入特征
+        self._lr_model.fit(gbdt_feats_encoded, self._train_dataset["label"])
+        
+    def predict(self, X):
+        # 获取GBDT模型的预测分数
+        gbdt_preds = self._gbdt_model.apply(X)[:, :, 0]
+        
+        gbdt_feats_encoded = self._onehot_encoder.transform(gbdt_preds)
+        
+        # 使用训练好的LR模型输出概率
+        return self._lr_model.predict(gbdt_feats_encoded)
+    
+    def predict_proba(self, X):
+        # 获取GBDT模型的预测分数
+        gbdt_preds = self._gbdt_model.apply(X)[:, :, 0]
+        
+        gbdt_feats_encoded = self._onehot_encoder.transform(gbdt_preds)
+        
+        # 使用训练好的LR模型输出概率
+        return self._lr_model.predict_proba(gbdt_feats_encoded)
+        
+    def evaluate(self):
+        # 对测试集进行预测
+        y_pred = self.predict(self._test_dataset["data"])
+        y_pred_proba = self.predict_proba(self._test_dataset["data"])[:, 1]  # 获取正类的概率
+        
+        # 计算各类评估指标
+        accuracy = accuracy_score(self._test_dataset["label"], y_pred)
+        precision = precision_score(self._test_dataset["label"], y_pred)
+        recall = recall_score(self._test_dataset["label"], y_pred)
+        f1 = f1_score(self._test_dataset["label"], y_pred)
+        roc_auc = roc_auc_score(self._test_dataset["label"], y_pred_proba)    
+        
+        return {
+            'accuracy': accuracy,
+            'precision': precision,
+            'recall': recall,
+            'f1_score': f1,
+            'roc_auc': roc_auc
+        }
+        
+    def save_model(self, model_path):
+        """将模型保存到本地"""
+        models = {"gbdt_model": self._gbdt_model, "lr_model": self._lr_model, "onehot_encoder": self._onehot_encoder}
+        joblib.dump(models, model_path)
+    
+     
+if __name__ == "__main__":
+    gbdt_data_path = "./models/rank/data/gbdt_data.csv"
+    trainer = Trainer(gbdt_data_path)
+    
+    start_time = time.time()
+    trainer.train()
+    end_time = time.time()
+    
+    training_time_hours = (end_time - start_time) / 3600
+    print(f"训练时间: {training_time_hours:.4f} 小时")
+    
+    eval_metrics = trainer.evaluate()
+    
+    # 输出评估结果
+    print("GBDT-LR Evaluation Metrics:")
+    for metric, value in eval_metrics.items():
+        print(f"{metric}: {value:.4f}")
+        
+    # 保存模型
+    model_path = "./models/rank/weights/model.pkl"
+    trainer.save_model(model_path)
+    

+ 121 - 0
models/rank/gbdt_lr_sort.py

@@ -0,0 +1,121 @@
+import joblib
+from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
+from models.rank.data import ProductConfig, CustConfig, ImportanceFeaturesMap
+from models.rank.data.utils import one_hot_embedding, sample_data_clear
+import pandas as pd
+from sklearn.preprocessing import StandardScaler
+
+
+class GbdtLrModel:
+    def __init__(self, model_path):
+        self.load_model(model_path)
+        self.redis = Redis().redis
+    
+    def load_model(self, model_path):
+        models = joblib.load(model_path)
+        self.gbdt_model, self.lr_model, self.onehot_encoder = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
+        
+    
+    # def get_recall_list(self, city_uuid, product_id):
+    #     """根据卷烟id获取召回的商铺列表"""
+    #     key = f"fc:{city_uuid}:{product_id}"
+    #     self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
+    
+    # def load_recall_data(self, city_uuid, product_id):
+    #     self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
+    #     self.custs_data = get_custs_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
+        
+    def get_cust_and_product_data(self, city_uuid, product_id):
+        """从商户数据库中获取指定城市所有商户的id"""
+        self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
+        self.custs_data = load_cust_data_from_mysql(city_uuid)[CustConfig.FEATURE_COLUMNS]
+    
+    def generate_feats_map(self, city_uuid, product_id):
+        """组合卷烟、商户特征矩阵"""
+        # self.get_recall_list(city_uuid, product_id)
+        # self.load_recall_data(city_uuid, product_id)
+        
+        self.get_cust_and_product_data(city_uuid, product_id)
+        # 做数据清洗
+        self.product_data = sample_data_clear(self.product_data, ProductConfig)
+        self.custs_data = sample_data_clear(self.custs_data, CustConfig)
+        
+        # 笛卡尔积联合
+        self.custs_data["descartes"] = 1
+        self.product_data["descartes"] = 1
+        self.feats_map = pd.merge(self.custs_data, self.product_data, on="descartes").drop("descartes", axis=1)
+        self.recall_cust_list = self.feats_map["BB_RETAIL_CUSTOMER_CODE"].to_list()
+        self.feats_map.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
+        self.feats_map.drop('product_code', axis=1, inplace=True)
+        
+        # onehot编码
+        onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
+        onehot_columns = list(onehot_feats.keys())
+        numeric_columns = self.feats_map.drop(onehot_columns, axis=1).columns
+        self.feats_map = one_hot_embedding(self.feats_map, onehot_feats)
+        
+        # 数字特征归一化
+        scaler = StandardScaler()
+        self.feats_map[numeric_columns] = scaler.fit_transform(self.feats_map[numeric_columns])
+    
+    def sort(self, city_uuid, product_id):
+        self.generate_feats_map(city_uuid, product_id)
+        
+        gbdt_preds = self.gbdt_model.apply(self.feats_map)[:, :, 0]
+        gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
+        scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1]
+        
+        self.recommend_list = []
+        for cust_id, score in zip(self.recall_cust_list, scores):
+            self.recommend_list.append({cust_id: float(score)})
+            
+        self.recommend_list = sorted(self.recommend_list, key=lambda x: list(x.values())[0], reverse=True)
+        # for res in self.recommend_list[:200]:
+        #     print(res)
+        return self.recommend_list
+    
+    def generate_feats_importance(self):
+        """生成特征重要性"""
+        # 获取GBDT模型的特征重要性
+        feats_importance = self.gbdt_model.feature_importances_
+        
+        # 获取特征名称
+        feats_names = self.gbdt_model.feature_names_in_
+        
+        importance_dict = dict(zip(feats_names, feats_importance))
+        
+        onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
+        for feat, categories in onehot_feats.items():
+            related_columns = [col for col in feats_names if col.startswith(feat)]
+            if related_columns:
+                # 合并类别重要性
+                combined_importance = sum(importance_dict[col] for col in related_columns)
+                # 删除onehot类别列
+                for col in related_columns:
+                    del importance_dict[col]
+                # 添加合并后的重要性
+                importance_dict[feat] = combined_importance
+        
+        # 排序
+        sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
+        
+        # 输出特征重要性
+        cust_features_importance = []
+        product_features_importance = []
+        for feat, importance in sorted_importance:
+            if feat in list(ImportanceFeaturesMap.CUSTOM_FEATRUES_MAP.keys()):
+                cust_features_importance.append({ImportanceFeaturesMap.CUSTOM_FEATRUES_MAP[feat]: float(importance)})
+            if feat in list(ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP.keys()):
+                product_features_importance.append({ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP[feat]: float(importance)})
+        return cust_features_importance, product_features_importance
+    
+if __name__ == "__main__":
+    model_path = "./models/rank/weights/model.pkl"
+    city_uuid = "00000000000000000000000011445301"
+    product_id = "110102"
+    gbdt_sort = GbdtLrModel(model_path)
+    gbdt_sort.sort(city_uuid, product_id)
+    
+    importances = gbdt_sort.generate_feats_importance()
+    for importance in importances:
+        print(importance)

+ 2 - 1
models/recall/hot_recall.py

@@ -56,9 +56,10 @@ class HotRecallModel:
     def to_redis(self, rec_content_score, city_uuid):
         hotkey_name = rec_content_score["key"]
         rec_item_id = f"hot:{city_uuid}:{str(hotkey_name)}" # 修正 rec_item_id 拼接方式
-        
+        print("自动清除历史id前数量", self._redis_db.redis.zcard(rec_item_id))
         # 清空 sorted set 数据,确保不会影响后续的存储
         self._redis_db.redis.delete(rec_item_id)
+        print("自动清除历史id后数量", self._redis_db.redis.zcard(rec_item_id))
          
         res = {}
 

BIN
requirements.txt


+ 24 - 2
烟草模型部署文档.md

@@ -50,7 +50,29 @@ redis:
         3. 启动协同过滤  
         4. 启动系统过滤推理
 
-## 3、模型docker运行配置说明:
+## 3、GBDT LR模型训练推理启动
+
+### gbdt_lr.py
+
+```
+    parser.add_argument("--run_train", action='store_true')
+    parser.add_argument("--recommend", action='store_true')
+    parser.add_argument("--importance", action='store_true')
+
+    parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301')
+    parser.add_argument("--product_id", type=str, default='110102')
+```
+
+### gbdt_lr总共3个功能:
+
+1\. 启动gbdt_lr训练  python -m gbdt_lr --run_train --city_uuid "00000000000000000000000011445301"  
+        2. 根据城市id和product_id进行推荐,需要指定city_uuid、product_id。      python -m gbdt_lr --recommend --city_uuid "00000000000000000000000011445301" --product_id '110102'  
+        3. 获取指定城市的特征重要性指标。  python -m gbdt_lr --importance --city_uuid "00000000000000000000000011445301"    
+注意:在数据准备阶段,会将训练数据保存到./models/rank/data/gbdt_data.csv中  
+模型文件会存放在 ./models/rank/weights/city_uuid/model.pkl  
+重要性指标会存放在 ./models/rank/weights/下,分别是商户指标重要性和卷烟指标重要性  
+
+## 4、模型docker运行配置说明:
 
 ### docker镜像是:registry.cn-hangzhou.aliyuncs.com/hexiaoshi/brandcultivation:0.0.1
 
@@ -58,7 +80,7 @@ redis:
 docker run --name BrandCultivation -d -v /export/brandcultivation/crontab:/etc/cron.d/crontab -v /export/brandcultivation/database_config.yaml:/app/config/database_config.yaml  registry.cn-hangzhou.aliyuncs.com/hexiaoshi/brandcultivation:0.0.1
 ```
 
-## 4、模型kubernetes运行配置说明
+## 5、模型kubernetes运行配置说明
 
 yaml文件如下: