from database import MySqlDao from fastapi import APIRouter, BackgroundTasks from .request_body import RecommendRequest from models import Recommend import os from utils import FileStreamUtils, ReportUtils dao = MySqlDao() router = APIRouter() @router.post("/recommend") async def recommend(request: RecommendRequest, backgroundTasks: BackgroundTasks): """推荐接口""" gbdtlr_model_path = os.path.join("./models/rank/weights", request.city_uuid, "gbdtlr_model.pkl") if not os.path.exists(gbdtlr_model_path): return {"code": 200, "msg": "model not defined", "data": {"recommendationInfo": "该城市的模型未训练,请先进行训练"}} # 初始化模型 recommend_model = Recommend(request.city_uuid) # 判断该品规是否是新品规 products_in_oreder = dao.get_product_from_order(request.city_uuid)["product_code"].unique().tolist() if request.product_code in products_in_oreder: recommend_list = recommend_model.get_recommend_list_by_gbdtlr(request.product_code, recall_count=request.recall_cust_count) else: recommend_list = recommend_model.get_recommend_list_by_item2vec(request.product_code, recall_count=request.recall_cust_count) recommend_data = recommend_model.get_recommend_and_delivery(recommend_list, delivery_count=request.delivery_count) request_data = [] for index, data in enumerate(recommend_data): id = index + 1 request_data.append( { "id": id, "cust_code": data["cust_code"], "recommend_score": data["recommend_score"], "delivery_count": data["delivery_count"] } ) # 异步执行报告生成任务 backgroundTasks.add_task( generate_and_upload_report, request ) return {"code": 200, "msg": "success", "data": {"recommendationInfo": request_data}} def generate_and_upload_report(request: RecommendRequest): """生成并上传报告到阿里云文件数据库""" # 生成相关报告 report_util = ReportUtils(request.city_uuid, request.product_code) report_util.generate_all_data(request.recall_cust_count, request.delivery_count) # 上传报告 reports_dir = os.path.join('./data/reports', request.city_uuid, request.product_code) report_files = [ '卷烟信息表', '品规商户特征关系表', '相似卷烟表', '商户售卖推荐表' ] file_id_map = FileStreamUtils.upload_files(reports_dir, report_files) # 将返回的file_id保存到数据库中 data_dict = { 'cultivacation_id': request.cultivacation_id, 'city_uuid': request.city_uuid, 'limit_cycle_name': request.limit_cycle_name, 'product_code': request.product_code, 'product_info_table': file_id_map.get('卷烟信息表'), 'relation_table': file_id_map.get('品规商户特征关系表'), 'similarity_product_table': file_id_map.get('相似卷烟表'), 'recommend_table': file_id_map.get('商户售卖推荐表'), } dao.insert_report(data_dict)