| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import joblib
- from database.dao.mysql_dao import MySqlDao
- from models.item2vec import Item2VecDataProcess
- import numpy as np
- from gensim.models import Word2Vec
- from models.rank.data.config import ProductConfig
- from models.rank.data.utils import sample_data_clear
- import pandas as pd
- from sklearn.metrics.pairwise import cosine_similarity
- from tqdm import tqdm
- class Item2Vec:
- def __init__(self, city_uuid):
- self._load_data(city_uuid)
- self._load_model()
-
- def _load_data(self, city_uuid):
- """加载特征sentence"""
- data_processor = Item2VecDataProcess(city_uuid)
- self._tokens_map = data_processor.generate_tokens()
- self._tokens = [item["token"] for item in self._tokens_map]
-
- def _load_model(self):
- self._model = Word2Vec(
- self._tokens,
- vector_size=64,
- window=3,
- min_count=1,
- sg=1, # skip-gram
- workers=4,
- epochs=20,
- sample=0.0000001
- )
-
- def token_to_vector(self, tokens):
- """将token转换为vector"""
- vector = [self._model.wv[token] for token in tokens if token in self._model.wv]
- return np.mean(vector, axis=0) if vector else np.zeros(self._model.vector_size)
-
- def item_to_token(self, item):
- token = []
- for col in ProductConfig.FEATURE_COLUMNS:
- if col == 'product_code':
- continue
- else:
- token.append(f"{item[col].strip()}")
-
- return token
-
- def get_similarity_map(self, product):
- """获取目标卷烟与所有卷烟的相似度"""
- product = product.squeeze().to_dict()
- token = self.item_to_token(product)
- vector = self.token_to_vector(token).reshape(1, -1)
-
- similarity_map = []
- for item in tqdm(self._tokens_map, desc="正在计算卷烟相似度..."):
- target_product_code = item["product_code"]
- torget_token = item["token"]
- target_vectot = self.token_to_vector(torget_token).reshape(1, -1)
-
- similarity = cosine_similarity(vector, target_vectot)[0][0]
-
- similarity_map.append(
- {
- "target_product_code": product['product_code'],
- "product_code": target_product_code,
- "similarity": similarity
- }
- )
- similarity_map.sort(key=lambda x: x["similarity"], reverse=True)
-
- return similarity_map
-
-
-
-
- if __name__ == "__main__":
- dao = MySqlDao()
- city_uuid = "00000000000000000000000011445301"
- product_id = "350139"
-
- product = dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
- product = sample_data_clear(product, ProductConfig)
- model = Item2Vec(city_uuid)
- sims = model.get_similarity_map(product)
- sims = pd.DataFrame(sims)
- product_info = dao.load_product_data(city_uuid)[ProductConfig.FEATURE_COLUMNS]
- sims = sims.merge(product_info, on="product_code", how="inner")
- sims.to_csv("./data/product_similarity.csv", index=False)
|