Переглянути джерело

修改:协同过滤在train时再加载数据

Sherlock1011 1 рік тому
батько
коміт
cf34209d1a
1 змінених файлів з 9 додано та 8 видалено
  1. 9 8
      models/recall/itemCF/ItemCF.py

+ 9 - 8
models/recall/itemCF/ItemCF.py

@@ -7,15 +7,16 @@ from joblib import Parallel, delayed
 import joblib
 
 class ItemCF:
-    def __init__(self, score_path, similatity_path):
+    def __init__(self):
         self._recommendations = {}
+        
+    def train(self, score_path, similatity_path, n=100, k=10, top_n=100, n_jobs=4):
         self._score_df = pd.read_csv(score_path)
         self._similarity_df = pd.read_csv(similatity_path, index_col=0)
         self._similarity_matrix = csr_matrix(self._similarity_df.values)
         self._shop_index = {shop: idx for idx, shop in enumerate(self._similarity_df.index)}
         self._index_shop = {idx: shop for idx, shop in enumerate(self._similarity_df.index)}
         
-    def train(self, n=100, k=10, top_n=100, n_jobs=4):
         def process_product(product_code, scores):
             # 获取热度最高的n个商户
             top_n_shops = scores.nlargest(n, "SCORE")["BB_RETAIL_CUSTOMER_CODE"].values
@@ -51,7 +52,7 @@ class ItemCF:
         
         # 并行处理每个品规
         results = Parallel(n_jobs=n_jobs)(delayed(process_product)(product_code, scores) 
-                                          for product_code, scores in tqdm(self._score_df.groupby("PRODUCT_CODE"), desc="正在计算候选得分..."))
+                                          for product_code, scores in tqdm(self._score_df.groupby("PRODUCT_CODE"), desc="train:正在计算候选得分"))
         
         # 存储结果
         self._recommendations = {product_code: sorted_candidates for product_code, sorted_candidates in results}
@@ -62,7 +63,7 @@ class ItemCF:
         存储格式为 fc:product_code,其中商户 ID 作为成员,得分作为分数
         """
         redis_db = Redis()
-        for product_code, recommendations in self._recommendations.items():
+        for product_code, recommendations in tqdm(self._recommendations.items(), desc="train:正在存储推荐结果"):
             redis_key = f"fc:{product_code}"
             zset_data = {}
             for rec in recommendations:
@@ -91,12 +92,12 @@ class ItemCF:
 if __name__ == "__main__":
     score_path = "./models/recall/itemCF/matrix/score.csv"
     similarity_path = "./models/recall/itemCF/matrix/similarity.csv"
-    itemcf_model = ItemCF(score_path, similarity_path)
-    itemcf_model.train(n_jobs=4)
-    recommend_list = itemcf_model.inference(110102)
+    itemcf_model = ItemCF()
+    # itemcf_model.train(score_path, similarity_path, n_jobs=4)
+    recommend_list = itemcf_model.inference(110111)
     # itemcf_model.to_redis_zset()
     # print(len(recommend_list))
-    # print(recommend_list)
+    print(recommend_list)
     # joblib.dump(itemcf_model, "itemCF.model")
     
     # model = joblib.load("./itemCF.model")