item2vec.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import joblib
  2. from database.dao.mysql_dao import MySqlDao
  3. from models.item2vec import Item2VecDataProcess
  4. import numpy as np
  5. from gensim.models import Word2Vec
  6. from models.rank.data.config import ProductConfig
  7. from models.rank.data.utils import sample_data_clear
  8. import pandas as pd
  9. from sklearn.metrics.pairwise import cosine_similarity
  10. from tqdm import tqdm
  11. class Item2Vec:
  12. def __init__(self, city_uuid):
  13. self._load_data(city_uuid)
  14. self._load_model()
  15. def _load_data(self, city_uuid):
  16. """加载特征sentence"""
  17. data_processor = Item2VecDataProcess(city_uuid)
  18. self._tokens_map = data_processor.generate_tokens()
  19. self._tokens = [item["token"] for item in self._tokens_map]
  20. def _load_model(self):
  21. self._model = Word2Vec(
  22. self._tokens,
  23. vector_size=64,
  24. window=3,
  25. min_count=1,
  26. sg=1, # skip-gram
  27. workers=4,
  28. epochs=20,
  29. sample=0.0000001
  30. )
  31. def token_to_vector(self, tokens):
  32. """将token转换为vector"""
  33. vector = [self._model.wv[token] for token in tokens if token in self._model.wv]
  34. return np.mean(vector, axis=0) if vector else np.zeros(self._model.vector_size)
  35. def item_to_token(self, item):
  36. token = []
  37. for col in ProductConfig.FEATURE_COLUMNS:
  38. if col == 'product_code':
  39. continue
  40. else:
  41. token.append(f"{item[col].strip()}")
  42. return token
  43. def get_similarity_map(self, product):
  44. """获取目标卷烟与所有卷烟的相似度"""
  45. product = product.squeeze().to_dict()
  46. token = self.item_to_token(product)
  47. vector = self.token_to_vector(token).reshape(1, -1)
  48. similarity_map = []
  49. for item in tqdm(self._tokens_map, desc="正在计算卷烟相似度..."):
  50. target_product_code = item["product_code"]
  51. torget_token = item["token"]
  52. target_vectot = self.token_to_vector(torget_token).reshape(1, -1)
  53. similarity = cosine_similarity(vector, target_vectot)[0][0]
  54. similarity_map.append(
  55. {
  56. "product_code": product['product_code'],
  57. "target_product_code": target_product_code,
  58. "similarity": similarity
  59. }
  60. )
  61. similarity_map.sort(key=lambda x: x["similarity"], reverse=True)
  62. return similarity_map
  63. if __name__ == "__main__":
  64. dao = MySqlDao()
  65. city_uuid = "00000000000000000000000011445301"
  66. product_id = "420202"
  67. order_data = dao.load_order_data(city_uuid)
  68. product = dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  69. product = sample_data_clear(product, ProductConfig)
  70. model = Item2Vec(city_uuid)
  71. sims = model.get_similarity_map(product)
  72. sims = pd.DataFrame(sims)
  73. product_info = dao.load_product_data(city_uuid)[ProductConfig.FEATURE_COLUMNS]
  74. sims = sims.merge(product_info, left_on="target_product_code", right_on="product_code", how="inner")
  75. sims.to_csv("./data/product_similarity.csv", index=False)