浏览代码

评估验证推荐效果

yangzeyu 11 月之前
父节点
当前提交
7ef5da0898
共有 3 个文件被更改,包括 8 次插入6 次删除
  1. 2 0
      database/dao/mysql_dao.py
  2. 2 4
      inference.py
  3. 4 2
      utils/result_process.py

+ 2 - 0
database/dao/mysql_dao.py

@@ -20,6 +20,8 @@ class MySqlDao:
         self._product_tablename = "tads_brandcul_product_info_f"
         self._cust_tablename = "tads_brandcul_cust_info_f"
         self._order_tablename = "tads_brandcul_consumer_order"
+        # self._order_tablename = "tads_brandcul_consumer_order"
+        # self._eval_order_name = "tads_brandcul_consumer_order_check"
         self._mock_order_tablename = "yunfu_mock_data"
         self._shopping_tablename = "tads_brandcul_cust_info_lbs_f"
         # self._shopping_tablename = "yunfu_shopping_mock_data"

+ 2 - 4
inference.py

@@ -26,8 +26,6 @@ def get_recall_cust(city_uuid, product_id, recall_count):
     """根据协同过滤和热度召回召回商户"""
     itemcf_recall_list = get_itemcf_recall(city_uuid, product_id)
     hot_recall_list = get_hot_recall(city_uuid)
-    for i in hot_recall_list:
-        print(i)
     
     result = list(dict.fromkeys(itemcf_recall_list))
     
@@ -85,10 +83,10 @@ def run():
     pass
 
 if __name__ == '__main__':
-    # generate_features_shap("00000000000000000000000011445301", "420202", delivery_count=5000)
+    # generate_features_shap("00000000000000000000000011445301", "350139", delivery_count=5000)
     # recommend_list = get_recommend_list("00000000000000000000000011445301", "420202")
     # recommend_list = pd.DataFrame(recommend_list)
     # recommend_list.to_csv("./data/recommend_list.csv", index=False, encoding="utf-8-sig")
-    data = dao.get_order_by_cust("00000000000000000000000011445301", "445323105795")
+    data = dao.get_order_by_cust("00000000000000000000000011445301", "445381107139")
     data = data.groupby(["cust_code", "product_code", "product_name"], as_index=False)["sale_qty"].sum()
     data.to_csv("./data/cust.csv", index=False)

+ 4 - 2
utils/result_process.py

@@ -32,12 +32,14 @@ def generate_report(city_uuid, data, filter_dict, recommend_data, delivery_count
     """根据总表筛选结果"""
     # 1. 筛选商户相关性排序结果
     data = filter_data(data, filter_dict).copy()
-    data.to_csv(os.path.join(save_dir, "feats_interaction.csv"), index=False, encoding='utf-8-sig')
+    # data.to_csv(os.path.join(save_dir, "feats_interaction.csv"), index=False, encoding='utf-8-sig')
     group_sums = data.groupby("cust_feat")["relation"].sum()
     # 筛选出总和非负的cust_feat
     valid_cust_feats = group_sums[group_sums > 0].index.tolist()
     cust_relation = data[data["cust_feat"].isin(valid_cust_feats)]
     cust_relation = cust_relation.reset_index(drop=True)
+    cust_relation.to_csv(os.path.join(save_dir, "feats_interaction.csv"), index=False, encoding='utf-8-sig')
+    
     
     # 2. 品规信息
     cust_relation[:20].to_csv(os.path.join(save_dir, "cust_relation.csv"), index=False, encoding='utf-8-sig')
@@ -103,7 +105,7 @@ def get_cust_list_from_history_order(city_uuid, product_code):
     cust_ids = recommend_data.set_index("cust_code")
     
     # 执行合并操作
-    merge_data = order_data.join(cust_ids, on="cust_code", how="inner")
+    merge_data = order_data.join(cust_ids, on="cust_code", how="left")
     merge_data = merge_data[["cust_code", "cust_name", "product_code", "product_name", "sale_qty", "sale_amt", "推荐序号", "匹配评分"]]
     return merge_data