|
@@ -0,0 +1,120 @@
|
|
|
|
|
+#!/usr/bin/env python3
|
|
|
|
|
+# -*- coding:utf-8 -*-
|
|
|
|
|
+import gensim
|
|
|
|
|
+from dao.mysql_client import Mysql
|
|
|
|
|
+
|
|
|
|
|
+class Item2Vec(object):
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ mysql_client = Mysql()
|
|
|
|
|
+ # 创建会话
|
|
|
|
|
+ self.session = mysql_client.create_session()
|
|
|
|
|
+def load_item_sequences_from_mysql():
|
|
|
|
|
+ try:
|
|
|
|
|
+ conn = mysql.connector.connect(
|
|
|
|
|
+ host='localhost',
|
|
|
|
|
+ user='your_username',
|
|
|
|
|
+ password='your_password',
|
|
|
|
|
+ database='your_database'
|
|
|
|
|
+ )
|
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
+ query = "SELECT user_id, sequence FROM item_sequences"
|
|
|
|
|
+ cursor.execute(query)
|
|
|
|
|
+ for row in cursor:
|
|
|
|
|
+ user_id, sequence_str = row
|
|
|
|
|
+ sequence = sequence_str.split(',')
|
|
|
|
|
+ yield user_id, sequence
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+ except mysql.connector.Error as err:
|
|
|
|
|
+ print(f"数据库连接或查询出错: {err}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def load_item_attributes_from_mysql():
|
|
|
|
|
+ try:
|
|
|
|
|
+ conn = mysql.connector.connect(
|
|
|
|
|
+ host='localhost',
|
|
|
|
|
+ user='your_username',
|
|
|
|
|
+ password='your_password',
|
|
|
|
|
+ database='your_database'
|
|
|
|
|
+ )
|
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
+ query = "SELECT item, attributes FROM item_attributes"
|
|
|
|
|
+ cursor.execute(query)
|
|
|
|
|
+ item_attributes = {}
|
|
|
|
|
+ for item, attributes_str in cursor:
|
|
|
|
|
+ attributes = attributes_str.split(',')
|
|
|
|
|
+ item_attributes[item] = attributes
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+ return item_attributes
|
|
|
|
|
+ except mysql.connector.Error as err:
|
|
|
|
|
+ print(f"数据库连接或查询出错: {err}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def load_user_attributes_from_mysql():
|
|
|
|
|
+ try:
|
|
|
|
|
+ conn = mysql.connector.connect(
|
|
|
|
|
+ host='localhost',
|
|
|
|
|
+ user='your_username',
|
|
|
|
|
+ password='your_password',
|
|
|
|
|
+ database='your_database'
|
|
|
|
|
+ )
|
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
+ query = "SELECT user_id, taste, cigarette_length, cigarette_type, packaging_color FROM user_attributes"
|
|
|
|
|
+ cursor.execute(query)
|
|
|
|
|
+ for row in cursor:
|
|
|
|
|
+ user_id, taste, cigarette_length, cigarette_type, packaging_color = row
|
|
|
|
|
+ user_attrs = [attr for attr in [taste, cigarette_length, cigarette_type, packaging_color] if attr]
|
|
|
|
|
+ yield user_id, user_attrs
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+ except mysql.connector.Error as err:
|
|
|
|
|
+ print(f"数据库连接或查询出错: {err}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def combine_user_item_attributes(item_sequences, item_attributes):
|
|
|
|
|
+ user_attributes = {user_id: attrs for user_id, attrs in load_user_attributes_from_mysql()}
|
|
|
|
|
+ for user_id, sequence in item_sequences:
|
|
|
|
|
+ user_attrs = user_attributes.get(user_id, [])
|
|
|
|
|
+ combined_sequence = user_attrs.copy()
|
|
|
|
|
+ for item in sequence:
|
|
|
|
|
+ combined_sequence.append(item)
|
|
|
|
|
+ combined_sequence.extend(item_attributes.get(item, []))
|
|
|
|
|
+ yield combined_sequence
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def train_item2vec(item_sequences, vector_size=100, window=5, min_count=10, workers=4):
|
|
|
|
|
+ model = gensim.models.Word2Vec(sentences=item_sequences, vector_size=vector_size, window=window,
|
|
|
|
|
+ min_count=min_count, workers=workers)
|
|
|
|
|
+ return model
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def get_item_vector(item, model):
|
|
|
|
|
+ try:
|
|
|
|
|
+ return model.wv[item]
|
|
|
|
|
+ except KeyError:
|
|
|
|
|
+ print(f"物品 {item} 未在模型中找到。")
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def find_similar_items(item, model, topn=5):
|
|
|
|
|
+ try:
|
|
|
|
|
+ similar_items = model.wv.most_similar(item, topn=topn)
|
|
|
|
|
+ filtered_similar_items = [(item, score) for item, score in similar_items if not item.startswith(('attr', 'user_'))]
|
|
|
|
|
+ return filtered_similar_items
|
|
|
|
|
+ except KeyError:
|
|
|
|
|
+ print(f"物品 {item} 未在模型中找到。")
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ item_sequences = load_item_sequences_from_mysql()
|
|
|
|
|
+ item_attributes = load_item_attributes_from_mysql()
|
|
|
|
|
+ combined_sequences = combine_user_item_attributes(item_sequences, item_attributes)
|
|
|
|
|
+ item2vec_model = train_item2vec(combined_sequences)
|
|
|
|
|
+ item_vector = get_item_vector('item1', item2vec_model)
|
|
|
|
|
+ if item_vector is not None:
|
|
|
|
|
+ print(f"物品 'item1' 的向量表示: {item_vector}")
|
|
|
|
|
+ similar_items = find_similar_items('item1', item2vec_model, topn=3)
|
|
|
|
|
+ if similar_items is not None:
|
|
|
|
|
+ print(f"与物品 'item1' 最相似的 3 个物品: {similar_items}")
|