import gradio as gr from db import MongoDao import requests from PIL import Image from io import BytesIO from agent.agent import Agent import json agent = Agent() headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36", "Referer": "https://www.aliexpress.com/", "Accept-Language": "en-US,en;q=0.9", } dao = MongoDao("vbrand-ec") def get_merchant_list(): """ 返回商户列表,显示 title,存储 outId """ merchant_data = [item["title"] for item in dao.get_fields_data(["title"])] # merchant_dict = {m["title"]: m["outId"] for m in merchant_data} return merchant_data # def display_cust_info(title): # """ 根据商户名称返回对应的 outId """ # merchant_dict = get_merchant_list() # cust_id = merchant_dict.get(title, "-1") # if cust_id != '-1': # return get_cust_info(cust_id) # else: # return "未找到对应 ID", None def load_image(image_url): response = requests.get(image_url, headers=headers) image = Image.open(BytesIO(response.content)) return image def get_cust_info(title): record = dao.get_one_record_by_query({"title": title}) if record == None: return "title不正确", None res = f""" 商品名称:\t{record["title"]}\n 平台:\t{record["platFormName"]}\n 品牌:\t{record["brandName"]}\n 价格:\t{record["price"]}\n 链接:\t{record["url"]}\n """ image_url = record["image"][0] image = load_image(image_url) return res, image def check_infringement(title, brandname): """ 模拟侵权检测逻辑 """ if brandname not in title: record = dao.get_one_record_by_query({"title": title}) actual_brandname = record["brandName"] if actual_brandname not in brandname: key_word_falg = True else: key_word_falg = False key_word_judgement = json.loads(agent.brand_key_word_judgement(brandname, title)) key_word_falg = key_word_judgement["key_word_flag"] result = f""" 关键词引流: {key_word_falg} """ return result def search_by_cust_id(cust_id): if cust_id == '': return None, None else: return get_cust_info(cust_id) # merchant_dict = get_merchant_list() # merchant_list_titles = list(merchant_dict.keys()) merchant_list_titles = get_merchant_list() # 确保商户列表不为空 default_merchant = merchant_list_titles[0] if merchant_list_titles else None default_cust_info, default_image = get_cust_info(default_merchant) with gr.Blocks() as demo: gr.Markdown("## 侵权识别系统", elem_id="header") with gr.Row(): with gr.Column(): brand_state = gr.State(value="李宁") brand_dropdown = gr.Dropdown( ["李宁"], label="品牌选择", value="李宁", interactive=True) search_box = gr.Textbox(label="搜索商户") search_button = gr.Button("搜索") merchant_list = gr.Dropdown( merchant_list_titles, label="商户列表", value=default_merchant, # 设置默认值 interactive=True ) check_button = gr.Button("查询侵权") with gr.Column(): with gr.Row(): image_display = gr.Image(label="商品图片", interactive=False, type='pil', value=default_image) product_info = gr.Textbox( label="商品信息", interactive=False, value=default_cust_info # 预填充默认商户信息 ) infringement_result = gr.Textbox(label="侵权识别结果", interactive=False) # 事件绑定 brand_dropdown.change( fn=lambda x: x, inputs=brand_dropdown, outputs=brand_state ) search_button.click(search_by_cust_id, inputs=search_box, outputs=[product_info, image_display]) merchant_list.change(get_cust_info, inputs=merchant_list, outputs=[product_info, image_display]) check_button.click(check_infringement, inputs=[merchant_list, brand_state], outputs=infringement_result) demo.launch(server_name = "0.0.0.0", server_port = 7860)