瀏覽代碼

封装运行流程到app.py,并重构部分代码

Sherlock1011 1 年之前
父節點
當前提交
c2a1bc2990

+ 88 - 6
app.py

@@ -1,6 +1,88 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-# 0:执行所有模块
-# 1:只执行热度找回
-# 2:协同过滤
-# --config 
+import argparse
+from dao import load_order_data_from_mysql
+from dao.redis_db import Redis
+from models import HotRecallModel, UserItemScore, ItemCFModel, calculate_similarity_and_save_results
+import os
+
+def run_hot_recall(order_data):
+    """运行热度召回算法"""
+    hot_model = HotRecallModel(order_data)
+    hot_model.calculate_all_hot_score()
+    print("热度召回已完成!")
+
+def run_itemcf(args):
+    # """运行协同过滤算法"""
+    # if os.path.exists(args.interst_score_path) and os.path.exists(args.similarity_matrix_path):
+    #     os.remove(args.interst_score_path)
+    #     os.remove(args.similarity_matrix_path)
+    # # n_jobs = 4
+    
+    # # 计算user-score-item数据
+    # cal_interest_scores_model = UserItemScore()
+    # scores = cal_interest_scores_model.score(order_data)
+    # scores.to_csv(args.interst_score_path, index=False, encoding="utf-8")
+    # print("Interest Scores cal done!")
+    
+    # # 计算商户共现矩阵及相似度矩阵
+    # calculate_similarity_and_save_results(order_data, args.similarity_matrix_path)
+    # print("Shops similarity matrix cal done!")
+    
+    # 运行协同过滤召回
+    itemcf_model = ItemCFModel()
+    itemcf_model.train(args.interst_score_path, args.similarity_matrix_path, args.n, args.k, args.top_n, args.n_jobs)
+    print("协同过滤已完成!")
+
+def run_itemcf_inference(product_code):
+        """
+        从 Redis 中读取推荐结果,并返回 {shop_id: score} 的列表
+        """
+        redis_db = Redis()
+        redis_key = f"fc:{product_code}"
+        recommendations = redis_db.redis.zrange(redis_key, 0, -1, withscores=True, desc=True)
+        
+        # 将推荐结果转换为 {shop_id: score} 的字典列表
+        result = [{shop_id: float(score)} for shop_id, score in recommendations]
+        
+        return result
+
+def run():
+    parser = argparse.ArgumentParser()
+    
+    # 运行方式
+    parser.add_argument("--run_all", action='store_true')
+    parser.add_argument("--run_hot", action='store_true')
+    parser.add_argument("--run_itemcf", action='store_true')
+    parser.add_argument("--run_itemcf_inference", action='store_true')
+    
+    # 协同过滤相关配置
+    parser.add_argument("--interst_score_path", type=str, default="./models/recall/itemCF/matrix/score.csv")
+    parser.add_argument("--similarity_matrix_path", type=str, default="./models/recall/itemCF/matrix/similarity.csv")
+    parser.add_argument("--n", type=int, default=100)
+    parser.add_argument("--k", type=int, default=10)
+    parser.add_argument("--top_n", type=int, default=200, help='default n * k')
+    parser.add_argument("--n_jobs", type=int, default=4)
+    
+    # 协同过滤推理配置
+    parser.add_argument("--product_code", type=int, default=110111)
+    
+    args = parser.parse_args()
+    
+    if args.run_all:
+        order_data = load_order_data_from_mysql()
+        run_hot_recall(order_data)
+        run_itemcf(order_data, args)
+        
+    elif args.run_hot:
+        order_data = load_order_data_from_mysql()
+        run_hot_recall(order_data)
+        
+    elif args.run_itemcf:
+        # order_data = load_order_data_from_mysql()
+        run_itemcf(args)
+        
+    elif args.run_itemcf_inference:
+        recomments = run_itemcf_inference(args.product_code)
+        print(recomments)
+    
+if __name__ == "__main__":
+    run()

+ 3 - 1
dao/__init__.py

@@ -1,7 +1,9 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 from dao.mysql_client import Mysql
+from dao.dao import load_order_data_from_mysql
 
 __all__ = [
-    "Mysql"
+    "Mysql",
+    "load_order_data_from_mysql"
 ]

+ 14 - 0
dao/dao.py

@@ -0,0 +1,14 @@
+from dao import Mysql
+
+def load_order_data_from_mysql():
+    """从数据库中读取数据"""
+    client = Mysql()
+    tablename = "mock_order"
+    query_text = "*"
+    
+    df = client.load_data(tablename, query_text)
+    
+     # 去除重复值和填补缺失值
+    df.drop_duplicates(inplace=True)
+    df.fillna(0, inplace=True)
+    return df

