Kaynağa Gözat

gbdt-lr流程封装

Sherlock 1 yıl önce
ebeveyn
işleme
079b03bf0a
6 değiştirilmiş dosya ile 113 ekleme ve 6 silme
  1. 2 2
      app.py
  2. 6 2
      dao/__init__.py
  3. 16 0
      dao/dao.py
  4. 34 0
      dao/mysql_client.py
  5. 13 2
      models/rank/gbdt_lr.py
  6. 42 0
      models/rank/gbdt_lr_sort.py

+ 2 - 2
app.py

@@ -59,9 +59,9 @@ def run():
     # parser.add_argument("--similarity_matrix_path", type=str, default="./models/recall/itemCF/matrix/similarity.csv")
     parser.add_argument("--n", type=int, default=100)
     parser.add_argument("--k", type=int, default=20)
-    parser.add_argument("--top_n", type=int, default=2000, help='default n * k')
+    parser.add_argument("--top_n", type=int, default=200, help='default n * k')
     parser.add_argument("--n_jobs", type=int, default=4)
-    parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011441801', help="City UUID for filtering data")
+    parser.add_argument("--city_uuid", type=str, default='00000000000000000000000011445301', help="City UUID for filtering data")
     
     # 协同过滤推理配置
     parser.add_argument("--product_code", type=int, default=110111)

+ 6 - 2
dao/__init__.py

@@ -1,11 +1,15 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 from dao.mysql_client import Mysql
-from dao.dao import load_order_data_from_mysql, load_cust_data_from_mysql, load_product_data_from_mysql
+from dao.dao import load_order_data_from_mysql, load_cust_data_from_mysql, load_product_data_from_mysql, get_product_by_id, get_custs_by_ids
+from dao.redis_db import Redis
 
 __all__ = [
     "Mysql",
     "load_order_data_from_mysql",
     "load_cust_data_from_mysql",
-    "load_product_data_from_mysql"
+    "load_product_data_from_mysql",
+    "Redis",
+    "get_product_by_id",
+    "get_custs_by_ids"
 ]

+ 16 - 0
dao/dao.py

@@ -43,6 +43,22 @@ def load_product_data_from_mysql(city_uuid):
     
     return df
 
+def get_product_by_id(city_uuid, product_id):
+    client = Mysql()
+    
+    res = client.get_product_by_id(city_uuid, product_id)
+    if len(res) == 0:
+        return None
+    return res
+
+def get_custs_by_ids(city_uuid, cust_ids):
+    client = Mysql()
+    
+    res = client.get_cust_by_ids(city_uuid, cust_ids)
+    if len(res) == 0:
+        return None
+    return res
+
 if __name__ == '__main__':
     data = load_order_data_from_mysql("00000000000000000000000011445301")
     print(data)

+ 34 - 0
dao/mysql_client.py

@@ -72,6 +72,40 @@ class Mysql(object):
             self.closed()
             return total_df
         
+    def get_product_by_id(self, city_uuid, product_id):
+        """根据 city_uuid 和 product_id 从表中获取品规信息"""
+        query = text("""
+            SELECT * 
+            FROM tads_brandcul_product_info 
+            WHERE city_uuid = :city_uuid 
+            AND product_code = :product_id
+        """)
+        
+        with self.create_session() as session:
+            result = session.execute(query, {"city_uuid": city_uuid, "product_id": product_id}).fetchall()
+            result = pd.DataFrame(result)
+        return result
+        
+    def get_cust_by_ids(self, city_uuid, cust_id_list):
+        """根据 city_uuid 和 cust_id 列表从表中获取零售户信息"""
+        if not cust_id_list:
+            return []
+        
+        cust_id_str = ",".join([f"'{cust_id}'" for cust_id in cust_id_list])
+        
+        query = text(f"""
+            SELECT * 
+            FROM tads_brandcul_cust_info 
+            WHERE BA_CITY_ORG_CODE = :city_uuid 
+            AND BB_RETAIL_CUSTOMER_CODE IN ({cust_id_str})
+        """)
+        
+        with self.create_session() as session:
+            results = session.execute(query, {"city_uuid": city_uuid}).fetchall()
+            results = pd.DataFrame(results)
+        
+        return results
+        
     def load_mock_data(self, tablename, query_text, page=1, page_size=1000):
         # 创建一个空的DataFrame用于存储所有数据
         total_df = pd.DataFrame()

