inference.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from database.dao.mysql_dao import MySqlDao
  2. from models.item2vec import Item2Vec
  3. from models.rank.data.config import OrderConfig, ProductConfig
  4. from models.rank.data.utils import sample_data_clear
  5. import numpy as np
  6. import pandas as pd
  7. from sklearn.preprocessing import StandardScaler
  8. class Item2VecModel:
  9. def __init__(self, city_uuid):
  10. self._dao = MySqlDao()
  11. self._city_uuid = city_uuid
  12. self._item2vec_model = Item2Vec(city_uuid)
  13. def generate_product_similarity_map(self, product_code):
  14. """根据product_code生成卷烟相似度矩阵"""
  15. product = self._dao.get_product_by_id(self._city_uuid, product_code)[ProductConfig.FEATURE_COLUMNS]
  16. product = sample_data_clear(product, ProductConfig)
  17. similarity_map = self._item2vec_model.get_similarity_map(product)
  18. similarity_map = pd.DataFrame(similarity_map)
  19. product_list = self._dao.load_product_data(self._city_uuid)[ProductConfig.FEATURE_COLUMNS]
  20. similarity_map = similarity_map.merge(product_list, on="product_code", how="inner")
  21. # self._similarity_map = self._similarity_map.query(f"product_code != {product_code}")
  22. return similarity_map
  23. def get_similarity_list(self, product_code, top=40):
  24. """获取与指卷烟最相似的top k个卷烟"""
  25. similarity_map = self.generate_product_similarity_map(product_code)
  26. similarity_map.to_excel("./data/product_similarity.xlsx", index=False)
  27. similarity_list = similarity_map["product_code"].to_list()
  28. similarity_list = similarity_list[:top]
  29. return similarity_list
  30. def get_recommend_cust_list(self, product_code, top=100):
  31. """获取推荐的商户列表"""
  32. product_list = self.get_similarity_list(product_code)
  33. order_data = self._dao.get_order_by_product_ids(self._city_uuid, product_list)[OrderConfig.FEATURE_COLUMNS]
  34. order_data["sale_qty"] = order_data["sale_qty"].fillna(0)
  35. order_data = order_data.groupby(["cust_code", "product_code"], as_index=False)["sale_qty"].sum()
  36. # 按照卷烟分组,取每款卷烟售卖最好的前50个商户
  37. order_data = (
  38. order_data
  39. .sort_values(["product_code", "sale_qty"], ascending=[True, False])
  40. .groupby("product_code")
  41. .head(top)
  42. )
  43. recommend_cust = (
  44. order_data.groupby(["cust_code"], as_index=False)["sale_qty"].sum()
  45. .query("sale_qty > 0")
  46. .sort_values(["sale_qty"], ascending=[False])
  47. )
  48. # 对销量进行归一化
  49. scaler = StandardScaler()
  50. normalized = scaler.fit_transform(recommend_cust["sale_qty"].values.reshape(-1, 1))
  51. recommend_cust["sale_qty"] = ((1 / (1 + np.exp(-normalized))) * 100).flatten()
  52. recommend_cust = recommend_cust.rename(columns={"sale_qty": "recommend_score"})
  53. # recommend_cust.to_csv("./data/item2vec_recommend.csv", index=False)
  54. return recommend_cust
  55. if __name__ == "__main__":
  56. city_uuid = "00000000000000000000000011445301"
  57. product_id = "350139"
  58. model = Item2VecModel(city_uuid)
  59. model.get_similarity_list(product_id)
  60. # dao = MySqlDao()
  61. # data = dao.get_order_by_cust_and_product(city_uuid, "445300108802", "340223")[OrderConfig.FEATURE_COLUMNS]
  62. # data.to_csv("./data/result.csv", index=False)