inference.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. from database import RedisDatabaseHelper, MySqlDao
  2. from models.item2vec import Item2VecModel
  3. from models.rank.data.config import CustConfig, ProductConfig, ShopConfig, OrderConfig
  4. from models.rank.data.utils import sample_data_clear
  5. from models.rank.gbdt_lr_inference import GbdtLrModel
  6. from utils.result_process import get_cust_list_from_history_order, split_relation_subtable, generate_report
  7. import pandas as pd
  8. redis = RedisDatabaseHelper().redis
  9. dao = MySqlDao()
  10. gbdtlr_model = GbdtLrModel("./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl")
  11. item2vec = Item2VecModel("00000000000000000000000011445301")
  12. def get_itemcf_recall(city_uuid, product_id):
  13. """协同召回"""
  14. key = f"fc:{city_uuid}:{product_id}"
  15. recall_list = redis.zrevrange(key, 0, -1, withscores=False)
  16. return recall_list
  17. def get_hot_recall(city_uuid):
  18. """热度召回"""
  19. key = f"hot:{city_uuid}:sale_qty"
  20. recall_list = redis.zrevrange(key, 0, -1, withscores=False)
  21. return recall_list
  22. def get_recall_cust(city_uuid, product_id, recall_count):
  23. """根据协同过滤和热度召回召回商户
  24. """
  25. itemcf_recall_list = get_itemcf_recall(city_uuid, product_id)
  26. hot_recall_list = get_hot_recall(city_uuid)
  27. result = list(dict.fromkeys(itemcf_recall_list))
  28. # 如果结果不足,从hot_recall中补齐
  29. if len(result) < recall_count:
  30. hot_recall_set = set(hot_recall_list) - set(result)
  31. additional_items = [item for item in hot_recall_list if item in hot_recall_set]
  32. needed = recall_count - len(result)
  33. result.extend(additional_items[:needed])
  34. return result[:recall_count]
  35. def generate_recommend_sample(city_uuid, product_id):
  36. """生成预测数据集"""
  37. product_in_order = dao.get_product_from_order(city_uuid)["product_code"].unique().tolist()
  38. if product_id in product_in_order:
  39. recall_count = 1000
  40. cust_list = get_recall_cust(city_uuid, product_id, recall_count)
  41. else:
  42. cust_list = item2vec.get_recommend_cust_list(product_id)["cust_code"].to_list()
  43. # 获取卷烟的信息
  44. product_data = dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  45. filter_dict = product_data.to_dict("records")[0]
  46. cust_data = dao.get_cust_by_ids(city_uuid, cust_list)[CustConfig.FEATURE_COLUMNS]
  47. shop_data = dao.get_shop_by_ids(city_uuid, cust_list)[ShopConfig.FEATURE_COLUMNS]
  48. product_data = sample_data_clear(product_data, ProductConfig)
  49. cust_data = sample_data_clear(cust_data, CustConfig)
  50. shop_data = sample_data_clear(shop_data, ShopConfig)
  51. cust_feats = shop_data.set_index("cust_code")
  52. cust_data = cust_data.join(cust_feats, on="BB_RETAIL_CUSTOMER_CODE", how="inner")
  53. feats_map = gbdtlr_model.generate_feats_map(product_data, cust_data)
  54. return feats_map, filter_dict, cust_list
  55. def get_recommend_list_by_gbdt_lr(city_uuid, product_id):
  56. """根据gbdt-lr进行打分并获得推荐列表,适用于推荐历史订单中存在的卷烟"""
  57. feats_sample, _, cust_list = generate_recommend_sample(city_uuid, product_id)
  58. recommend_list = gbdtlr_model.get_recommend_list(feats_sample, cust_list)
  59. return recommend_list
  60. def gbdt_lr_inference(city_uuid, product_id):
  61. pass
  62. def generate_features_shap(city_uuid, product_id, delivery_count):
  63. feats_sample, filter_dict, cust_list = generate_recommend_sample(city_uuid, product_id)
  64. if product_id in dao.get_product_from_order(city_uuid)["product_code"].unique().tolist():
  65. # 如果推荐商品为新卷烟,走iterm2vec
  66. recommend_data = gbdtlr_model.get_recommend_list(feats_sample, cust_list)
  67. else:
  68. recommend_data = item2vec.get_recommend_cust_list(product_id).to_dict("records")
  69. result = gbdtlr_model.generate_shap_interance(feats_sample)
  70. generate_report(city_uuid, result, filter_dict, recommend_data, delivery_count, "./data")
  71. def eval(city_uuid, product_code):
  72. """推荐效果验证"""
  73. eval_report = get_cust_list_from_history_order(city_uuid, product_code)
  74. eval_report.to_csv("./data/eval.csv", index=False)
  75. def generate_delivery_strategy():
  76. pass
  77. def run():
  78. pass
  79. if __name__ == '__main__':
  80. # generate_features_shap("00000000000000000000000011445301", "350139", delivery_count=5000)
  81. eval("00000000000000000000000011445301", "350355")
  82. # recommend_list = get_recommend_list_by_gbdt_lr("00000000000000000000000011445301", "350139")
  83. # recommend_list = pd.DataFrame(recommend_list)
  84. # recommend_list.to_csv("./data/recommend_list.csv", index=False, encoding="utf-8-sig")
  85. # 拿龙军数据
  86. # data = dao.get_order_by_cust("00000000000000000000000011445301", "445323105795")
  87. # data = data.groupby(["cust_code", "product_code", "product_name"], as_index=False)["sale_qty"].sum()
  88. # data.to_csv("./data/cust.csv", index=False)
  89. # city_uuid = "00000000000000000000000011445301"
  90. # order_data = dao.get_order_by_cust("00000000000000000000000011445301", "445323105795")
  91. # order_data["sale_qty"] = order_data["sale_qty"].fillna(0)
  92. # order_data = order_data.infer_objects(copy=False)
  93. # order_data = order_data.groupby(["cust_code", "product_code", "product_name"], as_index=False)["sale_qty"].sum()
  94. # cust_data = dao.load_cust_data(city_uuid)[CustConfig.FEATURE_COLUMNS]
  95. # sample_data_clear(cust_data, CustConfig)
  96. # shop_data = dao.load_shopping_data(city_uuid)[ShopConfig.FEATURE_COLUMNS]
  97. # sample_data_clear(shop_data, ShopConfig)
  98. # cust_ids = shop_data.set_index("cust_code")
  99. # cust_data = cust_data.join(cust_ids, on="BB_RETAIL_CUSTOMER_CODE", how="inner")
  100. # product_data = dao.load_product_data(city_uuid)[ProductConfig.FEATURE_COLUMNS]
  101. # sample_data_clear(product_data, ProductConfig)
  102. # order_data = order_data.merge(product_data, on="product_code", how="inner")
  103. # order_data = order_data.merge(cust_data, left_on='cust_code', right_on='BB_RETAIL_CUSTOMER_CODE', how="inner")
  104. # result = gbdtlr_model.inference_from_sample(order_data)
  105. # result.to_csv("./data/junlong.csv", index=False)