inference.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from database.dao.mysql_dao import MySqlDao
  2. from models.item2vec import Item2Vec
  3. from models.rank.data.config import OrderConfig, ProductConfig
  4. from models.rank.data.utils import sample_data_clear
  5. import pandas as pd
  6. class Item2VecModel:
  7. def __init__(self, city_uuid):
  8. self._dao = MySqlDao()
  9. self._city_uuid = city_uuid
  10. self._item2vec_model = Item2Vec(city_uuid)
  11. def generate_product_similarity_map(self, product_code):
  12. """根据product_code生成卷烟相似度矩阵"""
  13. product = self._dao.get_product_by_id(self._city_uuid, product_code)[ProductConfig.FEATURE_COLUMNS]
  14. product = sample_data_clear(product, ProductConfig)
  15. similarity_map = self._item2vec_model.get_similarity_map(product)
  16. similarity_map = pd.DataFrame(similarity_map)
  17. product_list = self._dao.load_product_data(self._city_uuid)[ProductConfig.FEATURE_COLUMNS]
  18. similarity_map = similarity_map.merge(product_list, on="product_code", how="inner")
  19. # self._similarity_map = self._similarity_map.query(f"product_code != {product_code}")
  20. return similarity_map
  21. def get_similarity_list(self, product_code, top=40):
  22. """获取与指卷烟最相似的top k个卷烟"""
  23. similarity_map = self.generate_product_similarity_map(product_code)
  24. similarity_list = similarity_map["product_code"].to_list()
  25. # similarity_list.remove(product_code)
  26. similarity_list = similarity_list[:top]
  27. return similarity_list
  28. def get_recommend_cust_list(self, product_code, top=50):
  29. """获取推荐的商户列表"""
  30. product_list = self.get_similarity_list(product_code)
  31. order_data = self._dao.get_order_by_product_ids(self._city_uuid, product_list)[OrderConfig.FEATURE_COLUMNS]
  32. order_data["sale_qty"] = order_data["sale_qty"].fillna(0)
  33. order_data = order_data.groupby(["cust_code", "product_code"], as_index=False)["sale_qty"].sum()
  34. # 按照卷烟分组,取每款卷烟售卖最好的前50个商户
  35. order_data = (
  36. order_data
  37. .sort_values(["product_code", "sale_qty"], ascending=[True, False])
  38. .groupby("product_code")
  39. .head(top)
  40. )
  41. recommend_cust = order_data.groupby(["cust_code"], as_index=False)["sale_qty"].sum()
  42. recommend_cust = recommend_cust.sort_values(["sale_qty"], ascending=[False])
  43. recommend_cust.to_csv("./data/recommend.csv", index=False)
  44. if __name__ == "__main__":
  45. city_uuid = "00000000000000000000000011445301"
  46. product_id = "420202"
  47. model = Item2VecModel(city_uuid)
  48. model.get_recommend_cust_list(product_id)
  49. # dao = MySqlDao()
  50. # data = dao.get_order_by_cust_and_product(city_uuid, "445300108802", "340223")[OrderConfig.FEATURE_COLUMNS]
  51. # data.to_csv("./data/result.csv", index=False)