+ 1 - 1
dao/redis_db.py

@@ -21,7 +21,7 @@ if __name__ == '__main__':
     r = Redis().redis
 
     # 有序集合的键名
-    zset_key = 'hotkeys'
+    zset_key = 'configs:hotkeys'
 
     data_list = ['ORDER_FULLORDR_RATE', 'MONTH6_SALE_QTY_YOY', 'MONTH6_SALE_QTY_MOM', 'MONTH6_SALE_QTY']
 

+ 10 - 0
models/__init__.py

@@ -1,2 +1,12 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
+from models.recall.hot_recall import HotRecallModel
+from models.recall.itemCF.calculate_similarity_matrix import calculate_similarity_and_save_results
+from models.recall.itemCF.user_item_score import UserItemScore
+from models.recall.itemCF.ItemCF import ItemCFModel
+__all__ = [
+    "HotRecallModel",
+    "UserItemScore",
+    "calculate_similarity_and_save_results",
+    "ItemCFModel"
+]

+ 0 - 2
models/recall/__init__.py

@@ -1,2 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-

+ 3 - 6
models/recall/hot_recall.py

@@ -10,20 +10,17 @@
 import pandas as pd
 from dao.redis_db import Redis
 from dao.mysql_client import Mysql
-import random
 from tqdm import tqdm
-import joblib
 
-random.seed(12345)
 class HotRecallModel:
-    def __init__(self):
+    def __init__(self, order_data):
         self._redis_db = Redis()
         self._hotkeys = self.get_hotkeys()
-        self._order_data = self._load_data_from_dataset()
+        self._order_data = order_data
 
 
     def get_hotkeys(self):
-        info = self._redis_db.redis.zrange("hotkeys", 0, -1, withscores=True)
+        info = self._redis_db.redis.zrange("configs:hotkeys", 0, -1, withscores=True)
         hotkeys = []
         for item, _ in info:
             hotkeys.append(item)

+ 9 - 19
models/recall/itemCF/ItemCF.py

@@ -6,7 +6,7 @@ from scipy.sparse import csr_matrix
 from joblib import Parallel, delayed
 import joblib
 
-class ItemCF:
+class ItemCFModel:
     def __init__(self):
         self._recommendations = {}
         
@@ -53,9 +53,10 @@ class ItemCF:
         # 并行处理每个品规
         results = Parallel(n_jobs=n_jobs)(delayed(process_product)(product_code, scores) 
                                           for product_code, scores in tqdm(self._score_df.groupby("PRODUCT_CODE"), desc="train:正在计算候选得分"))
-        
+        print(len(results))
         # 存储结果
         self._recommendations = {product_code: sorted_candidates for product_code, sorted_candidates in results}
+        self.to_redis_zset()
     
     def to_redis_zset(self):
         """
@@ -76,31 +77,20 @@ class ItemCF:
             
             redis_db.redis.zadd(redis_key, zset_data)
     
-    def inference(self, product_code):
-        """
-        从 Redis 中读取推荐结果,并返回 {shop_id: score} 的列表
-        """
-        redis_db = Redis()
-        redis_key = f"fc:{product_code}"
-        recommendations = redis_db.redis.zrange(redis_key, 0, -1, withscores=True, desc=True)
-        
-        # 将推荐结果转换为 {shop_id: score} 的字典列表
-        result = [{shop_id: float(score)} for shop_id, score in recommendations]
-        
-        return result
-    
 if __name__ == "__main__":
     score_path = "./models/recall/itemCF/matrix/score.csv"
     similarity_path = "./models/recall/itemCF/matrix/similarity.csv"
-    itemcf_model = ItemCF()
+    # itemcf_model = ItemCFModel()
     # itemcf_model.train(score_path, similarity_path, n_jobs=4)
-    recommend_list = itemcf_model.inference(110111)
+    # recommend_list = itemcf_model.inference(110111)
     # itemcf_model.to_redis_zset()
     # print(len(recommend_list))
-    print(recommend_list)
+    # print(recommend_list)
     # joblib.dump(itemcf_model, "itemCF.model")
     
     # model = joblib.load("./itemCF.model")
     # recommend_list = model.inference(110102)
     # print(len(recommend_list))
-    # print(recommend_list)
+    # print(recommend_list)
+    data = pd.read_csv(similarity_path, index_col=0)
+    print(data)

+ 9 - 18
models/recall/itemCF/calculate_co_occurrence_matrix.py → models/recall/itemCF/calculate_similarity_matrix.py

@@ -1,3 +1,4 @@
+from dao import load_order_data_from_mysql
 import pandas as pd
 import numpy as np
 
@@ -5,18 +6,6 @@ from itertools import combinations
 from dao.mysql_client import Mysql
 from tqdm import tqdm
 
-def load_data_from_dataset():
-    """从数据库中读取数据"""
-    client = Mysql()
-    tablename = "mock_order"
-    query_text = "*"
-    
-    df = client.load_data(tablename, query_text)
-    
-     # 去除重复值和填补缺失值
-    df.drop_duplicates(inplace=True)
-    df.fillna(0, inplace=True)
-    return df
 
 def build_co_occurence_matrix(order_data):
     """
