api.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from config import load_service_config
  2. from fastapi import FastAPI, Request, status, BackgroundTasks, HTTPException
  3. from fastapi.exceptions import RequestValidationError
  4. from fastapi.responses import JSONResponse
  5. from database.dao.mysql_dao import MySqlDao
  6. from models import Recommend
  7. import os
  8. from pydantic import BaseModel
  9. import uvicorn
  10. from utils import ReportUtils
  11. import requests
  12. from typing import List, Dict
  13. app = FastAPI()
  14. dao = MySqlDao()
  15. cfgs = load_service_config()
  16. # 添加全局异常处理器
  17. @app.exception_handler(RequestValidationError)
  18. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  19. return JSONResponse(
  20. status_code=status.HTTP_400_BAD_REQUEST,
  21. content={
  22. "code": 400,
  23. "msg": "请求参数错误",
  24. "data": {
  25. "detail": exc.errors(),
  26. "body": exc.body
  27. }
  28. },
  29. )
  30. # 定义请求体
  31. class RecommendRequest(BaseModel):
  32. city_uuid: str # 城市id
  33. product_code: str # 卷烟编码
  34. recall_cust_count: int # 推荐的商户数量
  35. delivery_count: int # 投放的品规数量
  36. class ReportRequest(BaseModel):
  37. city_uuid: str # 城市id
  38. product_code: str # 卷烟编码
  39. @app.post("/brandcultivation/api/v1/recommend")
  40. async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks):
  41. gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
  42. if not os.path.exists(gbdtlr_model_path):
  43. return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}}
  44. # 初始化模型
  45. recommend_model = Recommend(request.city_uuid)
  46. # 判断该品规是否是新品规
  47. products_in_oreder = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
  48. if request.product_code in products_in_oreder:
  49. recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count)
  50. else:
  51. recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count)
  52. recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count)
  53. request_data = []
  54. for index, data in enumerate(recommend_data):
  55. id = index + 1
  56. request_data.append(
  57. {
  58. "id": id,
  59. "cust_code": data["cust_code"],
  60. "recommend_score": data["recommend_score"],
  61. "delivery_count": data["delivery_count"]
  62. }
  63. )
  64. # 异步执行报告生成任务
  65. backgroundTasks.add_task(
  66. generate_report,
  67. request.city_uuid,
  68. request.product_code,
  69. request.recall_cust_count,
  70. request.delivery_count
  71. )
  72. return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
  73. def generate_report(city_uuid, product_id, recall_count, delivery_count):
  74. """生成报告"""
  75. report_util = ReportUtils(city_uuid, product_id)
  76. report_util.generate_all_data(recall_count, delivery_count)
  77. repots_dir = os.path.join('./data/reports', city_uuid, product_id)
  78. upload_file(repots_dir)
  79. def upload_file(reports_dir):
  80. """上传报告文件"""
  81. base_url = cfgs["aliyun"]["upload_url"]
  82. files = [
  83. "卷烟信息表.xlsx",
  84. "品规商户特征关系表.xlsx",
  85. "相似卷烟表.xlsx"
  86. ]
  87. # 设置请求头
  88. headers = {
  89. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
  90. "Accept": "*/*",
  91. }
  92. # 设置Cookie
  93. cookies = {
  94. "ecp_token": "ZXlKamIyUmxJam9pTURZMVltUXpZbUV5T0dWaFl6ZzRPREJqTkRjNU5ERXpaV0k0T1dRd1pERWlMQ0p6WTI5d1pTSTZJbkJ5YjNSbFkzUmxaQ0lzSW1Oc2FXVnVkQ0k2SW1Oc2FXVnVkRjlzYjI1bmFta2lmUT09",
  95. "acw_tc": "0a03176c17452016143048775e73c8e9556c4c3ee2dbc80f405a9314a5d0e5",
  96. "cna": "09T8H5zH8SoCASeqHG9RyLbu",
  97. "isg": "BDAwNyXyI9cQD__-0PkIodjTAfiCeRTD-vUtYyqmIw3d5bGP0Y0nUayWOe2F3syb",
  98. "dd-ztna-token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VySWQiOiIwNTAwMzQzMDM4MjYxNDMwOTQiLCJjb3JwSWQiOiJkaW5nYzc2YTJhMzdiMjRhODYwYTRhYzVkNjk4MDg2NGQzMzUiLCJkZXB0SWRzIjoiW1s4Mzk0NzIxNjAsODM2MjQ1MTUzLDFdXSIsInJvbGVJZHMiOiJbMjEwNDk3MDY1OF0iLCJ1bmlvbklkIjoiVXFyYTB4U0VQcWNxdFNhQTFqSzRzZ2lFaUUiLCJleHAiOjE3NDU4MDY0MjIsImlhdCI6MTc0NTIwMTYyMn0.S0K3UWF0fBP5fF6TQaUkxe8I7piSQjrXqG3DdviqTzdO9Y0J7wsydcRvfp-OuHUkXbY92diIBGRHBr5Bb0eYgmCZ946Q8CHk12WXWvJmMDPrQ4i2C1J7W7uzUuPYJiWQLGQd40iCxNKfZIGq2ML-6r1i_k0iA7L_f6En1e9b5gLIs0GaluUwhCTTpAHBsmb-FquQER2Um9igARbMT6aaNrJzQbcPPWntQ3Pz_65PgjtTUMiFtBZW-YeAm2iB9JtYo_SJUsoX0d_oYo-6D4IfgZtkYKZZxbLj5rWnhoaJHm5acELAO8otQNZYQra0PfRIaKTT1vNkr1L5rShIxZk3yQ"
  99. }
  100. files_id = {}
  101. for file in files:
  102. file_path = os.path.join(reports_dir, file)
  103. try:
  104. with open(file_path, 'rb') as f:
  105. files = {'file': (os.path.basename(file_path), f)}
  106. response = requests.post(
  107. base_url,
  108. headers=headers,
  109. files=files,
  110. verify=True
  111. )
  112. if response.json()["success"]:
  113. file_id = response.json()["data"]["file_info"]["fileid"]
  114. files_id[os.path.basename(file_path).split('.')[0]] = file_id
  115. except requests.exceptions.RequestException as e:
  116. print("请求出错:", e)
  117. except Exception as e:
  118. print("发生错误:", e)
  119. files_id_str = ""
  120. if files_id:
  121. for filename, file_id in files_id.items():
  122. files_id_str += f"{filename},{file_id}\n"
  123. else:
  124. files_id_str = "failed"
  125. with open(os.path.join(reports_dir, "files_id.txt"), 'w', encoding="utf-8") as file:
  126. file.write(files_id_str)
  127. @app.post("/brandcultivation/api/v1/report")
  128. async def get_file_id(request: ReportRequest):
  129. files_id_path = os.path.join("./data/reports", request.city_uuid, request.product_code, "files_id.txt")
  130. if not os.path.exists(files_id_path):
  131. raise HTTPException(
  132. status_code=status.HTTP_404_NOT_FOUND,
  133. detail="Reports not found"
  134. )
  135. with open(files_id_path, 'r', encoding="utf-8") as file:
  136. lines = file.readlines()
  137. if lines[0].strip() == "failed":
  138. raise HTTPException(
  139. status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
  140. detail="reports upload failed"
  141. )
  142. request_data = []
  143. for index, line in enumerate(lines):
  144. filename, file_id = line.strip().split(",")
  145. request_data.append(
  146. {
  147. "id": index+1,
  148. "filename": filename,
  149. "file_id": file_id
  150. }
  151. )
  152. return {"code": 200, "msg": "success", "data": {"reportInfo": request_data}}
  153. if __name__ == "__main__":
  154. uvicorn.run(app, host="0.0.0.0", port=7960)
  155. # report_dir = "./data/reports/00000000000000000000000011445301/440298"
  156. # upload_file(report_dir)