item2vec.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import joblib
  2. from database.dao.mysql_dao import MySqlDao
  3. from models.item2vec import Item2VecDataProcess
  4. import numpy as np
  5. from gensim.models import Word2Vec
  6. from models.rank.data.config import ProductConfig
  7. from models.rank.data.utils import sample_data_clear
  8. import pandas as pd
  9. from sklearn.metrics.pairwise import cosine_similarity
  10. from tqdm import tqdm
  11. class Item2Vec:
  12. def __init__(self, city_uuid):
  13. self._load_data(city_uuid)
  14. self._load_model()
  15. def _load_data(self, city_uuid):
  16. """加载特征sentence"""
  17. data_processor = Item2VecDataProcess(city_uuid)
  18. self._tokens_map = data_processor.generate_tokens()
  19. self._tokens = [item["token"] for item in self._tokens_map]
  20. def _load_model(self):
  21. self._model = Word2Vec(
  22. self._tokens,
  23. vector_size=64,
  24. window=3,
  25. min_count=1,
  26. sg=1, # skip-gram
  27. workers=1, # 固定为1,保证多线程不引入随机性
  28. seed=123456, # 固定随机种子,确保结果可复现
  29. epochs=20,
  30. sample=0.0000001
  31. )
  32. def token_to_vector(self, tokens):
  33. """将token转换为vector"""
  34. vector = [self._model.wv[token] for token in tokens if token in self._model.wv]
  35. return np.mean(vector, axis=0) if vector else np.zeros(self._model.vector_size)
  36. def item_to_token(self, item):
  37. token = []
  38. for col in ProductConfig.FEATURE_COLUMNS:
  39. if col == 'product_code':
  40. continue
  41. else:
  42. token.append(f"{item[col].strip()}")
  43. return token
  44. def get_similarity_map(self, product):
  45. """获取目标卷烟与所有卷烟的相似度"""
  46. product = product.squeeze().to_dict()
  47. token = self.item_to_token(product)
  48. vector = self.token_to_vector(token).reshape(1, -1)
  49. similarity_map = []
  50. for item in tqdm(self._tokens_map, desc="正在计算卷烟相似度..."):
  51. target_product_code = item["product_code"]
  52. torget_token = item["token"]
  53. target_vectot = self.token_to_vector(torget_token).reshape(1, -1)
  54. similarity = cosine_similarity(vector, target_vectot)[0][0]
  55. similarity_map.append(
  56. {
  57. "target_product_code": product['product_code'],
  58. "product_code": target_product_code,
  59. "similarity": similarity
  60. }
  61. )
  62. similarity_map.sort(key=lambda x: x["similarity"], reverse=True)
  63. return similarity_map
  64. if __name__ == "__main__":
  65. dao = MySqlDao()
  66. city_uuid = "00000000000000000000000011445301"
  67. product_id = "350139"
  68. product = dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  69. product = sample_data_clear(product, ProductConfig)
  70. model = Item2Vec(city_uuid)
  71. sims = model.get_similarity_map(product)
  72. sims = pd.DataFrame(sims)
  73. product_info = dao.load_product_data(city_uuid)[ProductConfig.FEATURE_COLUMNS]
  74. sims = sims.merge(product_info, on="product_code", how="inner")
  75. sims.to_csv("./data/product_similarity.csv", index=False)