Просмотр исходного кода

更新打标规则和gbdt_lr模型推理流程

yangzeyu 1 год назад
Родитель
Сommit
edff1af8d4

+ 3 - 2
dao/__init__.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 from dao.mysql_client import Mysql
-from dao.dao import load_order_data_from_mysql, load_cust_data_from_mysql, load_product_data_from_mysql, get_product_by_id, get_custs_by_ids
+from dao.dao import load_order_data_from_mysql, load_cust_data_from_mysql, load_product_data_from_mysql, get_product_by_id, get_custs_by_ids, get_cust_list_from_database
 from dao.redis_db import Redis
 
 __all__ = [
@@ -11,5 +11,6 @@ __all__ = [
     "load_product_data_from_mysql",
     "Redis",
     "get_product_by_id",
-    "get_custs_by_ids"
+    "get_custs_by_ids",
+    "get_cust_list_from_database"
 ]

+ 19 - 6
dao/dao.py

@@ -3,20 +3,21 @@ from dao import Mysql
 def load_order_data_from_mysql(city_uuid):
     """从数据库中读取订单数据"""
     client = Mysql()
-    tablename = "tads_brandcul_cust_order"
+    tablename = "yunfu_mock_data"
     query_text = "*"
-    city_uuid = "00000000000000000000000011441801"
-    df = client.load_data(tablename, query_text, "city_uuid", city_uuid)
-    # df = client.load_mock_data(tablename, query_text)
+    # city_uuid = "00000000000000000000000011441801"
+    # df = client.load_data(tablename, query_text, "city_uuid", city_uuid)
+    df = client.load_mock_data(tablename, query_text)
     if len(df) == 0:
         return None
     
-    df.drop('stat_month', axis=1, inplace=True)
-    df.drop('city_uuid', axis=1, inplace=True)
+    # df.drop('stat_month', axis=1, inplace=True)
+    # df.drop('city_uuid', axis=1, inplace=True)
     
     # 去除重复值和填补缺失值
     df.drop_duplicates(inplace=True)
     df.fillna(0, inplace=True)
+    df = df.infer_objects(copy=False)
     return df
 
 def load_cust_data_from_mysql(city_uuid):
@@ -31,6 +32,18 @@ def load_cust_data_from_mysql(city_uuid):
     
     return df
 
+def get_cust_list_from_database(city_uuid):
+    client = Mysql()
+    tablename = "tads_brandcul_cust_info"
+    query_text = "*"
+    
+    df = client.load_data(tablename, query_text, "BA_CITY_ORG_CODE", city_uuid)
+    cust_list = df["BB_RETAIL_CUSTOMER_CODE"].to_list()
+    if len(cust_list) == 0:
+        return []
+    
+    return cust_list
+
 def load_product_data_from_mysql(city_uuid):
     """从数据库中读取商品信息"""
     client = Mysql()

+ 15 - 3
models/rank/data/dataloader.py

@@ -1,6 +1,5 @@
 import pandas as pd
 from models.rank.data.config import CustConfig, ProductConfig
-from sklearn.preprocessing import OneHotEncoder
 from sklearn.model_selection import train_test_split
 from sklearn.preprocessing import StandardScaler
 from models.rank.data.utils import one_hot_embedding
@@ -32,7 +31,13 @@ class DataLoader:
         labels = self._gbdt_data["label"]
         
         # 2. 划分数据集,80%训练集、20%的测试集
-        X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42, shuffle=True)
+        X_train, X_test, y_train, y_test = train_test_split(
+            features, labels, 
+            test_size=0.2, 
+            random_state=42, 
+            shuffle=True,
+            stratify=labels,
+        )
         
         # 3. 数据标准化(仅对特征进行标准化)
         scaler = StandardScaler()
@@ -47,4 +52,11 @@ class DataLoader:
 if __name__ == '__main__':
     path = './models/rank/data/gbdt_data.csv'
     dataloader = DataLoader(path)
-    dataloader.split_dataset()
+    train_dataset, test_dataset = dataloader.split_dataset()
+    
+    # 打印训练集和测试集的正负样本分布
+    print("训练集正负样本分布:")
+    print(train_dataset["label"].value_counts(normalize=True))
+    
+    print("测试集正负样本分布:")
+    print(test_dataset["label"].value_counts(normalize=True))

+ 21 - 5
models/rank/data/preprocess.py

@@ -3,6 +3,7 @@ from models.rank.data.config import CustConfig, ProductConfig, OrderConfig
 import os
 import pandas as pd
 from sklearn.preprocessing import MinMaxScaler
+from sklearn.utils import shuffle
 import numpy as np
 
 class DataProcess():
@@ -40,7 +41,7 @@ class DataProcess():
         self._calculate_score()
         
         # 4. 根据中位数打标签
-        self.labeled_data_by_score()
+        self.labeled_data()
         
         # 5. 选取训练样本
         self._generate_train_data()
@@ -61,6 +62,7 @@ class DataProcess():
                     self._cust_data[feature] = self._cust_data[feature].fillna(self._cust_data[rules["value"]])
                 elif rules["opt"] == "mean":
                     self._cust_data[feature] = self._cust_data[feature].fillna(self._cust_data[feature].mean())
+                self._cust_data[feature] = self._cust_data[feature].infer_objects(copy=False)
     
     def _clean_product_data(self):
         """卷烟信息表数据清洗"""
@@ -73,6 +75,7 @@ class DataProcess():
                     self._product_data[feature] = self._product_data[feature].fillna(rules["value"])
                 elif rules["opt"] == "mean":
                     self._product_data[feature] = self._product_data[feature].fillna(self._product_data[feature].mean())
+                self._product_data[feature] = self._product_data[feature].infer_objects(copy=False)
                     
     def _clean_order_data(self):
         pass
