Quellcode durchsuchen

异步存储报告

Sherlock vor 11 Monaten
Ursprung
Commit
ec1db7d838
3 geänderte Dateien mit 95 neuen und 12 gelöschten Zeilen
  1. 62 3
      api.py
  2. 32 8
      api_test.py
  3. 1 1
      utils/report_utils.py

+ 62 - 3
api.py

@@ -1,4 +1,4 @@
-from fastapi import FastAPI, HTTPException, Request, status
+from fastapi import FastAPI, Request, status, BackgroundTasks
 from fastapi.exceptions import RequestValidationError
 from fastapi.responses import JSONResponse
 from database.dao.mysql_dao import MySqlDao
@@ -7,6 +7,7 @@ import os
 from pydantic import BaseModel
 import uvicorn
 from utils import ReportUtils
+from typing import List, Dict
 
 app = FastAPI()
 dao = MySqlDao()
@@ -33,8 +34,12 @@ class RecommendRequest(BaseModel):
     recall_cust_count: int      # 推荐的商户数量
     delivery_count: int              # 投放的品规数量
     
+class ReportRequest(BaseModel):
+    city_uuid: str              # 城市id
+    product_code: str           # 卷烟编码
+    
 @app.post("/brandcultivation/api/v1/recommend")
-def recommend(request: RecommendRequest):
+async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks):
     gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
     if not os.path.exists(gbdtlr_model_path):
         return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}}
@@ -61,7 +66,14 @@ def recommend(request: RecommendRequest):
             }
         )
     
-    generate_report(request.city_uuid, request.product_code, request.recall_cust_count, request.delivery_count)
+    # 异步执行报告生成任务
+    backgroundTasks.add_task(
+        generate_report,
+        request.city_uuid,
+        request.product_code,
+        request.recall_cust_count,
+        request.delivery_count
+    )
     
     return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
 
@@ -69,6 +81,53 @@ def generate_report(city_uuid, product_id, recall_count, delivery_count):
     """生成报告"""
     report_util = ReportUtils(city_uuid, product_id)
     report_util.generate_all_data(recall_count, delivery_count)
+    
+REPORT_FILES = [
+    "卷烟信息表.xlsx",
+    "品规商户特征关系表.xlsx",
+    "商户售卖推荐表.xlsx",
+    "相似卷烟表.xlsx"
+]
+
+@app.get("/brandcultivation/api/v1/download_report")
+async def get_report_files(city_uuid: str, product_code: str) -> JSONResponse:
+    report_dir = os.path.join("./data/reports", city_uuid, product_code)
+    
+    # 检查报告是否存在
+    if not os.path.exists(report_dir):
+        return JSONResponse(
+            status_code=200,
+            content={
+                "code": 200,
+                "msg": "report directory not found",
+                "data": {"reportDownloadInfo": "该品规报告还未生成!"}
+            }
+        )
+    
+    # 收集所有存在的文件
+    available_files:List[Dict[str, str]] = []
+    base_url = "http://127.0.0.1:7960"
+    
+    for filename in REPORT_FILES:
+        file_path = os.path.join(report_dir, filename)
+        if os.path.exists(file_path):
+            file_url = f"{base_url}/brandcultivation/api/v1/download_file/{city_uuid}/{product_code}/{filename}"
+            available_files.append({
+                "filename": filename,
+                "download_url": file_url
+            })
+            
+    return JSONResponse(
+        status_code=200,
+        content={
+            "code": 200,
+            "msg": "success",
+            "data": {
+                "reportDownloadInfo": available_files
+            }
+        }
+    )
+    
 
 if __name__ == "__main__":
     uvicorn.run(app, host="0.0.0.0", port=7960)

+ 32 - 8
api_test.py

@@ -1,14 +1,38 @@
 import requests
 import json
 
-url = "http://127.0.0.1:7960/brandcultivation/api/v1/recommend"
-payload = {
+# url = "http://127.0.0.1:7960/brandcultivation/api/v1/recommend"
+# payload = {
+#     "city_uuid": "00000000000000000000000011445301",
+#     "product_code": "440298",
+#     "recall_cust_count": 500,
+#     "delivery_count": 5000
+# }
+# headers = {'Content-Type': 'application/json'}
+
+# response = requests.post(url, data=json.dumps(payload), headers=headers)
+# print(response.json())
+
+# 2. 然后调用报告下载接口
+download_url = "http://127.0.0.1:7960/brandcultivation/api/v1/download_report"
+download_payload = {
     "city_uuid": "00000000000000000000000011445301",
-    "product_code": "440298",
-    "recall_cust_count": 500,
-    "delivery_count": 5000
+    "product_code": "440298"
 }
-headers = {'Content-Type': 'application/json'}
+download_headers = {'Content-Type': 'application/json'}
+
+print("\n调用报告下载接口...")
+download_response = requests.get(
+    download_url,
+    params=download_payload,  # 注意GET请求使用params而不是data
+    headers=download_headers
+)
+print(json.dumps(download_response.json(), indent=2))
 
-response = requests.post(url, data=json.dumps(payload), headers=headers)
-print(response.json())
+if download_response.json().get("code") == 200:
+    for file_info in download_response.json()["data"]["reportDownloadInfo"]:
+        print(f"\n下载文件: {file_info['filename']}")
+        file_response = requests.get(file_info["download_url"])
+        with open(file_info["filename"], "wb") as f:
+            f.write(file_response.content)
+        print(f"已保存到: {file_info['filename']}")

+ 1 - 1
utils/report_utils.py

@@ -132,4 +132,4 @@ class ReportUtils:
         self.generate_product_report()
         self.generate_recommend_report(recall_count, delivery_count)
         self.generate_similarity_product_report()
-        self.generate_eval_data()
+        # self.generate_eval_data()