Przeglądaj źródła

更新使用item2vec进行冷启动

Sherlock1011 11 miesięcy temu
rodzic
commit
2ace484206

+ 48 - 1
database/dao/mysql_dao.py

@@ -19,7 +19,7 @@ class MySqlDao:
         self.db_helper = MySqlDatabaseHelper()
         self._product_tablename = "tads_brandcul_product_info_f"
         self._cust_tablename = "tads_brandcul_cust_info_f"
-        self._order_tablename = "tads_brandcul_consumer_order_check"
+        self._order_tablename = "tads_brandcul_consumer_order"
         # self._order_tablename = "tads_brandcul_consumer_order"
         # self._eval_order_name = "tads_brandcul_consumer_order_check"
         self._mock_order_tablename = "yunfu_mock_data"
@@ -128,6 +128,40 @@ class MySqlDao:
         
         return data
     
+    def get_product_by_ids(self, city_uuid, product_id_list):
+        """根据product_code列表查询其信息"""
+        if not product_id_list:
+            return None
+        
+        product_id_str = ",".join([f"'{product_id}'" for product_id in product_id_list])
+        query = f"""
+            SELECT *
+            FROM {self._product_tablename}
+            WHERE city_uuid = :city_uuid
+            AND product_code IN ({product_id_str})
+        """
+        params = {"city_uuid": city_uuid}
+        data = self.db_helper.load_data_with_page(query, params)
+        
+        return data
+    
+    def get_order_by_product_ids(self, city_uuid, product_ids):
+        """获取指定香烟列表的所有售卖记录"""
+        if not product_ids:
+            return None
+        
+        product_ids_str = ",".join([f"'{product_code}'" for product_code in product_ids])
+        query = f"""
+            SELECT *
+            FROM {self._order_tablename}
+            WHERE city_uuid = :city_uuid
+            AND product_code IN ({product_ids_str})
+        """
+        params = {"city_uuid": city_uuid}
+        data = self.db_helper.load_data_with_page(query, params)
+        
+        return data
+    
     def get_order_by_product(self, city_uuid, product_id):
         query = f"""
             SELECT *
@@ -152,6 +186,19 @@ class MySqlDao:
         
         return data
     
+    def get_order_by_cust_and_product(self, city_uuid, cust_id, product_id):
+        query = f"""
+            SELECT *
+            FROM {self._order_tablename}
+            WHERE city_uuid = :city_uuid
+            AND cust_code = :cust_id
+            AND product_code =:product_id
+        """
+        params = {"city_uuid": city_uuid, "cust_id": cust_id, "product_id": product_id}
+        data = self.db_helper.load_data_with_page(query, params)
+        
+        return data
+    
     def data_preprocess(self, data: pd.DataFrame):
         
         data.drop(["cust_uuid", "longitude", "latitude", "range_radius"], axis=1, inplace=True)

+ 5 - 1
models/item2vec/__init__.py

@@ -1,5 +1,9 @@
 from models.item2vec.preprocess import Item2VecDataProcess
+from models.item2vec.item2vec import Item2Vec
+from models.item2vec.inference import Item2VecModel
 
 __all__ = [
-    "Item2VecDataProcess"
+    "Item2VecDataProcess",
+    "Item2Vec",
+    "Item2VecModel"
 ]

+ 63 - 0
models/item2vec/inference.py

@@ -0,0 +1,63 @@
+from database.dao.mysql_dao import MySqlDao
+from models.item2vec import Item2Vec
+from models.rank.data.config import OrderConfig, ProductConfig
+from models.rank.data.utils import sample_data_clear
+import pandas as pd
+
+class Item2VecModel:
+    def __init__(self, city_uuid):
+        self._dao = MySqlDao()
+        self._city_uuid = city_uuid
+        self._item2vec_model = Item2Vec(city_uuid)
+        
+    def generate_product_similarity_map(self, product_code):
+        """根据product_code生成卷烟相似度矩阵"""
+        product = self._dao.get_product_by_id(self._city_uuid, product_code)[ProductConfig.FEATURE_COLUMNS]
+        product = sample_data_clear(product, ProductConfig)
+        
+        similarity_map = self._item2vec_model.get_similarity_map(product)
+        similarity_map = pd.DataFrame(similarity_map)
+        product_list = self._dao.load_product_data(self._city_uuid)[ProductConfig.FEATURE_COLUMNS]
+        similarity_map = similarity_map.merge(product_list, on="product_code", how="inner")
+        # self._similarity_map = self._similarity_map.query(f"product_code != {product_code}")
+        return similarity_map
+        
+    def get_similarity_list(self, product_code, top=40):
+        """获取与指卷烟最相似的top k个卷烟"""
+        similarity_map = self.generate_product_similarity_map(product_code)
+        similarity_list = similarity_map["product_code"].to_list()
+        # similarity_list.remove(product_code)
+        similarity_list = similarity_list[:top]
+        return similarity_list
+    
+    def get_recommend_cust_list(self, product_code, top=50):
+        """获取推荐的商户列表"""
+        product_list = self.get_similarity_list(product_code)
+        order_data = self._dao.get_order_by_product_ids(self._city_uuid, product_list)[OrderConfig.FEATURE_COLUMNS]
+        order_data["sale_qty"] = order_data["sale_qty"].fillna(0)
+        order_data = order_data.groupby(["cust_code", "product_code"], as_index=False)["sale_qty"].sum()
+        
+        
+        # 按照卷烟分组,取每款卷烟售卖最好的前50个商户
+        order_data = (
+            order_data
+            .sort_values(["product_code", "sale_qty"], ascending=[True, False])
+            .groupby("product_code")
+            .head(top)
+        )
+        
+        recommend_cust = order_data.groupby(["cust_code"], as_index=False)["sale_qty"].sum()
+        recommend_cust = recommend_cust.sort_values(["sale_qty"], ascending=[False])
+        recommend_cust.to_csv("./data/recommend.csv", index=False)
+        
+        
+        
+if __name__ == "__main__":
+    city_uuid = "00000000000000000000000011445301"
+    product_id = "420202"
+    
+    model = Item2VecModel(city_uuid)
+    model.get_recommend_cust_list(product_id)
+    # dao = MySqlDao()
+    # data = dao.get_order_by_cust_and_product(city_uuid, "445300108802", "340223")[OrderConfig.FEATURE_COLUMNS]
+    # data.to_csv("./data/result.csv", index=False)

+ 4 - 4
models/item2vec/item2vec.py

@@ -64,8 +64,8 @@ class Item2Vec:
             
             similarity_map.append(
                 {
-                    "product_code": product['product_code'], 
-                    "target_product_code": target_product_code,
+                    "target_product_code": product['product_code'], 
+                    "product_code": target_product_code,
                     "similarity": similarity
                 }
             )
@@ -81,12 +81,12 @@ if __name__ == "__main__":
     dao = MySqlDao()
     city_uuid = "00000000000000000000000011445301"
     product_id = "420202"
-    order_data = dao.load_order_data(city_uuid)
+    
     product = dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
     product = sample_data_clear(product, ProductConfig)
     model = Item2Vec(city_uuid)
     sims = model.get_similarity_map(product)
     sims = pd.DataFrame(sims)
     product_info = dao.load_product_data(city_uuid)[ProductConfig.FEATURE_COLUMNS]
-    sims = sims.merge(product_info, left_on="target_product_code", right_on="product_code", how="inner")
+    sims = sims.merge(product_info, on="product_code", how="inner")
     sims.to_csv("./data/product_similarity.csv", index=False)

+ 3 - 1
models/item2vec/preprocess.py

@@ -8,7 +8,9 @@ class Item2VecDataProcess:
     def __init__(self, city_uuid):
         self._mysql_dao = MySqlDao()
         print("item2vec: 正在加载product_info...")
-        self._product_data = self._mysql_dao.load_product_data(city_uuid)
+        # self._product_data = self._mysql_dao.load_product_data(city_uuid)
+        product_ids = self._mysql_dao.load_order_data(city_uuid)["product_code"].unique().tolist()
+        self._product_data = self._mysql_dao.get_product_by_ids(city_uuid, product_ids)
         self._data_process()
         
     def _data_process(self):

BIN
models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl


+ 3 - 3
utils/result_process.py

@@ -99,16 +99,16 @@ def get_cust_list_from_history_order(city_uuid, product_code):
     
     # 读取推荐数据
     recommend_data = pd.read_csv('./data/recommend_report.csv')
-    
+    recommend_data = recommend_data.drop(columns=["sale_qty"])
     # 确保recommend_data中的cust_code也是字符串类型
     recommend_data["cust_code"] = recommend_data["cust_code"].astype(str)
     cust_ids = recommend_data.set_index("cust_code")
     
     # 执行合并操作
     merge_data = order_data.join(cust_ids, on="cust_code", how="left")
-    merge_data = merge_data[["cust_code", "cust_name", "product_code", "product_name", "sale_qty", "sale_amt", "推荐序号", "匹配评分"]]
+    merge_data = merge_data[["cust_code", "cust_name", "product_code", "product_name", "sale_qty", "推荐序号"]]
     return merge_data
         
 if __name__ == "__main__":
-    order_data = get_cust_list_from_history_order("00000000000000000000000011445301", "420202")
+    order_data = get_cust_list_from_history_order("00000000000000000000000011445301", "350355")
     order_data.to_csv("./data/eval.csv", index=False)