api.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from fastapi import FastAPI, HTTPException, Request, status
  2. from fastapi.exceptions import RequestValidationError
  3. from fastapi.responses import JSONResponse
  4. from database.dao.mysql_dao import MySqlDao
  5. from models import Recommend
  6. import os
  7. from pydantic import BaseModel
  8. import uvicorn
  9. app = FastAPI()
  10. dao = MySqlDao()
  11. # 添加全局异常处理器
  12. @app.exception_handler(RequestValidationError)
  13. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  14. return JSONResponse(
  15. status_code=status.HTTP_400_BAD_REQUEST,
  16. content={
  17. "code": 400,
  18. "msg": "请求参数错误",
  19. "data": {
  20. "detail": exc.errors(),
  21. "body": exc.body
  22. }
  23. },
  24. )
  25. # 定义请求体
  26. class RecommendRequest(BaseModel):
  27. city_uuid: str # 城市id
  28. product_code: str # 卷烟编码
  29. recall_cust_count: int # 推荐的商户数量
  30. delivery_count: int # 投放的品规数量
  31. @app.post("/brandcultivation/api/v1/recommend")
  32. def recommend(request: RecommendRequest):
  33. gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
  34. if not os.path.exists(gbdtlr_model_path):
  35. return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}}
  36. # 初始化模型
  37. recommend_model = Recommend(request.city_uuid)
  38. # 判断该品规是否是新品规
  39. products_in_oreder = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
  40. if request.product_code in products_in_oreder:
  41. recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count)
  42. else:
  43. recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count)
  44. recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count)
  45. request_data = []
  46. for index, data in enumerate(recommend_data):
  47. id = index + 1
  48. request_data.append(
  49. {
  50. "id": id,
  51. "cust_code": data["cust_code"],
  52. "recommend_score": data["recommend_score"],
  53. "delivery_count": data["delivery_count"]
  54. }
  55. )
  56. return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
  57. def generate_report(city_uuid, product_id, recall_count, delivery_count):
  58. """生成报告"""
  59. if __name__ == "__main__":
  60. uvicorn.run(app, host="0.0.0.0", port=8000)