| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- #!/usr/bin/env python3
- # -*- coding:utf-8 -*-
- import gensim
- from dao.mysql_client import Mysql
- class Item2Vec(object):
- def __init__(self):
- mysql_client = Mysql()
- # 创建会话
- self.session = mysql_client.create_session()
- def load_item_sequences_from_mysql():
- try:
- conn = mysql.connector.connect(
- host='localhost',
- user='your_username',
- password='your_password',
- database='your_database'
- )
- cursor = conn.cursor()
- query = "SELECT user_id, sequence FROM item_sequences"
- cursor.execute(query)
- for row in cursor:
- user_id, sequence_str = row
- sequence = sequence_str.split(',')
- yield user_id, sequence
- cursor.close()
- conn.close()
- except mysql.connector.Error as err:
- print(f"数据库连接或查询出错: {err}")
- def load_item_attributes_from_mysql():
- try:
- conn = mysql.connector.connect(
- host='localhost',
- user='your_username',
- password='your_password',
- database='your_database'
- )
- cursor = conn.cursor()
- query = "SELECT item, attributes FROM item_attributes"
- cursor.execute(query)
- item_attributes = {}
- for item, attributes_str in cursor:
- attributes = attributes_str.split(',')
- item_attributes[item] = attributes
- cursor.close()
- conn.close()
- return item_attributes
- except mysql.connector.Error as err:
- print(f"数据库连接或查询出错: {err}")
- def load_user_attributes_from_mysql():
- try:
- conn = mysql.connector.connect(
- host='localhost',
- user='your_username',
- password='your_password',
- database='your_database'
- )
- cursor = conn.cursor()
- query = "SELECT user_id, taste, cigarette_length, cigarette_type, packaging_color FROM user_attributes"
- cursor.execute(query)
- for row in cursor:
- user_id, taste, cigarette_length, cigarette_type, packaging_color = row
- user_attrs = [attr for attr in [taste, cigarette_length, cigarette_type, packaging_color] if attr]
- yield user_id, user_attrs
- cursor.close()
- conn.close()
- except mysql.connector.Error as err:
- print(f"数据库连接或查询出错: {err}")
- def combine_user_item_attributes(item_sequences, item_attributes):
- user_attributes = {user_id: attrs for user_id, attrs in load_user_attributes_from_mysql()}
- for user_id, sequence in item_sequences:
- user_attrs = user_attributes.get(user_id, [])
- combined_sequence = user_attrs.copy()
- for item in sequence:
- combined_sequence.append(item)
- combined_sequence.extend(item_attributes.get(item, []))
- yield combined_sequence
- def train_item2vec(item_sequences, vector_size=100, window=5, min_count=10, workers=4):
- model = gensim.models.Word2Vec(sentences=item_sequences, vector_size=vector_size, window=window,
- min_count=min_count, workers=workers)
- return model
- def get_item_vector(item, model):
- try:
- return model.wv[item]
- except KeyError:
- print(f"物品 {item} 未在模型中找到。")
- return None
- def find_similar_items(item, model, topn=5):
- try:
- similar_items = model.wv.most_similar(item, topn=topn)
- filtered_similar_items = [(item, score) for item, score in similar_items if not item.startswith(('attr', 'user_'))]
- return filtered_similar_items
- except KeyError:
- print(f"物品 {item} 未在模型中找到。")
- return None
- if __name__ == "__main__":
- item_sequences = load_item_sequences_from_mysql()
- item_attributes = load_item_attributes_from_mysql()
- combined_sequences = combine_user_item_attributes(item_sequences, item_attributes)
- item2vec_model = train_item2vec(combined_sequences)
- item_vector = get_item_vector('item1', item2vec_model)
- if item_vector is not None:
- print(f"物品 'item1' 的向量表示: {item_vector}")
- similar_items = find_similar_items('item1', item2vec_model, topn=3)
- if similar_items is not None:
- print(f"与物品 'item1' 最相似的 3 个物品: {similar_items}")
|