item2vec.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import gensim
  4. from dao.mysql_client import Mysql
  5. class Item2Vec(object):
  6. def __init__(self):
  7. mysql_client = Mysql()
  8. # 创建会话
  9. self.session = mysql_client.create_session()
  10. def load_item_sequences_from_mysql():
  11. try:
  12. conn = mysql.connector.connect(
  13. host='localhost',
  14. user='your_username',
  15. password='your_password',
  16. database='your_database'
  17. )
  18. cursor = conn.cursor()
  19. query = "SELECT user_id, sequence FROM item_sequences"
  20. cursor.execute(query)
  21. for row in cursor:
  22. user_id, sequence_str = row
  23. sequence = sequence_str.split(',')
  24. yield user_id, sequence
  25. cursor.close()
  26. conn.close()
  27. except mysql.connector.Error as err:
  28. print(f"数据库连接或查询出错: {err}")
  29. def load_item_attributes_from_mysql():
  30. try:
  31. conn = mysql.connector.connect(
  32. host='localhost',
  33. user='your_username',
  34. password='your_password',
  35. database='your_database'
  36. )
  37. cursor = conn.cursor()
  38. query = "SELECT item, attributes FROM item_attributes"
  39. cursor.execute(query)
  40. item_attributes = {}
  41. for item, attributes_str in cursor:
  42. attributes = attributes_str.split(',')
  43. item_attributes[item] = attributes
  44. cursor.close()
  45. conn.close()
  46. return item_attributes
  47. except mysql.connector.Error as err:
  48. print(f"数据库连接或查询出错: {err}")
  49. def load_user_attributes_from_mysql():
  50. try:
  51. conn = mysql.connector.connect(
  52. host='localhost',
  53. user='your_username',
  54. password='your_password',
  55. database='your_database'
  56. )
  57. cursor = conn.cursor()
  58. query = "SELECT user_id, taste, cigarette_length, cigarette_type, packaging_color FROM user_attributes"
  59. cursor.execute(query)
  60. for row in cursor:
  61. user_id, taste, cigarette_length, cigarette_type, packaging_color = row
  62. user_attrs = [attr for attr in [taste, cigarette_length, cigarette_type, packaging_color] if attr]
  63. yield user_id, user_attrs
  64. cursor.close()
  65. conn.close()
  66. except mysql.connector.Error as err:
  67. print(f"数据库连接或查询出错: {err}")
  68. def combine_user_item_attributes(item_sequences, item_attributes):
  69. user_attributes = {user_id: attrs for user_id, attrs in load_user_attributes_from_mysql()}
  70. for user_id, sequence in item_sequences:
  71. user_attrs = user_attributes.get(user_id, [])
  72. combined_sequence = user_attrs.copy()
  73. for item in sequence:
  74. combined_sequence.append(item)
  75. combined_sequence.extend(item_attributes.get(item, []))
  76. yield combined_sequence
  77. def train_item2vec(item_sequences, vector_size=100, window=5, min_count=10, workers=4):
  78. model = gensim.models.Word2Vec(sentences=item_sequences, vector_size=vector_size, window=window,
  79. min_count=min_count, workers=workers)
  80. return model
  81. def get_item_vector(item, model):
  82. try:
  83. return model.wv[item]
  84. except KeyError:
  85. print(f"物品 {item} 未在模型中找到。")
  86. return None
  87. def find_similar_items(item, model, topn=5):
  88. try:
  89. similar_items = model.wv.most_similar(item, topn=topn)
  90. filtered_similar_items = [(item, score) for item, score in similar_items if not item.startswith(('attr', 'user_'))]
  91. return filtered_similar_items
  92. except KeyError:
  93. print(f"物品 {item} 未在模型中找到。")
  94. return None
  95. if __name__ == "__main__":
  96. item_sequences = load_item_sequences_from_mysql()
  97. item_attributes = load_item_attributes_from_mysql()
  98. combined_sequences = combine_user_item_attributes(item_sequences, item_attributes)
  99. item2vec_model = train_item2vec(combined_sequences)
  100. item_vector = get_item_vector('item1', item2vec_model)
  101. if item_vector is not None:
  102. print(f"物品 'item1' 的向量表示: {item_vector}")
  103. similar_items = find_similar_items('item1', item2vec_model, topn=3)
  104. if similar_items is not None:
  105. print(f"与物品 'item1' 最相似的 3 个物品: {similar_items}")