| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- from fastapi import FastAPI, Request, status, BackgroundTasks
- from fastapi.exceptions import RequestValidationError
- from fastapi.responses import JSONResponse
- from database.dao.mysql_dao import MySqlDao
- from models import Recommend
- import os
- from pydantic import BaseModel
- import uvicorn
- from utils import ReportUtils
- from typing import List, Dict
- app = FastAPI()
- dao = MySqlDao()
- # 添加全局异常处理器
- @app.exception_handler(RequestValidationError)
- async def validation_exception_handler(request: Request, exc: RequestValidationError):
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={
- "code": 400,
- "msg": "请求参数错误",
- "data": {
- "detail": exc.errors(),
- "body": exc.body
- }
- },
- )
- # 定义请求体
- class RecommendRequest(BaseModel):
- city_uuid: str # 城市id
- product_code: str # 卷烟编码
- recall_cust_count: int # 推荐的商户数量
- delivery_count: int # 投放的品规数量
-
- class ReportRequest(BaseModel):
- city_uuid: str # 城市id
- product_code: str # 卷烟编码
-
- @app.post("/brandcultivation/api/v1/recommend")
- async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks):
- gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
- if not os.path.exists(gbdtlr_model_path):
- return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}}
-
- # 初始化模型
- recommend_model = Recommend(request.city_uuid)
-
- # 判断该品规是否是新品规
- products_in_oreder = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
- if request.product_code in products_in_oreder:
- recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count)
- else:
- recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count)
- recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count)
- request_data = []
- for index, data in enumerate(recommend_data):
- id = index + 1
- request_data.append(
- {
- "id": id,
- "cust_code": data["cust_code"],
- "recommend_score": data["recommend_score"],
- "delivery_count": data["delivery_count"]
- }
- )
-
- # 异步执行报告生成任务
- backgroundTasks.add_task(
- generate_report,
- request.city_uuid,
- request.product_code,
- request.recall_cust_count,
- request.delivery_count
- )
-
- return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
- def generate_report(city_uuid, product_id, recall_count, delivery_count):
- """生成报告"""
- report_util = ReportUtils(city_uuid, product_id)
- report_util.generate_all_data(recall_count, delivery_count)
-
- REPORT_FILES = [
- "卷烟信息表.xlsx",
- "品规商户特征关系表.xlsx",
- "商户售卖推荐表.xlsx",
- "相似卷烟表.xlsx"
- ]
- @app.get("/brandcultivation/api/v1/download_report")
- async def get_report_files(city_uuid: str, product_code: str) -> JSONResponse:
- report_dir = os.path.join("./data/reports", city_uuid, product_code)
-
- # 检查报告是否存在
- if not os.path.exists(report_dir):
- return JSONResponse(
- status_code=200,
- content={
- "code": 200,
- "msg": "report directory not found",
- "data": {"reportDownloadInfo": "该品规报告还未生成!"}
- }
- )
-
- # 收集所有存在的文件
- available_files:List[Dict[str, str]] = []
- base_url = "http://127.0.0.1:7960"
-
- for filename in REPORT_FILES:
- file_path = os.path.join(report_dir, filename)
- if os.path.exists(file_path):
- file_url = f"{base_url}/brandcultivation/api/v1/download_file/{city_uuid}/{product_code}/{filename}"
- available_files.append({
- "filename": filename,
- "download_url": file_url
- })
-
- return JSONResponse(
- status_code=200,
- content={
- "code": 200,
- "msg": "success",
- "data": {
- "reportDownloadInfo": available_files
- }
- }
- )
-
- if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=7960)
|