@@ -53,7 +42,6 @@ def calculate_similarity_matrix(co_occurrence_matrix, order_data, shops_to_index
     """
     # 计算每个商铺售卖品规的总次数
     shop_counts = order_data.groupby("BB_RETAIL_CUSTOMER_CODE").size()
-    num_shops = len(shops_to_index)
     
     # 将商户售卖次数转换为数组
     counts = np.array([shop_counts[shop] for shop in shops_to_index.keys()])
@@ -76,14 +64,17 @@ def save_matrix(matrix, shops, save_path):
     matrix_df = pd.DataFrame(matrix, index=shops, columns=shops)
     matrix_df.to_csv(save_path, index=True, encoding="utf-8")
     
+def calculate_similarity_and_save_results(order_data, similarity_matrix_save_path):
+    co_occurrence_matrix, shops, shops_to_index = build_co_occurence_matrix(order_data)
+    similarity_matrix = calculate_similarity_matrix(co_occurrence_matrix, order_data, shops_to_index)
+    save_matrix(similarity_matrix, shops, similarity_matrix_save_path)
+    
 if __name__ == "__main__":
     co_occurrence_save_path = "./models/recall/itemCF/matrix/occurrence.csv"
     similarity_matrix_save_path = "./models/recall/itemCF/matrix/similarity.csv"
-    order_data = load_data_from_dataset()
+    # 从数据库中读取订单数据
+    order_data = load_order_data_from_mysql()
     
-    co_occurrence_matrix, shops, shops_to_index = build_co_occurence_matrix(order_data)
+    calculate_similarity_and_save_results(order_data, similarity_matrix_save_path)
     
-    # save_matrix(co_occurrence_matrix, shops, co_occurrence_save_path)
-    similarity_matrix = calculate_similarity_matrix(co_occurrence_matrix, order_data, shops_to_index)
-    save_matrix(similarity_matrix, shops, similarity_matrix_save_path)
     

+ 6 - 18
models/recall/itemCF/user_item_score.py

@@ -7,9 +7,9 @@
 @author     : Sherlock1011 & Min1027
 @Version     : 1.0
 '''
-import joblib
 
-from dao.mysql_client import Mysql
+
+from dao import load_order_data_from_mysql
 from decimal import Decimal
 
 # 算法封装成一个类
@@ -61,32 +61,20 @@ class UserItemScore:
         df_result = df_result.sort_values(by=["PRODUCT_CODE", "SCORE"], ascending=[True, False])
 
         # 选择要保存的列
-        # df_result[['PRODUCT_CODE', 'BB_RETAIL_CUSTOMER_CODE', 'SCORE']].to_csv("./models/recall/itemCF/matrix/score.csv", index=False, encoding="utf-8")
         return df_result[['PRODUCT_CODE', 'BB_RETAIL_CUSTOMER_CODE', 'SCORE']]
-
-def load_data_from_dataset():
-    """从数据库中读取数据"""
-    client = Mysql()
-    tablename = "mock_order"
-    query_text = "*"
-    
-    df = client.load_data(tablename, query_text)
-    
-     # 去除重复值和填补缺失值
-    df.drop_duplicates(inplace=True)
-    df.fillna(0, inplace=True)
-    return df
  
 if __name__ == "__main__":
     # 创建一个 ItemCF 类的实例
     item_cf_algorithm = UserItemScore()
     
     # 读取数据
-    order_data = load_data_from_dataset()
+    order_data = load_order_data_from_mysql()
 
     # 调用算法
     scores = item_cf_algorithm.score(order_data)
     
+    scores_path = "./models/recall/itemCF/matrix/score.csv"
+    
     # 保存评分结果到csv文件
-    scores.to_csv("./models/recall/itemCF/matrix/score.csv", index=False, encoding="utf-8")
+    scores.to_csv(scores_path, index=False, encoding="utf-8")