浏览代码

封装推理流程

Sherlock 11 月之前
父节点
当前提交
08799dee8b
共有 6 个文件被更改,包括 162 次插入78 次删除
  1. 23 6
      database/dao/mysql_dao.py
  2. 73 9
      inference.py
  3. 1 1
      models/rank/__init__.py
  4. 7 6
      models/rank/data/utils.py
  5. 31 54
      models/rank/gbdt_lr_inference.py
  6. 27 2
      utils/result_process.py

+ 23 - 6
database/dao/mysql_dao.py

@@ -81,14 +81,14 @@ class MySqlDao:
     
     def get_product_by_id(self, city_uuid, product_id):
         """根据city_uuid 和 product_id 从表中获取拼柜信息"""
-        query = text(f"""
+        query = f"""
             SELECT *
             FROM {self._product_tablename}
             WHERE city_uuid = :city_uuid
             AND product_code = :product_id
-        """)
+        """
         params = {"city_uuid": city_uuid, "product_id": product_id}
-        data = self.db_helper.fetch_one(query, params)
+        data = self.db_helper.load_data_with_page(query, params)
         
         return data
     
@@ -98,14 +98,31 @@ class MySqlDao:
             return None
         
         cust_id_str = ",".join([f"'{cust_id}'" for cust_id in cust_id_list])
-        query = text(f"""
+        query = f"""
             SELECT *
             FROM {self._cust_tablename}
             WHERE BA_CITY_ORG_CODE = :city_uuid
             AND BB_RETAIL_CUSTOMER_CODE IN ({cust_id_str})
-        """)
+        """
+        params = {"city_uuid": city_uuid}
+        data = self.db_helper.load_data_with_page(query, params)
+        
+        return data
+    
+    def get_shop_by_ids(self, city_uuid, cust_id_list):
+        """根据零售户列表查询其信息"""
+        if not cust_id_list:
+            return None
+        
+        cust_id_str = ",".join([f"'{cust_id}'" for cust_id in cust_id_list])
+        query = f"""
+            SELECT *
+            FROM {self._shopping_tablename}
+            WHERE city_uuid = :city_uuid
+            AND cust_code IN ({cust_id_str})
+        """
         params = {"city_uuid": city_uuid}
-        data = self.db_helper.fetch_all(query, params)
+        data = self.db_helper.load_data_with_page(query, params)
         
         return data
     

+ 73 - 9
inference.py

@@ -1,23 +1,87 @@
-def itemcf_inference():
+
+from database import RedisDatabaseHelper, MySqlDao
+from models.rank.data.config import CustConfig, ProductConfig, ShopConfig
+from models.rank.data.utils import sample_data_clear
+from models.rank.gbdt_lr_inference import GbdtLrModel
+from utils.result_process import split_relation_subtable
+import pandas as pd
+
+redis = RedisDatabaseHelper().redis
+dao = MySqlDao()
+gbdtlr_model = GbdtLrModel("./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl")
+
+def get_itemcf_recall(city_uuid, product_id):
+    """协同召回"""
+    key = f"fc:{city_uuid}:{product_id}"
+    recall_list = redis.zrevrange(key, 0, -1, withscores=False)
+    return recall_list
+
+def get_hot_recall(city_uuid):
+    """热度召回"""
+    key = f"hot:{city_uuid}:sale_qty"
+    recall_list = redis.zrevrange(key, 0, -1, withscores=False)
+    return recall_list
+
+def get_recall_cust(city_uuid, product_id, recall_count):
+    """根据协同过滤和热度召回召回商户"""
+    itemcf_recall_list = get_itemcf_recall(city_uuid, product_id)
+    hot_recall_list = get_hot_recall(city_uuid)
     
-    pass
+    result = list(dict.fromkeys(itemcf_recall_list))
+    
+    # 如果结果不足,从hot_recall中补齐
+    if len(result) < recall_count:
+        hot_recall_set = set(hot_recall_list) - set(result)
+        additional_items = [item for item in hot_recall_list if item in hot_recall_set]
+        needed = recall_count - len(result)
+        result.extend(additional_items[:needed])
+    return result[:recall_count]
 
-def hotrecall_inference():
+def generate_recommend_sample(city_uuid, product_id):
+    """生成预测数据集"""
+    recall_count = 300
+    cust_list = get_recall_cust(city_uuid, product_id, recall_count)
     
-    pass
+    product_data = dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
+    filter_dict = product_data.to_dict("records")[0]
+    cust_data = dao.get_cust_by_ids(city_uuid, cust_list)[CustConfig.FEATURE_COLUMNS]
+    shop_data = dao.get_shop_by_ids(city_uuid, cust_list)[ShopConfig.FEATURE_COLUMNS]
+    
+    product_data = sample_data_clear(product_data, ProductConfig)
+    cust_data = sample_data_clear(cust_data, CustConfig)
+    shop_data = sample_data_clear(shop_data, ShopConfig)
+    
+    cust_feats = shop_data.set_index("cust_code")
+    cust_data = cust_data.join(cust_feats, on="BB_RETAIL_CUSTOMER_CODE", how="inner")
+    
+    feats_map = gbdtlr_model.generate_feats_map(product_data, cust_data)
+    
+    return feats_map, filter_dict, cust_list
 
-def gbdt_lr_inference():
+def get_recommend_list(city_uuid, product_id):
+    feats_sample, _, cust_list = generate_recommend_sample(city_uuid, product_id)
+    recommend_list = gbdtlr_model.get_recommend_list(feats_sample, cust_list)
+    return recommend_list
     
+
+def gbdt_lr_inference(city_uuid, product_id):
     pass
 
-def generate_features_shap():
+def generate_features_shap(city_uuid, product_id):
+    feats_sample, filter_dict, _ = generate_recommend_sample(city_uuid, product_id)
+    result = gbdtlr_model.generate_shap_interance(feats_sample)
+    split_relation_subtable(result, filter_dict, "./data")
     
-    pass
 
 def generate_delivery_strategy():
     
     pass
 
 def run():
-    
-    pass
+    pass
+
+if __name__ == '__main__':
+    recommend_list = get_recommend_list("00000000000000000000000011445301", "350139")
+    recommend_list = pd.DataFrame(recommend_list)
+    recommend_list.to_csv("./data/推荐商户表.csv", index=False)
+    

+ 1 - 1
models/rank/__init__.py

@@ -2,7 +2,7 @@
 # -*- coding:utf-8 -*-
 from models.rank.data.preprocess import DataProcess
 from models.rank.gbdt_lr import Trainer
-from models.rank.gbdt_lr_sort import GbdtLrModel
+from models.rank.gbdt_lr_inference import GbdtLrModel
 
 __all__ = [
     "DataProcess",

+ 7 - 6
models/rank/data/utils.py

@@ -13,12 +13,13 @@ def one_hot_embedding(dataframe, onehout_feat):
     return dataframe
 
 def sample_data_clear(data, config):
-    for feature, rules, in config.CLEANING_RULES.items():
+    for feature, rules, in config.CLEANING_RULES.items():    
         if rules["type"] == "num":
             data[feature] = pd.to_numeric(data[feature], errors="coerce")
-        if rules["method"] == "fill":
-            if rules["type"] == "str":
-                data[feature] = data[feature].fillna(rules["value"])
-            elif rules["type"] == "num":
-                data[feature] = data[feature].fillna(0.0)
+        if rules["method"] == "fillna":
+            if rules["opt"] == "fill":
+                data[feature] = data[feature].fillna(rules["value"]).infer_objects(copy=False)
+            elif rules["opt"] == "mean":
+                data[feature] = data[feature].fillna(data[feature].mean()).infer_objects(copy=False)
+            data[feature] = data[feature].infer_objects(copy=False)
     return data

+ 31 - 54
models/rank/gbdt_lr_inference.py

@@ -24,64 +24,50 @@ class GbdtLrModel:
         models = joblib.load(model_path)
         self.gbdt_model, self.lr_model, self.onehot_encoder = models["gbdt_model"], models["lr_model"], models["onehot_encoder"]
         
-    
-    # def get_recall_list(self, city_uuid, product_id):
-    #     """根据卷烟id获取召回的商铺列表"""
-    #     key = f"fc:{city_uuid}:{product_id}"
-    #     self.recall_cust_list = self.redis.zrange(key, 0, -1, withscores=False)
-    
-    # def load_recall_data(self, city_uuid, product_id):
-    #     self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
-    #     self.custs_data = self._mysql_dao.get_cust_by_ids(city_uuid, self.recall_cust_list)[CustConfig.FEATURE_COLUMNS]
-        
     def get_cust_and_product_data(self, city_uuid, product_id):
         """从商户数据库中获取指定城市所有商户的id"""
         self.product_data = self._mysql_dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
         self.custs_data = self._mysql_dao.load_cust_data(city_uuid)[CustConfig.FEATURE_COLUMNS]
     
-    def generate_feats_map(self, city_uuid, product_id):
+    def generate_feats_map(self, product_data, cust_data):
         """组合卷烟、商户特征矩阵"""
-        # self.get_recall_list(city_uuid, product_id)
-        # self.load_recall_data(city_uuid, product_id)
-        
-        self.get_cust_and_product_data(city_uuid, product_id)
-        # 做数据清洗
-        self.product_data = sample_data_clear(self.product_data, ProductConfig)
-        self.custs_data = sample_data_clear(self.custs_data, CustConfig)
-        
         # 笛卡尔积联合
-        self.custs_data["descartes"] = 1
-        self.product_data["descartes"] = 1
-        self.feats_map = pd.merge(self.custs_data, self.product_data, on="descartes").drop("descartes", axis=1)
-        self.recall_cust_list = self.feats_map["BB_RETAIL_CUSTOMER_CODE"].to_list()
-        self.feats_map.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
-        self.feats_map.drop('product_code', axis=1, inplace=True)
+        cust_data["descartes"] = 1
+        product_data["descartes"] = 1
+        feats_map = pd.merge(cust_data, product_data, on="descartes").drop("descartes", axis=1)
+        # recall_cust_list = feats_map["BB_RETAIL_CUSTOMER_CODE"].to_list()
+        feats_map.drop('BB_RETAIL_CUSTOMER_CODE', axis=1, inplace=True)
+        feats_map.drop('product_code', axis=1, inplace=True)
         
         # onehot编码
-        onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT}
+        onehot_feats = {**CustConfig.ONEHOT_CAT, **ProductConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}
         onehot_columns = list(onehot_feats.keys())
-        numeric_columns = self.feats_map.drop(onehot_columns, axis=1).columns
-        self.feats_map = one_hot_embedding(self.feats_map, onehot_feats)
+        numeric_columns = feats_map.drop(onehot_columns, axis=1).columns
+        feats_map = one_hot_embedding(feats_map, onehot_feats)
         
         # 数字特征归一化
-        scaler = StandardScaler()
-        self.feats_map[numeric_columns] = scaler.fit_transform(self.feats_map[numeric_columns])
+        if len(numeric_columns) != 0:
+            scaler = StandardScaler()
+            feats_map[numeric_columns] = scaler.fit_transform(feats_map[numeric_columns])
+            
+        return feats_map
     
-    def recommend_sort(self, city_uuid, product_id):
-        self.generate_feats_map(city_uuid, product_id)
-        
-        gbdt_preds = self.gbdt_model.apply(self.feats_map)[:, :, 0]
+    def get_recommend_list(self, recommend_sample, recall_list):
+        gbdt_preds = self.gbdt_model.apply(recommend_sample)[:, :, 0]
         gbdt_feats_encoded = self.onehot_encoder.transform(gbdt_preds)
         scores = self.lr_model.predict_proba(gbdt_feats_encoded)[:, 1]
         
-        self.recommend_list = []
-        for cust_id, score in zip(self.recall_cust_list, scores):
-            self.recommend_list.append({cust_id: float(score)})
-            
-        self.recommend_list = sorted(self.recommend_list, key=lambda x: list(x.values())[0], reverse=True)
-        # for res in self.recommend_list[:200]:
-        #     print(res)
-        return self.recommend_list
+        recommend_list = []
+        for cust_id, score in zip(recall_list, scores):
+            recommend_list.append({cust_id: float(score)})
+            recommend_list.append({"cust_code": cust_id, "recommend_score": score})
+            
+        recommend_list = sorted(
+            [item for item in recommend_list if "recommend_score" in item],
+            key=lambda x: x["recommend_score"],
+            reverse=True
+        )
+        return recommend_list
     
     def generate_feats_importance(self):
         """生成特征重要性"""
@@ -148,7 +134,6 @@ class GbdtLrModel:
                 batch_interaction = self._explanier.shap_interaction_values(batch_data)
                 fp[i:i+len(batch_interaction)] = batch_interaction.astype(np.float32)
                 fp.flush()  # 确保数据写入磁盘
-            print("SHAP交互值计算并存储完成")
             
             # 分批计算均值
             mean_interaction = np.zeros((n_features, n_features), dtype=np.float32)
@@ -157,7 +142,6 @@ class GbdtLrModel:
                 mean_interaction += batch.sum(axis=0)  # 按批累加
             
             mean_interaction /= n_samples  # 计算最终均值
-            print("均值计算完成")
             
             # 构建交互矩阵DataFrame
             interaction_df = pd.DataFrame(
@@ -165,7 +149,6 @@ class GbdtLrModel:
                 index=data.columns,
                 columns=data.columns
             )
-            print("交互矩阵构建完成")
             
             # 分离卷烟和商户特征
             product_feats = [
@@ -179,23 +162,19 @@ class GbdtLrModel:
                 for feat, categories in {**CustConfig.ONEHOT_CAT, **ShopConfig.ONEHOT_CAT}.items()
                 for item in categories
             ]
-            print("特征分离完成")
             
             # 提取交叉区块
             cross_matrix = interaction_df.loc[product_feats, cust_feats] 
-            print("交叉区块提取完成")
             
             # 转换为长格式
             stacked = cross_matrix.stack().reset_index()
             stacked.columns = ['product_feat', 'cust_feat', 'relation']
-            print("转换为长格式完成")
             
             # 过滤掉零值或NaN的配对
             filtered = stacked[
                 (stacked['relation'].abs() > 1e-6) &  # 排除极小值
                 (~stacked['relation'].isna())         # 排除NaN
             ].copy()
-            print("过滤完成")
             
             # 排序结果
             results = (
@@ -203,7 +182,6 @@ class GbdtLrModel:
                 .sort_values('relation', ascending=False)
                 .to_dict('records')
             )
-            print("排序完成")
             
             # 替换特征名称
             feats_name_map = {
@@ -223,8 +201,6 @@ class GbdtLrModel:
                 cust_infos = cust_f.split("_")
                 item["cust_feat"] = f"{feats_name_map['_'.join(cust_infos[:-1])]}({cust_infos[-1]})"
             
-            print("名称替换完成")
-            
             # 返回最终结果
             return pd.DataFrame(results, columns=['product_feat', 'cust_feat', 'relation'])
         
@@ -258,8 +234,9 @@ if __name__ == "__main__":
     # ])
     # product_df.to_csv("./data/product_feats.csv", index=False)
     data, _ = DataLoader("./data/gbdt/train_data.csv").split_dataset()
-    # data = data["data"].sample(n=1000, replace=True, random_state=42)
-    data = data["data"]
+    data = data["data"].sample(n=300, replace=True, random_state=42)
+    data.to_csv("./data/data.csv", index=False)
+    # data = data["data"]
     result = gbdt_sort.generate_shap_interance(data)
     print("保存结果")
     result.to_csv("./data/feats_interaction.csv", index=False, encoding='utf-8-sig')

+ 27 - 2
utils/result_process.py

@@ -1,6 +1,22 @@
 import os
-def split_relation_subtable(data, save_dir):
+import pandas as pd
+from database import MySqlDao
+from models.rank.data.config import ImportanceFeaturesMap, ProductConfig
+
+def filter_data(data, filter_dict):
+    
+    product_content = []
+    for key, value in filter_dict.items():
+        if key != 'product_code':
+            product_content.append(f"{ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP[key]}({value})")
+    
+    data = data[data['product_feat'].isin(product_content)]
+    return data
+
+def split_relation_subtable(data, product_data, save_dir):
     """拆分卷烟商户特征相关性子表"""
+    data = filter_data(data, product_data).copy()
+    data.to_csv(os.path.join(save_dir, "feats_interaction.csv"), index=False, encoding='utf-8-sig')
     data['group_key'] = data["product_feat"].str.extract(r'^([^(]+)')
     grouped = data.groupby('group_key')
     sub_tables = {
@@ -9,4 +25,13 @@ def split_relation_subtable(data, save_dir):
     }
     
     for name, sub_data in sub_tables.items():
-        sub_data.to_csv(os.path.join(save_dir, f"{name}.csv"), index=False)
+        sub_data.to_csv(os.path.join(save_dir, f"{name}.csv"), index=False, encoding='utf-8-sig')
+        
+if __name__ == "__main__":
+    dao = MySqlDao()
+    
+    save_dir = "./data"
+    data = pd.read_csv("./data/feats_interaction.csv")
+    product_data = dao.get_product_by_id("00000000000000000000000011445301", "430201")[ProductConfig.FEATURE_COLUMNS]
+    filter_dict = product_data.to_dict("records")[0]
+    split_relation_subtable(data, filter_dict, save_dir)