recommend.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from database import MySqlDao
  2. from fastapi import APIRouter, BackgroundTasks, HTTPException, status
  3. from .request_body import RecommendRequest
  4. from core import get_logger
  5. from models import Recommend
  6. import os
  7. from utils import FileStreamUtils, ReportUtils
  8. logger = get_logger("api.recommend")
  9. dao = MySqlDao()
  10. router = APIRouter()
  11. @router.post("/recommend")
  12. async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks):
  13. """推荐接口"""
  14. logger.info(f"Recommend request: city={request.city_uuid}, product={request.product_code}, core_custs={len(request.cust_code_list)}")
  15. gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
  16. if not os.path.exists(gbdtlr_model_path):
  17. logger.warning(f"Model not found: {gbdtlr_model_path}")
  18. raise HTTPException(
  19. status_code=status.HTTP_404_NOT_FOUND,
  20. detail="该城市的模型未训练,请先进行训练",
  21. )
  22. recommend_model = Recommend(request.city_uuid)
  23. products_in_order = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
  24. if request.product_code in products_in_order:
  25. logger.info(f"Using GBDT-LR model for existing product {request.product_code}")
  26. recommend_list = recommend_model.get_recommend_list_by_gbdtlr(
  27. request.product_code, cust_code_list=request.cust_code_list
  28. )
  29. else:
  30. logger.info(f"Using Item2Vec model for new product {request.product_code}")
  31. recommend_list = recommend_model.get_recommend_list_by_item2vec(
  32. request.product_code, cust_code_list=request.cust_code_list
  33. )
  34. request_data = []
  35. for index, data in enumerate(recommend_list):
  36. request_data.append(
  37. {
  38. "id": index + 1,
  39. "cust_code": data["cust_code"],
  40. "recommend_score": data["recommend_score"],
  41. }
  42. )
  43. logger.info(f"Recommend completed: {len(request_data)} customers recommended")
  44. backgroundTasks.add_task(generate_and_upload_report, request)
  45. return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
  46. def generate_and_upload_report(request: RecommendRequest):
  47. """生成并上传报告到阿里云文件数据库"""
  48. logger.info(f"Background task started: generating report for {request.city_uuid}/{request.product_code}")
  49. try:
  50. report_util = ReportUtils(request.city_uuid, request.product_code)
  51. report_util.generate_all_data(request.cust_code_list)
  52. reports_dir = os.path.join("./data/reports", request.city_uuid, request.product_code)
  53. report_files = ["卷烟信息表", "品规商户特征关系表", "相似卷烟表", "商户售卖推荐表"]
  54. file_id_map = FileStreamUtils.upload_files(reports_dir, report_files)
  55. if file_id_map is None:
  56. logger.error(f"Report upload failed for {request.city_uuid}/{request.product_code}")
  57. return
  58. data_dict = {
  59. "cultivacation_id": request.cultivacation_id,
  60. "city_uuid": request.city_uuid,
  61. "limit_cycle_name": request.limit_cycle_name,
  62. "product_code": request.product_code,
  63. "product_info_table": file_id_map.get("卷烟信息表"),
  64. "relation_table": file_id_map.get("品规商户特征关系表"),
  65. "similarity_product_table": file_id_map.get("相似卷烟表"),
  66. "recommend_table": file_id_map.get("商户售卖推荐表"),
  67. }
  68. dao.insert_report(data_dict)
  69. logger.info(f"Background task completed: report uploaded for {request.city_uuid}/{request.product_code}")
  70. except Exception as e:
  71. logger.error(f"Background task failed: {e}", exc_info=True)