Explorar o código

增加item2vec

Sherlock hai 11 meses
pai
achega
a2a78d4c67

+ 5 - 0
models/item2vec/__init__.py

@@ -0,0 +1,5 @@
+from models.item2vec.preprocess import Item2VecDataProcess
+
+__all__ = [
+    "Item2VecDataProcess"
+]

+ 34 - 0
models/item2vec/item2vec.py

@@ -0,0 +1,34 @@
+import joblib
+from models.item2vec import Item2VecDataProcess
+from gensim.models import Word2Vec
+class Item2Vec:
+    def __init__(self, city_uuid):
+        self._load_data(city_uuid)
+    
+    def _load_data(self, city_uuid):
+        """加载特征sentence"""
+        data_processor = Item2VecDataProcess(city_uuid)
+        self._sentences = data_processor.generate_sentence()
+        
+    def train(self):
+        self._model = Word2Vec(
+            self._sentences,
+            vector_size=64,
+            window=4,
+            min_count=1,
+            sg=1, # skip-gram
+            workers=4,
+            epochs=20
+        )
+        
+    def save_model(self, model_path):
+        joblib.dump(self._model, model_path)
+        
+        
+if __name__ == "__main__":
+    city_uuid = "00000000000000000000000011445301"
+    model = Item2Vec(city_uuid)
+    print("开始训练Item2Vec...")
+    model.train()
+    
+    

+ 42 - 0
models/item2vec/preprocess.py

@@ -0,0 +1,42 @@
+
+from database.dao.mysql_dao import MySqlDao
+from models.rank.data.config import ProductConfig
+from models.rank.data.utils import sample_data_clear
+
+
+class Item2VecDataProcess:
+    def __init__(self, city_uuid):
+        self._mysql_dao = MySqlDao()
+        print("item2vec: 正在加载product_info...")
+        self._product_data = self._mysql_dao.load_product_data(city_uuid)
+        self._data_process()
+        
+    def _data_process(self):
+        """数据预处理"""
+        # 获取指定的特征
+        self._product_data = self._product_data[ProductConfig.FEATURE_COLUMNS]
+        # 数据清洗
+        self._product_data = sample_data_clear(self._product_data, ProductConfig)
+        
+    def tokenize_features(self, row):
+        """根据每款烟的特征生成sentence"""
+        tokens = []
+        
+        for col in ProductConfig.FEATURE_COLUMNS:
+            if col == 'product_code':
+                continue
+            if col in ["direct_retail_price", "tbc_total_length"]:
+                tokens.append(f"{col}_{row[col].replace('-', '_')}")
+            else:
+                tokens.append(f"{col}_{row[col]}")
+        
+        return tokens
+    
+    def generate_sentence(self):
+        sentcens = self._product_data.apply(self.tokenize_features, axis=1).tolist()
+        return sentcens
+        
+if __name__ == "__main__":
+    city_uuid = "00000000000000000000000011445301"
+    processor = Item2VecDataProcess(city_uuid)
+    processor.generate_sentence()

+ 0 - 120
models/recall/item2vec.py

@@ -1,120 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-import gensim
-from dao.mysql_client import Mysql
-
-class Item2Vec(object):
-    def __init__(self):
-        mysql_client = Mysql()
-        # 创建会话
-        self.session = mysql_client.create_session()
-def load_item_sequences_from_mysql():
-    try:
-        conn = mysql.connector.connect(
-            host='localhost',
-            user='your_username',
-            password='your_password',
-            database='your_database'
-        )
-        cursor = conn.cursor()
-        query = "SELECT user_id, sequence FROM item_sequences"
-        cursor.execute(query)
-        for row in cursor:
-            user_id, sequence_str = row
-            sequence = sequence_str.split(',')
-            yield user_id, sequence
-        cursor.close()
-        conn.close()
-    except mysql.connector.Error as err:
-        print(f"数据库连接或查询出错: {err}")
-
-
-def load_item_attributes_from_mysql():
-    try:
-        conn = mysql.connector.connect(
-            host='localhost',
-            user='your_username',
-            password='your_password',
-            database='your_database'
-        )
-        cursor = conn.cursor()
-        query = "SELECT item, attributes FROM item_attributes"
-        cursor.execute(query)
-        item_attributes = {}
-        for item, attributes_str in cursor:
-            attributes = attributes_str.split(',')
-            item_attributes[item] = attributes
-        cursor.close()
-        conn.close()
-        return item_attributes
-    except mysql.connector.Error as err:
-        print(f"数据库连接或查询出错: {err}")
-
-
-def load_user_attributes_from_mysql():
-    try:
-        conn = mysql.connector.connect(
-            host='localhost',
-            user='your_username',
-            password='your_password',
-            database='your_database'
-        )
-        cursor = conn.cursor()
-        query = "SELECT user_id, taste, cigarette_length, cigarette_type, packaging_color FROM user_attributes"
-        cursor.execute(query)
-        for row in cursor:
-            user_id, taste, cigarette_length, cigarette_type, packaging_color = row
-            user_attrs = [attr for attr in [taste, cigarette_length, cigarette_type, packaging_color] if attr]
-            yield user_id, user_attrs
-        cursor.close()
-        conn.close()
-    except mysql.connector.Error as err:
-        print(f"数据库连接或查询出错: {err}")
-
-
-def combine_user_item_attributes(item_sequences, item_attributes):
-    user_attributes = {user_id: attrs for user_id, attrs in load_user_attributes_from_mysql()}
-    for user_id, sequence in item_sequences:
-        user_attrs = user_attributes.get(user_id, [])
-        combined_sequence = user_attrs.copy()
-        for item in sequence:
-            combined_sequence.append(item)
-            combined_sequence.extend(item_attributes.get(item, []))
-        yield combined_sequence
-
-
-def train_item2vec(item_sequences, vector_size=100, window=5, min_count=10, workers=4):
-    model = gensim.models.Word2Vec(sentences=item_sequences, vector_size=vector_size, window=window,
-                                   min_count=min_count, workers=workers)
-    return model
-
-
-def get_item_vector(item, model):
-    try:
-        return model.wv[item]
-    except KeyError:
-        print(f"物品 {item} 未在模型中找到。")
-        return None
-
-
-def find_similar_items(item, model, topn=5):
-    try:
-        similar_items = model.wv.most_similar(item, topn=topn)
-        filtered_similar_items = [(item, score) for item, score in similar_items if not item.startswith(('attr', 'user_'))]
-        return filtered_similar_items
-    except KeyError:
-        print(f"物品 {item} 未在模型中找到。")
-        return None
-
-
-if __name__ == "__main__":
-    item_sequences = load_item_sequences_from_mysql()
-    item_attributes = load_item_attributes_from_mysql()
-    combined_sequences = combine_user_item_attributes(item_sequences, item_attributes)
-    item2vec_model = train_item2vec(combined_sequences)
-    item_vector = get_item_vector('item1', item2vec_model)
-    if item_vector is not None:
-        print(f"物品 'item1' 的向量表示: {item_vector}")
-    similar_items = find_similar_items('item1', item2vec_model, topn=3)
-    if similar_items is not None:
-        print(f"与物品 'item1' 最相似的 3 个物品: {similar_items}")