inference.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from database import RedisDatabaseHelper, MySqlDao
  2. from models.rank.data.config import CustConfig, ProductConfig, ShopConfig, OrderConfig
  3. from models.rank.data.utils import sample_data_clear
  4. from models.rank.gbdt_lr_inference import GbdtLrModel
  5. from utils.result_process import split_relation_subtable, generate_report
  6. import pandas as pd
  7. redis = RedisDatabaseHelper().redis
  8. dao = MySqlDao()
  9. gbdtlr_model = GbdtLrModel("./models/rank/weights/00000000000000000000000011445301/gbdtlr_model.pkl")
  10. def get_itemcf_recall(city_uuid, product_id):
  11. """协同召回"""
  12. key = f"fc:{city_uuid}:{product_id}"
  13. recall_list = redis.zrevrange(key, 0, -1, withscores=False)
  14. return recall_list
  15. def get_hot_recall(city_uuid):
  16. """热度召回"""
  17. key = f"hot:{city_uuid}:sale_qty"
  18. recall_list = redis.zrevrange(key, 0, -1, withscores=False)
  19. return recall_list
  20. def get_recall_cust(city_uuid, product_id, recall_count):
  21. """根据协同过滤和热度召回召回商户"""
  22. itemcf_recall_list = get_itemcf_recall(city_uuid, product_id)
  23. hot_recall_list = get_hot_recall(city_uuid)
  24. result = list(dict.fromkeys(itemcf_recall_list))
  25. # 如果结果不足,从hot_recall中补齐
  26. if len(result) < recall_count:
  27. hot_recall_set = set(hot_recall_list) - set(result)
  28. additional_items = [item for item in hot_recall_list if item in hot_recall_set]
  29. needed = recall_count - len(result)
  30. result.extend(additional_items[:needed])
  31. return result[:recall_count]
  32. def generate_recommend_sample(city_uuid, product_id):
  33. """生成预测数据集"""
  34. recall_count = 300
  35. cust_list = get_recall_cust(city_uuid, product_id, recall_count)
  36. product_data = dao.get_product_by_id(city_uuid, product_id)[ProductConfig.FEATURE_COLUMNS]
  37. filter_dict = product_data.to_dict("records")[0]
  38. cust_data = dao.get_cust_by_ids(city_uuid, cust_list)[CustConfig.FEATURE_COLUMNS]
  39. shop_data = dao.get_shop_by_ids(city_uuid, cust_list)[ShopConfig.FEATURE_COLUMNS]
  40. product_data = sample_data_clear(product_data, ProductConfig)
  41. cust_data = sample_data_clear(cust_data, CustConfig)
  42. shop_data = sample_data_clear(shop_data, ShopConfig)
  43. cust_feats = shop_data.set_index("cust_code")
  44. cust_data = cust_data.join(cust_feats, on="BB_RETAIL_CUSTOMER_CODE", how="inner")
  45. feats_map = gbdtlr_model.generate_feats_map(product_data, cust_data)
  46. return feats_map, filter_dict, cust_list
  47. def get_recommend_list(city_uuid, product_id):
  48. feats_sample, _, cust_list = generate_recommend_sample(city_uuid, product_id)
  49. recommend_list = gbdtlr_model.get_recommend_list(feats_sample, cust_list)
  50. return recommend_list
  51. def gbdt_lr_inference(city_uuid, product_id):
  52. pass
  53. def generate_features_shap(city_uuid, product_id, delivery_count):
  54. feats_sample, filter_dict, cust_list = generate_recommend_sample(city_uuid, product_id)
  55. result = gbdtlr_model.generate_shap_interance(feats_sample)
  56. recommend_data = gbdtlr_model.get_recommend_list(feats_sample, cust_list)
  57. generate_report(city_uuid, result, filter_dict, recommend_data, delivery_count, "./data")
  58. def generate_delivery_strategy():
  59. pass
  60. def run():
  61. pass
  62. if __name__ == '__main__':
  63. # generate_features_shap("00000000000000000000000011445301", "420202", delivery_count=5000)
  64. # recommend_list = get_recommend_list("00000000000000000000000011445301", "420202")
  65. # recommend_list = pd.DataFrame(recommend_list)
  66. # recommend_list.to_csv("./data/recommend_list.csv", index=False, encoding="utf-8-sig")
  67. # 拿龙军数据
  68. # data = dao.get_order_by_cust("00000000000000000000000011445301", "445323105795")
  69. # data = data.groupby(["cust_code", "product_code", "product_name"], as_index=False)["sale_qty"].sum()
  70. # data.to_csv("./data/cust.csv", index=False)
  71. city_uuid = "00000000000000000000000011445301"
  72. order_data = dao.get_order_by_cust("00000000000000000000000011445301", "445323105795")
  73. order_data["sale_qty"] = order_data["sale_qty"].fillna(0)
  74. order_data = order_data.infer_objects(copy=False)
  75. order_data = order_data.groupby(["cust_code", "product_code", "product_name"], as_index=False)["sale_qty"].sum()
  76. cust_data = dao.load_cust_data(city_uuid)[CustConfig.FEATURE_COLUMNS]
  77. sample_data_clear(cust_data, CustConfig)
  78. shop_data = dao.load_shopping_data(city_uuid)[ShopConfig.FEATURE_COLUMNS]
  79. sample_data_clear(shop_data, ShopConfig)
  80. cust_ids = shop_data.set_index("cust_code")
  81. cust_data = cust_data.join(cust_ids, on="BB_RETAIL_CUSTOMER_CODE", how="inner")
  82. product_data = dao.load_product_data(city_uuid)[ProductConfig.FEATURE_COLUMNS]
  83. sample_data_clear(product_data, ProductConfig)
  84. order_data = order_data.merge(product_data, on="product_code", how="inner")
  85. order_data = order_data.merge(cust_data, left_on='cust_code', right_on='BB_RETAIL_CUSTOMER_CODE', how="inner")
  86. result = gbdtlr_model.inference_from_sample(order_data)
  87. result.to_csv("./data/junlong.csv", index=False)