api.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from fastapi import FastAPI, Request, status, BackgroundTasks
  2. from fastapi.exceptions import RequestValidationError
  3. from fastapi.responses import JSONResponse
  4. from database.dao.mysql_dao import MySqlDao
  5. from models import Recommend
  6. import os
  7. from pydantic import BaseModel
  8. import uvicorn
  9. from utils import ReportUtils
  10. from typing import List, Dict
  11. app = FastAPI()
  12. dao = MySqlDao()
  13. # 添加全局异常处理器
  14. @app.exception_handler(RequestValidationError)
  15. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  16. return JSONResponse(
  17. status_code=status.HTTP_400_BAD_REQUEST,
  18. content={
  19. "code": 400,
  20. "msg": "请求参数错误",
  21. "data": {
  22. "detail": exc.errors(),
  23. "body": exc.body
  24. }
  25. },
  26. )
  27. # 定义请求体
  28. class RecommendRequest(BaseModel):
  29. city_uuid: str # 城市id
  30. product_code: str # 卷烟编码
  31. recall_cust_count: int # 推荐的商户数量
  32. delivery_count: int # 投放的品规数量
  33. class ReportRequest(BaseModel):
  34. city_uuid: str # 城市id
  35. product_code: str # 卷烟编码
  36. @app.post("/brandcultivation/api/v1/recommend")
  37. async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks):
  38. gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
  39. if not os.path.exists(gbdtlr_model_path):
  40. return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}}
  41. # 初始化模型
  42. recommend_model = Recommend(request.city_uuid)
  43. # 判断该品规是否是新品规
  44. products_in_oreder = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
  45. if request.product_code in products_in_oreder:
  46. recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count)
  47. else:
  48. recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count)
  49. recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count)
  50. request_data = []
  51. for index, data in enumerate(recommend_data):
  52. id = index + 1
  53. request_data.append(
  54. {
  55. "id": id,
  56. "cust_code": data["cust_code"],
  57. "recommend_score": data["recommend_score"],
  58. "delivery_count": data["delivery_count"]
  59. }
  60. )
  61. # 异步执行报告生成任务
  62. backgroundTasks.add_task(
  63. generate_report,
  64. request.city_uuid,
  65. request.product_code,
  66. request.recall_cust_count,
  67. request.delivery_count
  68. )
  69. return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
  70. def generate_report(city_uuid, product_id, recall_count, delivery_count):
  71. """生成报告"""
  72. report_util = ReportUtils(city_uuid, product_id)
  73. report_util.generate_all_data(recall_count, delivery_count)
  74. REPORT_FILES = [
  75. "卷烟信息表.xlsx",
  76. "品规商户特征关系表.xlsx",
  77. "商户售卖推荐表.xlsx",
  78. "相似卷烟表.xlsx"
  79. ]
  80. @app.get("/brandcultivation/api/v1/download_report")
  81. async def get_report_files(city_uuid: str, product_code: str) -> JSONResponse:
  82. report_dir = os.path.join("./data/reports", city_uuid, product_code)
  83. # 检查报告是否存在
  84. if not os.path.exists(report_dir):
  85. return JSONResponse(
  86. status_code=200,
  87. content={
  88. "code": 200,
  89. "msg": "report directory not found",
  90. "data": {"reportDownloadInfo": "该品规报告还未生成!"}
  91. }
  92. )
  93. # 收集所有存在的文件
  94. available_files:List[Dict[str, str]] = []
  95. base_url = "http://127.0.0.1:7960"
  96. for filename in REPORT_FILES:
  97. file_path = os.path.join(report_dir, filename)
  98. if os.path.exists(file_path):
  99. file_url = f"{base_url}/brandcultivation/api/v1/download_file/{city_uuid}/{product_code}/{filename}"
  100. available_files.append({
  101. "filename": filename,
  102. "download_url": file_url
  103. })
  104. return JSONResponse(
  105. status_code=200,
  106. content={
  107. "code": 200,
  108. "msg": "success",
  109. "data": {
  110. "reportDownloadInfo": available_files
  111. }
  112. }
  113. )
  114. if __name__ == "__main__":
  115. uvicorn.run(app, host="0.0.0.0", port=7960)