@@ -87,7 +90,7 @@ class DataProcess():
         self._order_score["score"] = sum(self._order_score[feat] * weight 
                           for feat, weight in OrderConfig.WEIGHTS.items())
     
-    def labeled_data_by_score(self):
+    def labeled_data(self):
         """通过计算分数打标签"""
         # 按品规分组计算中位数
         product_medians = self._order_score.groupby("PRODUCT_CODE")["score"].median().reset_index()
@@ -102,7 +105,20 @@ class DataProcess():
         )
         self._order_score = self._order_score.sort_values("score", ascending=False)
         self._order_score = self._order_score[["BB_RETAIL_CUSTOMER_CODE", "PRODUCT_CODE", "label"]]
-        self._order_score.to_csv("./models/rank/data/train.csv")
+        self._order_score.rename(columns={"PRODUCT_CODE": "product_code"}, inplace=True)
+    
+    def _generate_train_data(self):
+        cust_feats = self._cust_data.set_index("BB_RETAIL_CUSTOMER_CODE")
+        product_feats = self._product_data.set_index("product_code")
+        
+        self._train_data = self._order_score.copy()
+        
+        self._train_data = self._train_data.join(cust_feats, on="BB_RETAIL_CUSTOMER_CODE", how="left")
+        self._train_data = self._train_data.join(product_feats, on="product_code", how="left")
+        
+        self._train_data = shuffle(self._train_data, random_state=42)
+        
+        self._train_data.to_csv(self._save_res_path, index=False)
     
     def _descartes(self):
         """将零售户信息与卷烟信息进行笛卡尔积连接"""
@@ -111,7 +127,7 @@ class DataProcess():
         
         self._descartes_data = pd.merge(self._cust_data, self._product_data, on="descartes").drop("descartes", axis=1)
         
-    def _labeled_data(self):
+    def _labeled_data_from_descartes(self):
         """根据order表信息给descartes_data数据打标签"""
         # 获取order表中的正样本组合
         order_combinations = self._order_data[["BB_RETAIL_CUSTOMER_CODE", "PRODUCT_CODE"]].drop_duplicates()
@@ -121,7 +137,7 @@ class DataProcess():
         self._descartes_data['label'] = self._descartes_data.apply(
             lambda row: 1 if (row['BB_RETAIL_CUSTOMER_CODE'], row['product_code']) in order_set else 0, axis=1)
     
-    def _generate_train_data(self):
+    def _generate_train_data_from_descartes(self):
         """从descartes_data中生成训练数据"""
         positive_samples = self._descartes_data[self._descartes_data["label"] == 1]
         negative_samples = self._descartes_data[self._descartes_data["label"] == 0]

+ 22 - 13
models/rank/gbdt_lr_sort.py

@@ -1,5 +1,5 @@
 import joblib
-from dao import Redis, get_product_by_id, get_custs_by_ids
+from dao import Redis, get_product_by_id, get_custs_by_ids, load_cust_data_from_mysql
 from models.rank.data import ProductConfig, CustConfig
 from models.rank.data.utils import one_hot_embedding, sample_data_clear
 import pandas as pd
@@ -16,19 +16,26 @@ class GbdtLrModel:
         self.gbdt_model, self.lr_model, self.onehot_encoder = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
         
     
-    def get_recall_list(self, city_uuid, product_id):
-        """根据卷烟id获取召回的商铺列表"""
-        key = f"fc:{city_uuid}:{product_id}"
-        self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
+    # def get_recall_list(self, city_uuid, product_id):
+    #     """根据卷烟id获取召回的商铺列表"""
+    #     key = f"fc:{city_uuid}:{product_id}"
+    #     self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
     
-    def load_recall_data(self, city_uuid, product_id):
+    # def load_recall_data(self, city_uuid, product_id):
+    #     self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
+    #     self.custs_data = get_custs_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
+        
+    def get_cust_and_product_data(self, city_uuid, product_id):
+        """从商户数据库中获取指定城市所有商户的id"""
         self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
-        self.custs_data = get_custs_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
+        self.custs_data = load_cust_data_from_mysql(city_uuid)[CustConfig.FEATURE_COLUMNS]
     
     def generate_feats_map(self, city_uuid, product_id):
         """组合卷烟、商户特征矩阵"""
-        self.get_recall_list(city_uuid, product_id)
-        self.load_recall_data(city_uuid, product_id)
+        # self.get_recall_list(city_uuid, product_id)
+        # self.load_recall_data(city_uuid, product_id)
+        
+        self.get_cust_and_product_data(city_uuid, product_id)
         # 做数据清洗
         self.product_data = sample_data_clear(self.product_data, ProductConfig)
         self.custs_data = sample_data_clear(self.custs_data, CustConfig)
@@ -65,6 +72,7 @@ class GbdtLrModel:
         self.recommend_list = sorted(self.recommend_list, key=lambda x: list(x.values())[0], reverse=True)
         for res in self.recommend_list[:200]:
             print(res)
+        return self.recommend_list
     
     def generate_feats_importance(self):
         """生成特征重要性"""
@@ -102,7 +110,8 @@ if __name__ == "__main__":
     city_uuid = "00000000000000000000000011445301"
     product_id = "110102"
     gbdt_sort = GbdtLrModel(model_path)
-    # gbdt_sort.sort(city_uuid, product_id)
-    importances = gbdt_sort.generate_feats_importance()
-    for importance in importances:
-        print(importance)
+    gbdt_sort.sort(city_uuid, product_id)
+    
+    # importances = gbdt_sort.generate_feats_importance()
+    # for importance in importances:
+    #     print(importance)

BIN
models/rank/weights/model.pkl