Эх сурвалжийг харах

特征相关性结果后处理

yangzeyu 11 сар өмнө
parent
commit
4d14332433

+ 12 - 0
database/dao/mysql_dao.py

@@ -126,6 +126,18 @@ class MySqlDao:
         
         return data
     
+    def get_order_by_product(self, city_uuid, product_id):
+        query = f"""
+            SELECT *
+            FROM {self._order_tablename}
+            WHERE city_uuid = :city_uuid
+            AND product_code = :product_id
+        """
+        params = {"city_uuid": city_uuid, "product_id": product_id}
+        data = self.db_helper.load_data_with_page(query, params)
+        
+        return data
+    
     def data_preprocess(self, data: pd.DataFrame):
         
         data.drop(["cust_uuid", "longitude", "latitude", "range_radius"], axis=1, inplace=True)

+ 6 - 6
inference.py

@@ -3,7 +3,7 @@ from database import RedisDatabaseHelper, MySqlDao
 from models.rank.data.config import CustConfig, ProductConfig, ShopConfig
 from models.rank.data.utils import sample_data_clear
 from models.rank.gbdt_lr_inference import GbdtLrModel
-from utils.result_process import split_relation_subtable
+from utils.result_process import split_relation_subtable, generate_report
 import pandas as pd
 
 redis = RedisDatabaseHelper().redis
@@ -70,7 +70,7 @@ def gbdt_lr_inference(city_uuid, product_id):
 def generate_features_shap(city_uuid, product_id):
     feats_sample, filter_dict, _ = generate_recommend_sample(city_uuid, product_id)
     result = gbdtlr_model.generate_shap_interance(feats_sample)
-    split_relation_subtable(result, filter_dict, "./data")
+    generate_report(result, filter_dict, "./data")
     
 
 def generate_delivery_strategy():
@@ -81,7 +81,7 @@ def run():
     pass
 
 if __name__ == '__main__':
-    recommend_list = get_recommend_list("00000000000000000000000011445301", "350139")
-    recommend_list = pd.DataFrame(recommend_list)
-    recommend_list.to_csv("./data/推荐商户表.csv", index=False)
-    
+    generate_features_shap("00000000000000000000000011445301", "350139")
+    # recommend_list = get_recommend_list("00000000000000000000000011445301", "530246")
+    # recommend_list = pd.DataFrame(recommend_list)
+    # recommend_list.to_csv("./data/recommend_list.csv", index=False, encoding="utf-8-sig")

+ 1 - 1
models/rank/gbdt_lr_inference.py

@@ -138,7 +138,7 @@ class GbdtLrModel:
             # 分批计算均值
             mean_interaction = np.zeros((n_features, n_features), dtype=np.float32)
             for i in tqdm(range(0, n_samples, batch_size), desc="计算均值..."):
-                batch = np.abs(fp[i:i+batch_size])  # 读取批数据并取绝对值
+                batch = fp[i:i+batch_size]  # 读取批数据并取绝对值
                 mean_interaction += batch.sum(axis=0)  # 按批累加
             
             mean_interaction /= n_samples  # 计算最终均值

+ 28 - 9
utils/result_process.py

@@ -3,6 +3,7 @@ import pandas as pd
 from database import MySqlDao
 from models.rank.data.config import ImportanceFeaturesMap, ProductConfig
 
+dao = MySqlDao()
 def filter_data(data, filter_dict):
     
     product_content = []
@@ -13,9 +14,9 @@ def filter_data(data, filter_dict):
     data = data[data['product_feat'].isin(product_content)]
     return data
 
-def split_relation_subtable(data, product_data, save_dir):
+def split_relation_subtable(data, filter_dict, save_dir):
     """拆分卷烟商户特征相关性子表"""
-    data = filter_data(data, product_data).copy()
+    data = filter_data(data, filter_dict).copy()
     data.to_csv(os.path.join(save_dir, "feats_interaction.csv"), index=False, encoding='utf-8-sig')
     data['group_key'] = data["product_feat"].str.extract(r'^([^(]+)')
     grouped = data.groupby('group_key')
@@ -27,11 +28,29 @@ def split_relation_subtable(data, product_data, save_dir):
     for name, sub_data in sub_tables.items():
         sub_data.to_csv(os.path.join(save_dir, f"{name}.csv"), index=False, encoding='utf-8-sig')
         
-if __name__ == "__main__":
-    dao = MySqlDao()
+def generate_report(data, filter_dict, save_dir):
+    """根据总表筛选结果"""
+    # 1. 筛选商户相关性排序结果
+    data = filter_data(data, filter_dict).copy()
+    data.to_csv(os.path.join(save_dir, "feats_interaction.csv"), index=False, encoding='utf-8-sig')
+    group_sums = data.groupby("cust_feat")["relation"].sum()
+    # 筛选出总和非负的cust_feat
+    valid_cust_feats = group_sums[group_sums > 0].index.tolist()
+    cust_relation = data[data["cust_feat"].isin(valid_cust_feats)]
+    cust_relation = cust_relation.reset_index(drop=True)
+    
+    # 2. 品规信息
+    cust_relation[:20].to_csv(os.path.join(save_dir, "cust_relation.csv"), index=False, encoding='utf-8-sig')
+    with open(os.path.join(save_dir, "product_info.csv"), "w", encoding='utf-8-sig') as f:
+        for key, value in filter_dict.items():
+            if key != 'product_code':
+                f.write(f"{ImportanceFeaturesMap.PRODUCT_FEATRUES_MAP[key]}, {value}\n")
     
-    save_dir = "./data"
-    data = pd.read_csv("./data/feats_interaction.csv")
-    product_data = dao.get_product_by_id("00000000000000000000000011445301", "430201")[ProductConfig.FEATURE_COLUMNS]
-    filter_dict = product_data.to_dict("records")[0]
-    split_relation_subtable(data, filter_dict, save_dir)
+        
+def get_cust_list_from_history_order(city_uuid, product_code):
+    order_data = dao.get_order_by_product(city_uuid, product_code)
+    return order_data
+        
+if __name__ == "__main__":
+    order_data = get_cust_list_from_history_order("00000000000000000000000011445301", "350139")
+    order_data.to_csv("./data/history.csv", index=False)