+ 13 - 2
models/rank/gbdt_lr.py

@@ -7,6 +7,7 @@ from sklearn.linear_model import LogisticRegression
 from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
 from sklearn.model_selection import GridSearchCV
 from sklearn.preprocessing import OneHotEncoder
+import joblib
 
 class Trainer:
     def __init__(self, path):
@@ -24,7 +25,8 @@ class Trainer:
             "max_iter": 1000,
             'C': 1.0, 
             'penalty': 'l2', 
-            'solver': 'liblinear',
+            # 'l1_ratio': 0.5,  # 添加 l1_ratio 参数,可以根据需要调整
+            'solver': 'sag',
             'random_state': 42,
             'class_weight': 'balanced'
         }
@@ -82,7 +84,7 @@ class Trainer:
         precision = precision_score(self._test_dataset["label"], y_pred)
         recall = recall_score(self._test_dataset["label"], y_pred)
         f1 = f1_score(self._test_dataset["label"], y_pred)
-        roc_auc = roc_auc_score(self._test_dataset["label"], y_pred_proba)
+        roc_auc = roc_auc_score(self._test_dataset["label"], y_pred_proba)    
         
         return {
             'accuracy': accuracy,
@@ -91,6 +93,11 @@ class Trainer:
             'f1_score': f1,
             'roc_auc': roc_auc
         }
+        
+    def save_model(self, model_path):
+        """将模型保存到本地"""
+        models = {"gbdt_model": self._gbdt_model, "lr_model": self._lr_model, "onehot_encoder": self._onehot_encoder}
+        joblib.dump(models, model_path)
     
      
 if __name__ == "__main__":
@@ -103,4 +110,8 @@ if __name__ == "__main__":
     print("GBDT-LR Evaluation Metrics:")
     for metric, value in eval_metrics.items():
         print(f"{metric}: {value:.4f}")
+        
+    # 保存模型
+    model_path = "./models/rank/weights/model.pkl"
+    trainer.save_model(model_path)
     

+ 42 - 0
models/rank/gbdt_lr_sort.py

@@ -0,0 +1,42 @@
+import joblib
+from dao import Redis, get_product_by_id, get_custs_by_ids
+from models.rank.data import ProductConfig, CustConfig
+class GbdtLrSort:
+    def __init__(self, model_path):
+        self.load_model(model_path)
+        self.redis = Redis().redis
+    
+    def load_model(self, model_path):
+        models = joblib.load(model_path)
+        self.gbdt_model = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
+        
+    
+    def get_recall_list(self, city_uuid, product_id):
+        """根据卷烟id获取召回的商铺列表"""
+        key = f"fc:{city_uuid}:{product_id}"
+        self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
+    
+    def load_recall_data(self, city_uuid, product_id):
+        self.product_data = get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
+        self.custs_data = get_custs_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
+        print(self.product_data)
+    
+    def generate_feats_map(self, city_uuid, product_id):
+        """组合卷烟、商户特征矩阵"""
+        self.get_recall_list(city_uuid, product_id)
+        self.load_recall_data(city_uuid, product_id)
+        # 做数据清洗
+        
+    
+    def sort(self, city_uuid, product_id):
+        pass
+    
+    def generate_feats_importance(self):
+        pass
+    
+if __name__ == "__main__":
+    model_path = "./models/rank/weights/model.pkl"
+    city_uuid = "00000000000000000000000011445301"
+    product_id = "110102"
+    gbdt_sort = GbdtLrSort(model_path)
+    gbdt_sort.generate_feats_map(city_uuid, product_id)