api.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from fastapi import FastAPI, HTTPException, Request, status
  2. from fastapi.exceptions import RequestValidationError
  3. from fastapi.responses import JSONResponse
  4. from database.dao.mysql_dao import MySqlDao
  5. from models import Recommend
  6. import os
  7. from pydantic import BaseModel
  8. import uvicorn
  9. from utils import ReportUtils
  10. app = FastAPI()
  11. dao = MySqlDao()
  12. # 添加全局异常处理器
  13. @app.exception_handler(RequestValidationError)
  14. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  15. return JSONResponse(
  16. status_code=status.HTTP_400_BAD_REQUEST,
  17. content={
  18. "code": 400,
  19. "msg": "请求参数错误",
  20. "data": {
  21. "detail": exc.errors(),
  22. "body": exc.body
  23. }
  24. },
  25. )
  26. # 定义请求体
  27. class RecommendRequest(BaseModel):
  28. city_uuid: str # 城市id
  29. product_code: str # 卷烟编码
  30. recall_cust_count: int # 推荐的商户数量
  31. delivery_count: int # 投放的品规数量
  32. @app.post("/brandcultivation/api/v1/recommend")
  33. def recommend(request: RecommendRequest):
  34. gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
  35. if not os.path.exists(gbdtlr_model_path):
  36. return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}}
  37. # 初始化模型
  38. recommend_model = Recommend(request.city_uuid)
  39. # 判断该品规是否是新品规
  40. products_in_oreder = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
  41. if request.product_code in products_in_oreder:
  42. recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count)
  43. else:
  44. recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count)
  45. recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count)
  46. request_data = []
  47. for index, data in enumerate(recommend_data):
  48. id = index + 1
  49. request_data.append(
  50. {
  51. "id": id,
  52. "cust_code": data["cust_code"],
  53. "recommend_score": data["recommend_score"],
  54. "delivery_count": data["delivery_count"]
  55. }
  56. )
  57. generate_report(request.city_uuid, request.product_code, request.recall_cust_count, request.delivery_count)
  58. return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
  59. def generate_report(city_uuid, product_id, recall_count, delivery_count):
  60. """生成报告"""
  61. report_util = ReportUtils(city_uuid, product_id)
  62. report_util.generate_all_data(recall_count, delivery_count)
  63. if __name__ == "__main__":
  64. uvicorn.run(app, host="0.0.0.0", port=7960)