| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- from database import MySqlDao
- from fastapi import APIRouter, BackgroundTasks
- from .request_body import RecommendRequest
- from models import Recommend
- import os
- from utils import FileStreamUtils, ReportUtils
- dao = MySqlDao()
- router = APIRouter()
- @router.post("/recommend")
- async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks):
- """推荐接口"""
- gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl")
- if not os.path.exists(gbdtlr_model_path):
- return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}}
-
- # 初始化模型
- recommend_model = Recommend(request.city_uuid)
-
- # 判断该品规是否是新品规
- products_in_oreder = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist()
- if request.product_code in products_in_oreder:
- recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count)
- else:
- recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count)
- recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count)
- request_data = []
- for index, data in enumerate(recommend_data):
- id = index + 1
- request_data.append(
- {
- "id": id,
- "cust_code": data["cust_code"],
- "recommend_score": data["recommend_score"],
- "delivery_count": data["delivery_count"]
- }
- )
-
- # 异步执行报告生成任务
- backgroundTasks.add_task(
- generate_and_upload_report,
- request
- )
-
- return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}}
- def generate_and_upload_report(request: RecommendRequest):
- """生成并上传报告到阿里云文件数据库"""
- # 生成相关报告
- report_util = ReportUtils(request.city_uuid, request.product_code)
- report_util.generate_all_data(request.recall_cust_count, request.delivery_count)
-
- # 上传报告
- reports_dir = os.path.join('./data/reports', request.city_uuid, request.product_code)
- report_files = [
- '卷烟信息表',
- '品规商户特征关系表',
- '相似卷烟表',
- '商户售卖推荐表'
- ]
- file_id_map = FileStreamUtils.upload_files(reports_dir, report_files)
-
- # 将返回的file_id保存到数据库中
- data_dict = {
- 'cultivacation_id': request.cultivacation_id,
- 'city_uuid': request.city_uuid,
- 'limit_cycle_name': request.limit_cycle_name,
- 'product_code': request.product_code,
- 'product_info_table': file_id_map.get('卷烟信息表'),
- 'relation_table': file_id_map.get('品规商户特征关系表'),
- 'similarity_product_table': file_id_map.get('相似卷烟表'),
- 'recommend_table': file_id_map.get('商户售卖推荐表'),
- }
- dao.insert_report(data_dict)
|