3 Revīzijas 2163257cfe ... b808e7c4dd

Autors SHA1 Ziņojums Datums
  huanghongbo b808e7c4dd 更改类名 1 gadu atpakaļ
  huanghongbo 721b876bda Merge branch 'dev' of http://1.12.234.191:12000/huanghongbo/BrandCultivation into dev 1 gadu atpakaļ
  huanghongbo 7140ae55d6 热度召回增加redis存储相关的信息 1 gadu atpakaļ
2 mainītis faili ar 122 papildinājumiem un 2 dzēšanām
  1. 120 0
      models/recall/item2vec.py
  2. 2 2
      models/recall/itemCF/ShopScore.py

+ 120 - 0
models/recall/item2vec.py

@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+import gensim
+from dao.mysql_client import Mysql
+
+class Item2Vec(object):
+    def __init__(self):
+        mysql_client = Mysql()
+        # 创建会话
+        self.session = mysql_client.create_session()
+def load_item_sequences_from_mysql():
+    try:
+        conn = mysql.connector.connect(
+            host='localhost',
+            user='your_username',
+            password='your_password',
+            database='your_database'
+        )
+        cursor = conn.cursor()
+        query = "SELECT user_id, sequence FROM item_sequences"
+        cursor.execute(query)
+        for row in cursor:
+            user_id, sequence_str = row
+            sequence = sequence_str.split(',')
+            yield user_id, sequence
+        cursor.close()
+        conn.close()
+    except mysql.connector.Error as err:
+        print(f"数据库连接或查询出错: {err}")
+
+
+def load_item_attributes_from_mysql():
+    try:
+        conn = mysql.connector.connect(
+            host='localhost',
+            user='your_username',
+            password='your_password',
+            database='your_database'
+        )
+        cursor = conn.cursor()
+        query = "SELECT item, attributes FROM item_attributes"
+        cursor.execute(query)
+        item_attributes = {}
+        for item, attributes_str in cursor:
+            attributes = attributes_str.split(',')
+            item_attributes[item] = attributes
+        cursor.close()
+        conn.close()
+        return item_attributes
+    except mysql.connector.Error as err:
+        print(f"数据库连接或查询出错: {err}")
+
+
+def load_user_attributes_from_mysql():
+    try:
+        conn = mysql.connector.connect(
+            host='localhost',
+            user='your_username',
+            password='your_password',
+            database='your_database'
+        )
+        cursor = conn.cursor()
+        query = "SELECT user_id, taste, cigarette_length, cigarette_type, packaging_color FROM user_attributes"
+        cursor.execute(query)
+        for row in cursor:
+            user_id, taste, cigarette_length, cigarette_type, packaging_color = row
+            user_attrs = [attr for attr in [taste, cigarette_length, cigarette_type, packaging_color] if attr]
+            yield user_id, user_attrs
+        cursor.close()
+        conn.close()
+    except mysql.connector.Error as err:
+        print(f"数据库连接或查询出错: {err}")
+
+
+def combine_user_item_attributes(item_sequences, item_attributes):
+    user_attributes = {user_id: attrs for user_id, attrs in load_user_attributes_from_mysql()}
+    for user_id, sequence in item_sequences:
+        user_attrs = user_attributes.get(user_id, [])
+        combined_sequence = user_attrs.copy()
+        for item in sequence:
+            combined_sequence.append(item)
+            combined_sequence.extend(item_attributes.get(item, []))
+        yield combined_sequence
+
+
+def train_item2vec(item_sequences, vector_size=100, window=5, min_count=10, workers=4):
+    model = gensim.models.Word2Vec(sentences=item_sequences, vector_size=vector_size, window=window,
+                                   min_count=min_count, workers=workers)
+    return model
+
+
+def get_item_vector(item, model):
+    try:
+        return model.wv[item]
+    except KeyError:
+        print(f"物品 {item} 未在模型中找到。")
+        return None
+
+
+def find_similar_items(item, model, topn=5):
+    try:
+        similar_items = model.wv.most_similar(item, topn=topn)
+        filtered_similar_items = [(item, score) for item, score in similar_items if not item.startswith(('attr', 'user_'))]
+        return filtered_similar_items
+    except KeyError:
+        print(f"物品 {item} 未在模型中找到。")
+        return None
+
+
+if __name__ == "__main__":
+    item_sequences = load_item_sequences_from_mysql()
+    item_attributes = load_item_attributes_from_mysql()
+    combined_sequences = combine_user_item_attributes(item_sequences, item_attributes)
+    item2vec_model = train_item2vec(combined_sequences)
+    item_vector = get_item_vector('item1', item2vec_model)
+    if item_vector is not None:
+        print(f"物品 'item1' 的向量表示: {item_vector}")
+    similar_items = find_similar_items('item1', item2vec_model, topn=3)
+    if similar_items is not None:
+        print(f"与物品 'item1' 最相似的 3 个物品: {similar_items}")

+ 2 - 2
models/recall/itemCF/ShopScore.py

@@ -15,7 +15,7 @@ from dao.mysql_client import Mysql
 from decimal import Decimal
 
 # 算法封装成一个类
-class ItemCFModel:
+class ShopScore:
     """TODO 1. 将结果保存到redis数据库中"""
     def __init__(self):
         self.weights = {
@@ -81,7 +81,7 @@ def load_data_from_dataset():
  
 if __name__ == "__main__":
     # 创建一个 ItemCF 类的实例
-    item_cf_algorithm = ItemCFModel()
+    item_cf_algorithm = ShopScore()
     
     # 读取数据
     order_data = load_data_from_dataset()