recommend.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from database import MySqlDao
  2. from fastapi import APIRouter, BackgroundTasks
  3. from .request_body import RecommendRequest
  4. from models import Recommend
  5. import os
  6. from utils import FileStreamUtils, ReportUtils
  7. dao = MySqlDao()
  8. router = APIRouter()
  9. @router.post("/recommend")
  10. async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks):
  11. """推荐接口"""
  12. gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
  13. if not os.path.exists(gbdtlr_model_path):
  14. return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}}
  15. # 初始化模型
  16. recommend_model = Recommend(request.city_uuid)
  17. # 判断该品规是否是新品规
  18. products_in_oreder = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
  19. if request.product_code in products_in_oreder:
  20. recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count)
  21. else:
  22. print("走这了")
  23. recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count)
  24. recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count)
  25. request_data = []
  26. for index, data in enumerate(recommend_data):
  27. id = index + 1
  28. request_data.append(
  29. {
  30. "id": id,
  31. "cust_code": data["cust_code"],
  32. "recommend_score": data["recommend_score"],
  33. "delivery_count": data["delivery_count"]
  34. }
  35. )
  36. # 异步执行报告生成任务
  37. backgroundTasks.add_task(
  38. generate_and_upload_report,
  39. request
  40. )
  41. return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
  42. def generate_and_upload_report(request: RecommendRequest):
  43. """生成并上传报告到阿里云文件数据库"""
  44. # 生成相关报告
  45. report_util = ReportUtils(request.city_uuid, request.product_code)
  46. report_util.generate_all_data(request.recall_cust_count, request.delivery_count)
  47. # 上传报告
  48. reports_dir = os.path.join('./data/reports', request.city_uuid, request.product_code)
  49. report_files = [
  50. '卷烟信息表',
  51. '品规商户特征关系表',
  52. '相似卷烟表',
  53. '商户售卖推荐表'
  54. ]
  55. file_id_map = FileStreamUtils.upload_files(reports_dir, report_files)
  56. # 将返回的file_id保存到数据库中
  57. data_dict = {
  58. 'cultivacation_id': request.cultivacation_id,
  59. 'city_uuid': request.city_uuid,
  60. 'limit_cycle_name': request.limit_cycle_name,
  61. 'product_code': request.product_code,
  62. 'product_info_table': file_id_map.get('卷烟信息表'),
  63. 'relation_table': file_id_map.get('品规商户特征关系表'),
  64. 'similarity_product_table': file_id_map.get('相似卷烟表'),
  65. 'recommend_table': file_id_map.get('商户售卖推荐表'),
  66. }
  67. dao.insert_report(data_dict)