#!/usr/bin/env python3 # -*- coding:utf-8 -*- import gensim from dao.mysql_client import Mysql class Item2Vec(object): def __init__(self): mysql_client = Mysql() # 创建会话 self.session = mysql_client.create_session() def load_item_sequences_from_mysql(): try: conn = mysql.connector.connect( host='localhost', user='your_username', password='your_password', database='your_database' ) cursor = conn.cursor() query = "SELECT user_id, sequence FROM item_sequences" cursor.execute(query) for row in cursor: user_id, sequence_str = row sequence = sequence_str.split(',') yield user_id, sequence cursor.close() conn.close() except mysql.connector.Error as err: print(f"数据库连接或查询出错: {err}") def load_item_attributes_from_mysql(): try: conn = mysql.connector.connect( host='localhost', user='your_username', password='your_password', database='your_database' ) cursor = conn.cursor() query = "SELECT item, attributes FROM item_attributes" cursor.execute(query) item_attributes = {} for item, attributes_str in cursor: attributes = attributes_str.split(',') item_attributes[item] = attributes cursor.close() conn.close() return item_attributes except mysql.connector.Error as err: print(f"数据库连接或查询出错: {err}") def load_user_attributes_from_mysql(): try: conn = mysql.connector.connect( host='localhost', user='your_username', password='your_password', database='your_database' ) cursor = conn.cursor() query = "SELECT user_id, taste, cigarette_length, cigarette_type, packaging_color FROM user_attributes" cursor.execute(query) for row in cursor: user_id, taste, cigarette_length, cigarette_type, packaging_color = row user_attrs = [attr for attr in [taste, cigarette_length, cigarette_type, packaging_color] if attr] yield user_id, user_attrs cursor.close() conn.close() except mysql.connector.Error as err: print(f"数据库连接或查询出错: {err}") def combine_user_item_attributes(item_sequences, item_attributes): user_attributes = {user_id: attrs for user_id, attrs in load_user_attributes_from_mysql()} for user_id, sequence in item_sequences: user_attrs = user_attributes.get(user_id, []) combined_sequence = user_attrs.copy() for item in sequence: combined_sequence.append(item) combined_sequence.extend(item_attributes.get(item, [])) yield combined_sequence def train_item2vec(item_sequences, vector_size=100, window=5, min_count=10, workers=4): model = gensim.models.Word2Vec(sentences=item_sequences, vector_size=vector_size, window=window, min_count=min_count, workers=workers) return model def get_item_vector(item, model): try: return model.wv[item] except KeyError: print(f"物品 {item} 未在模型中找到。") return None def find_similar_items(item, model, topn=5): try: similar_items = model.wv.most_similar(item, topn=topn) filtered_similar_items = [(item, score) for item, score in similar_items if not item.startswith(('attr', 'user_'))] return filtered_similar_items except KeyError: print(f"物品 {item} 未在模型中找到。") return None if __name__ == "__main__": item_sequences = load_item_sequences_from_mysql() item_attributes = load_item_attributes_from_mysql() combined_sequences = combine_user_item_attributes(item_sequences, item_attributes) item2vec_model = train_item2vec(combined_sequences) item_vector = get_item_vector('item1', item2vec_model) if item_vector is not None: print(f"物品 'item1' 的向量表示: {item_vector}") similar_items = find_similar_items('item1', item2vec_model, topn=3) if similar_items is not None: print(f"与物品 'item1' 最相似的 3 个物品: {similar_items}")