Browse Source

热度召回增加redis存储相关的信息

huanghongbo 1 năm trước cách đây
mục cha
commit
7140ae55d6
1 tập tin đã thay đổi với 120 bổ sung0 xóa
  1. 120 0
      models/recall/item2vec.py

+ 120 - 0
models/recall/item2vec.py

@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+import gensim
+from dao.mysql_client import Mysql
+
+class Item2Vec(object):
+    def __init__(self):
+        mysql_client = Mysql()
+        # 创建会话
+        self.session = mysql_client.create_session()
+def load_item_sequences_from_mysql():
+    try:
+        conn = mysql.connector.connect(
+            host='localhost',
+            user='your_username',
+            password='your_password',
+            database='your_database'
+        )
+        cursor = conn.cursor()
+        query = "SELECT user_id, sequence FROM item_sequences"
+        cursor.execute(query)
+        for row in cursor:
+            user_id, sequence_str = row
+            sequence = sequence_str.split(',')
+            yield user_id, sequence
+        cursor.close()
+        conn.close()
+    except mysql.connector.Error as err:
+        print(f"数据库连接或查询出错: {err}")
+
+
+def load_item_attributes_from_mysql():
+    try:
+        conn = mysql.connector.connect(
+            host='localhost',
+            user='your_username',
+            password='your_password',
+            database='your_database'
+        )
+        cursor = conn.cursor()
+        query = "SELECT item, attributes FROM item_attributes"
+        cursor.execute(query)
+        item_attributes = {}
+        for item, attributes_str in cursor:
+            attributes = attributes_str.split(',')
+            item_attributes[item] = attributes
+        cursor.close()
+        conn.close()
+        return item_attributes
+    except mysql.connector.Error as err:
+        print(f"数据库连接或查询出错: {err}")
+
+
+def load_user_attributes_from_mysql():
+    try:
+        conn = mysql.connector.connect(
+            host='localhost',
+            user='your_username',
+            password='your_password',
+            database='your_database'
+        )
+        cursor = conn.cursor()
+        query = "SELECT user_id, taste, cigarette_length, cigarette_type, packaging_color FROM user_attributes"
+        cursor.execute(query)
+        for row in cursor:
+            user_id, taste, cigarette_length, cigarette_type, packaging_color = row
+            user_attrs = [attr for attr in [taste, cigarette_length, cigarette_type, packaging_color] if attr]
+            yield user_id, user_attrs
+        cursor.close()
+        conn.close()
+    except mysql.connector.Error as err:
+        print(f"数据库连接或查询出错: {err}")
+
+
+def combine_user_item_attributes(item_sequences, item_attributes):
+    user_attributes = {user_id: attrs for user_id, attrs in load_user_attributes_from_mysql()}
+    for user_id, sequence in item_sequences:
+        user_attrs = user_attributes.get(user_id, [])
+        combined_sequence = user_attrs.copy()
+        for item in sequence:
+            combined_sequence.append(item)
+            combined_sequence.extend(item_attributes.get(item, []))
+        yield combined_sequence
+
+
+def train_item2vec(item_sequences, vector_size=100, window=5, min_count=10, workers=4):
+    model = gensim.models.Word2Vec(sentences=item_sequences, vector_size=vector_size, window=window,
+                                   min_count=min_count, workers=workers)
+    return model
+
+
+def get_item_vector(item, model):
+    try:
+        return model.wv[item]
+    except KeyError:
+        print(f"物品 {item} 未在模型中找到。")
+        return None
+
+
+def find_similar_items(item, model, topn=5):
+    try:
+        similar_items = model.wv.most_similar(item, topn=topn)
+        filtered_similar_items = [(item, score) for item, score in similar_items if not item.startswith(('attr', 'user_'))]
+        return filtered_similar_items
+    except KeyError:
+        print(f"物品 {item} 未在模型中找到。")
+        return None
+
+
+if __name__ == "__main__":
+    item_sequences = load_item_sequences_from_mysql()
+    item_attributes = load_item_attributes_from_mysql()
+    combined_sequences = combine_user_item_attributes(item_sequences, item_attributes)
+    item2vec_model = train_item2vec(combined_sequences)
+    item_vector = get_item_vector('item1', item2vec_model)
+    if item_vector is not None:
+        print(f"物品 'item1' 的向量表示: {item_vector}")
+    similar_items = find_similar_items('item1', item2vec_model, topn=3)
+    if similar_items is not None:
+        print(f"与物品 'item1' 最相似的 3 个物品: {similar_items}")