from fastapi import FastAPI, Request, status, BackgroundTasks from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from database.dao.mysql_dao import MySqlDao from models import Recommend import os from pydantic import BaseModel import uvicorn from utils import ReportUtils from typing import List, Dict app = FastAPI() dao = MySqlDao() # 添加全局异常处理器 @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={ "code": 400, "msg": "请求参数错误", "data": { "detail": exc.errors(), "body": exc.body } }, ) # 定义请求体 class RecommendRequest(BaseModel): city_uuid: str # 城市id product_code: str # 卷烟编码 recall_cust_count: int # 推荐的商户数量 delivery_count: int # 投放的品规数量 class ReportRequest(BaseModel): city_uuid: str # 城市id product_code: str # 卷烟编码 @app.post("/brandcultivation/api/v1/recommend") async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks): gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl") if not os.path.exists(gbdtlr_model_path): return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}} # 初始化模型 recommend_model = Recommend(request.city_uuid) # 判断该品规是否是新品规 products_in_oreder = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist() if request.product_code in products_in_oreder: recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count) else: recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count) recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count) request_data = [] for index, data in enumerate(recommend_data): id = index + 1 request_data.append( { "id": id, "cust_code": data["cust_code"], "recommend_score": data["recommend_score"], "delivery_count": data["delivery_count"] } ) # 异步执行报告生成任务 backgroundTasks.add_task( generate_report, request.city_uuid, request.product_code, request.recall_cust_count, request.delivery_count ) return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}} def generate_report(city_uuid, product_id, recall_count, delivery_count): """生成报告""" report_util = ReportUtils(city_uuid, product_id) report_util.generate_all_data(recall_count, delivery_count) REPORT_FILES = [ "卷烟信息表.xlsx", "品规商户特征关系表.xlsx", "商户售卖推荐表.xlsx", "相似卷烟表.xlsx" ] @app.get("/brandcultivation/api/v1/download_report") async def get_report_files(city_uuid: str, product_code: str) -> JSONResponse: report_dir = os.path.join("./data/reports", city_uuid, product_code) # 检查报告是否存在 if not os.path.exists(report_dir): return JSONResponse( status_code=200, content={ "code": 200, "msg": "report directory not found", "data": {"reportDownloadInfo": "该品规报告还未生成!"} } ) # 收集所有存在的文件 available_files:List[Dict[str, str]] = [] base_url = "http://127.0.0.1:7960" for filename in REPORT_FILES: file_path = os.path.join(report_dir, filename) if os.path.exists(file_path): file_url = f"{base_url}/brandcultivation/api/v1/download_file/{city_uuid}/{product_code}/{filename}" available_files.append({ "filename": filename, "download_url": file_url }) return JSONResponse( status_code=200, content={ "code": 200, "msg": "success", "data": { "reportDownloadInfo": available_files } } ) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7960)