api.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. from config import load_service_config
  2. from fastapi import FastAPI, Request, status, BackgroundTasks, HTTPException
  3. from fastapi.exceptions import RequestValidationError
  4. from fastapi.responses import JSONResponse
  5. from database.dao.mysql_dao import MySqlDao
  6. from models import Recommend
  7. import os
  8. from pydantic import BaseModel
  9. import uvicorn
  10. from utils import ReportUtils
  11. import requests
  12. from typing import List, Dict
  13. app = FastAPI()
  14. dao = MySqlDao()
  15. cfgs = load_service_config()
  16. # 添加全局异常处理器
  17. @app.exception_handler(RequestValidationError)
  18. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  19. return JSONResponse(
  20. status_code=status.HTTP_400_BAD_REQUEST,
  21. content={
  22. "code": 400,
  23. "msg": "请求参数错误",
  24. "data": {
  25. "detail": exc.errors(),
  26. "body": exc.body
  27. }
  28. },
  29. )
  30. # 定义请求体
  31. class RecommendRequest(BaseModel):
  32. city_uuid: str # 城市id
  33. product_code: str # 卷烟编码
  34. recall_cust_count: int # 推荐的商户数量
  35. delivery_count: int # 投放的品规数量
  36. class ReportRequest(BaseModel):
  37. city_uuid: str # 城市id
  38. product_code: str # 卷烟编码
  39. class EvalReportRequest(BaseModel):
  40. city_uuid: str # 城市id
  41. product_code: str # 卷烟编码
  42. start_time: str # 开始投放时间
  43. end_time: str # 结束投放时间
  44. @app.post("/brandcultivation/api/v1/recommend")
  45. async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks):
  46. gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
  47. if not os.path.exists(gbdtlr_model_path):
  48. return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}}
  49. # 初始化模型
  50. recommend_model = Recommend(request.city_uuid)
  51. # 判断该品规是否是新品规
  52. products_in_oreder = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
  53. if request.product_code in products_in_oreder:
  54. recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count)
  55. else:
  56. recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count)
  57. recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count)
  58. request_data = []
  59. for index, data in enumerate(recommend_data):
  60. id = index + 1
  61. request_data.append(
  62. {
  63. "id": id,
  64. "cust_code": data["cust_code"],
  65. "recommend_score": data["recommend_score"],
  66. "delivery_count": data["delivery_count"]
  67. }
  68. )
  69. # 异步执行报告生成任务
  70. backgroundTasks.add_task(
  71. generate_report,
  72. request.city_uuid,
  73. request.product_code,
  74. request.recall_cust_count,
  75. request.delivery_count
  76. )
  77. return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
  78. def generate_report(city_uuid, product_id, recall_count, delivery_count):
  79. """生成报告"""
  80. report_util = ReportUtils(city_uuid, product_id)
  81. report_util.generate_all_data(recall_count, delivery_count)
  82. repots_dir = os.path.join('./data/reports', city_uuid, product_id)
  83. upload_file(repots_dir)
  84. def upload_file(reports_dir):
  85. """上传报告文件"""
  86. base_url = cfgs["aliyun"]["upload_url"]
  87. files = [
  88. "卷烟信息表.xlsx",
  89. "品规商户特征关系表.xlsx",
  90. "相似卷烟表.xlsx"
  91. ]
  92. # 设置请求头
  93. headers = {
  94. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
  95. "Accept": "*/*",
  96. }
  97. # 设置Cookie
  98. cookies = {
  99. "expires_in": "10800000",
  100. "ecp_token": "ZXlKamIyUmxJam9pT1dZNFltRTNOVGhrTW1WbVpUYzRaR05oWkRZME5ERTRPREU1TkRObU9EY2lMQ0p6WTI5d1pTSTZJbkJ5YjNSbFkzUmxaQ0lzSW1Oc2FXVnVkQ0k2SW1Oc2FXVnVkRjlzYjI1bmFta2lmUT09",
  101. "acw_tc": "0a067c4317466919842786603e2653311180a2a4fa0fd05acb99cf0458b890",
  102. "isg": "BDAwNyXyI9cQD__-0PkIodjTAfiCeRTD-vUtYyqmIw3d5bGP0Y0nUayWOe2F3syb",
  103. "dd-ztna-token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VySWQiOiIwNTAwMzQzMDM4MjYxNDMwOTQiLCJjb3JwSWQiOiJkaW5nYzc2YTJhMzdiMjRhODYwYTRhYzVkNjk4MDg2NGQzMzUiLCJkZXB0SWRzIjoiW1s4Mzk0NzIxNjAsODM2MjQ1MTUzLDFdXSIsInJvbGVJZHMiOiJbMjEwNDk3MDY1OF0iLCJ1bmlvbklkIjoiVXFyYTB4U0VQcWNxdFNhQTFqSzRzZ2lFaUUiLCJleHAiOjE3NDcxODQ3MjksImlhdCI6MTc0NjU3OTkyOX0.dfv612-LwnIdoKL2G73gg7LYy8SBmvr3Zaan97Q5wGUbEFdWw0JUqOQQ1jdeom_Nd9FNCHlkZM32DvwyUrNnvbg1QQy2JeYEpAgysG4h0MT_OghE6-xGVQBIkg72GPTo_cvdMYG9SMfaCo5H-73zFfwMFASFoXCDoIPha6NioIskOJMmvQVsDkHtRXYh_gv0XaJxSWirDWhKC9vxPGaIwDff8doHwPdi9uO-tO9LFy9RXdyIsBXWem31rBSD3D6FmqZLZjOOZhCKMym1VenfIKC10Oa1zm8-Y8bGyMHG0LO_68AJstKYT4alJoBVDHXpMp3zvSXXQB6da_fIthQD4A"
  104. }
  105. files_id = {}
  106. for file in files:
  107. file_path = os.path.join(reports_dir, file)
  108. try:
  109. with open(file_path, 'rb') as f:
  110. files = {'file': (os.path.basename(file_path), f)}
  111. response = requests.post(
  112. base_url,
  113. headers=headers,
  114. files=files,
  115. verify=True
  116. )
  117. if response.json()["success"]:
  118. file_id = response.json()["data"]["file_info"]["fileid"]
  119. files_id[os.path.basename(file_path).split('.')[0]] = file_id
  120. except requests.exceptions.RequestException as e:
  121. print("请求出错:", e)
  122. except Exception as e:
  123. print("发生错误:", e)
  124. files_id_str = ""
  125. if files_id:
  126. for filename, file_id in files_id.items():
  127. files_id_str += f"{filename},{file_id}\n"
  128. else:
  129. files_id_str = "failed"
  130. with open(os.path.join(reports_dir, "files_id.txt"), 'w', encoding="utf-8") as file:
  131. file.write(files_id_str)
  132. @app.post("/brandcultivation/api/v1/report")
  133. async def get_file_id(request: ReportRequest):
  134. files_id_path = os.path.join("./data/reports", request.city_uuid, request.product_code, "files_id.txt")
  135. if not os.path.exists(files_id_path):
  136. raise HTTPException(
  137. status_code=status.HTTP_404_NOT_FOUND,
  138. detail="Reports not found"
  139. )
  140. with open(files_id_path, 'r', encoding="utf-8") as file:
  141. lines = file.readlines()
  142. if lines[0].strip() == "failed":
  143. raise HTTPException(
  144. status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
  145. detail="reports upload failed"
  146. )
  147. request_data = []
  148. for index, line in enumerate(lines):
  149. filename, file_id = line.strip().split(",")
  150. request_data.append(
  151. {
  152. "id": index+1,
  153. "filename": filename,
  154. "file_id": file_id
  155. }
  156. )
  157. return {"code": 200, "msg": "success", "data": {"reportInfo": request_data}}
  158. @app.post("/brandcultivation/api/v1/eval_report")
  159. async def get_eval_report(request: EvalReportRequest):
  160. """获取验证报告"""
  161. reports_dir = os.path.join('./data/reports', request.city_uuid, request.product_code)
  162. # 首先生成验证报告
  163. report_util = ReportUtils(request.city_uuid, request.product_code)
  164. report_util.generate_eval_data(request.start_time, request.end_time)
  165. # 其次上传验证报告到阿里云
  166. base_url = cfgs["aliyun"]["upload_url"]
  167. headers = {
  168. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
  169. "Accept": "*/*",
  170. }
  171. cookies = {
  172. "expires_in": "10800000",
  173. "ecp_token": "ZXlKamIyUmxJam9pT1dZNFltRTNOVGhrTW1WbVpUYzRaR05oWkRZME5ERTRPREU1TkRObU9EY2lMQ0p6WTI5d1pTSTZJbkJ5YjNSbFkzUmxaQ0lzSW1Oc2FXVnVkQ0k2SW1Oc2FXVnVkRjlzYjI1bmFta2lmUT09",
  174. "acw_tc": "0a067c4317466919842786603e2653311180a2a4fa0fd05acb99cf0458b890",
  175. "isg": "BDAwNyXyI9cQD__-0PkIodjTAfiCeRTD-vUtYyqmIw3d5bGP0Y0nUayWOe2F3syb",
  176. "dd-ztna-token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VySWQiOiIwNTAwMzQzMDM4MjYxNDMwOTQiLCJjb3JwSWQiOiJkaW5nYzc2YTJhMzdiMjRhODYwYTRhYzVkNjk4MDg2NGQzMzUiLCJkZXB0SWRzIjoiW1s4Mzk0NzIxNjAsODM2MjQ1MTUzLDFdXSIsInJvbGVJZHMiOiJbMjEwNDk3MDY1OF0iLCJ1bmlvbklkIjoiVXFyYTB4U0VQcWNxdFNhQTFqSzRzZ2lFaUUiLCJleHAiOjE3NDcxODQ3MjksImlhdCI6MTc0NjU3OTkyOX0.dfv612-LwnIdoKL2G73gg7LYy8SBmvr3Zaan97Q5wGUbEFdWw0JUqOQQ1jdeom_Nd9FNCHlkZM32DvwyUrNnvbg1QQy2JeYEpAgysG4h0MT_OghE6-xGVQBIkg72GPTo_cvdMYG9SMfaCo5H-73zFfwMFASFoXCDoIPha6NioIskOJMmvQVsDkHtRXYh_gv0XaJxSWirDWhKC9vxPGaIwDff8doHwPdi9uO-tO9LFy9RXdyIsBXWem31rBSD3D6FmqZLZjOOZhCKMym1VenfIKC10Oa1zm8-Y8bGyMHG0LO_68AJstKYT4alJoBVDHXpMp3zvSXXQB6da_fIthQD4A"
  177. }
  178. eval_file_path = os.path.join(reports_dir, "投放验证报告.xlsx")
  179. try:
  180. with open(eval_file_path, 'rb') as f:
  181. files = {'file': (os.path.basename(eval_file_path), f)}
  182. response = requests.post(
  183. base_url,
  184. headers=headers,
  185. files=files,
  186. verify=True
  187. )
  188. if response.json()["success"]:
  189. file_id = response.json()["data"]["file_info"]["fileid"]
  190. return {"code": 200, "msg": "success", "data": {"reportInfo": {"id": 1, "filename": "投放验证报告", "file_id": file_id}}}
  191. else:
  192. raise HTTPException(
  193. status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
  194. detail="reports upload failed"
  195. )
  196. except requests.exceptions.RequestException as e:
  197. print("请求出错:", e)
  198. except Exception as e:
  199. print("发生错误:", e)
  200. # 最后将file_id返回给前端
  201. if __name__ == "__main__":
  202. uvicorn.run(app, host="0.0.0.0", port=7960)
  203. # report_dir = "./data/reports/00000000000000000000000011445301/440298"
  204. # upload_file(report_dir)