|
|
@@ -1,34 +1,91 @@
|
|
|
import gradio as gr
|
|
|
from db import MongoDao
|
|
|
+import requests
|
|
|
+from PIL import Image
|
|
|
+from io import BytesIO
|
|
|
+from agent.agent import Agent
|
|
|
+import json
|
|
|
|
|
|
-dao = MongoDao("obrand-ec")
|
|
|
+agent = Agent()
|
|
|
+
|
|
|
+headers = {
|
|
|
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36",
|
|
|
+ "Referer": "https://www.aliexpress.com/",
|
|
|
+ "Accept-Language": "en-US,en;q=0.9",
|
|
|
+}
|
|
|
+
|
|
|
+dao = MongoDao("vbrand-ec")
|
|
|
|
|
|
def get_merchant_list():
|
|
|
""" 返回商户列表,显示 title,存储 outId """
|
|
|
- merchant_data = dao.get_fields_data(["outId", "title"])[:50]
|
|
|
- merchant_dict = {m["title"]: m["outId"] for m in merchant_data}
|
|
|
- return merchant_dict
|
|
|
+ merchant_data = [item["title"] for item in dao.get_fields_data(["title"])]
|
|
|
+ # merchant_dict = {m["title"]: m["outId"] for m in merchant_data}
|
|
|
+ return merchant_data
|
|
|
|
|
|
-def display_outid(title):
|
|
|
- """ 根据商户名称返回对应的 outId """
|
|
|
- merchant_dict = get_merchant_list()
|
|
|
- return merchant_dict.get(title, "未找到对应 ID")
|
|
|
+# def display_cust_info(title):
|
|
|
+# """ 根据商户名称返回对应的 outId """
|
|
|
+# merchant_dict = get_merchant_list()
|
|
|
+# cust_id = merchant_dict.get(title, "-1")
|
|
|
+# if cust_id != '-1':
|
|
|
+# return get_cust_info(cust_id)
|
|
|
+# else:
|
|
|
+# return "未找到对应 ID", None
|
|
|
+
|
|
|
+def load_image(image_url):
|
|
|
+ response = requests.get(image_url, headers=headers)
|
|
|
+ image = Image.open(BytesIO(response.content))
|
|
|
+ return image
|
|
|
|
|
|
-def check_infringement(merchant):
|
|
|
+def get_cust_info(title):
|
|
|
+ record = dao.get_one_record_by_query({"title": title})
|
|
|
+ if record == None:
|
|
|
+ return "title不正确", None
|
|
|
+ res = f"""
|
|
|
+ 商品名称:\t{record["title"]}\n
|
|
|
+ 平台:\t{record["platFormName"]}\n
|
|
|
+ 品牌:\t{record["brandName"]}\n
|
|
|
+ 价格:\t{record["price"]}\n
|
|
|
+ 链接:\t{record["url"]}\n
|
|
|
+ """
|
|
|
+ image_url = record["image"][0]
|
|
|
+ image = load_image(image_url)
|
|
|
+ return res, image
|
|
|
+
|
|
|
+def check_infringement(title, brandname):
|
|
|
""" 模拟侵权检测逻辑 """
|
|
|
- return f"商户 {merchant} 的侵权检测结果:未发现侵权"
|
|
|
+ if brandname not in title:
|
|
|
+ title = brandname + title
|
|
|
+
|
|
|
+ key_word_judgement = json.loads(agent.brand_key_word_judgement(brandname, title))
|
|
|
+
|
|
|
+ result = f"""
|
|
|
+ 关键词引流: {key_word_judgement["key_word_flag"]}
|
|
|
+
|
|
|
+ """
|
|
|
+ return result
|
|
|
+
|
|
|
+def search_by_cust_id(cust_id):
|
|
|
+ if cust_id == '':
|
|
|
+ return None, None
|
|
|
+ else:
|
|
|
+ return get_cust_info(cust_id)
|
|
|
+
|
|
|
+# merchant_dict = get_merchant_list()
|
|
|
+# merchant_list_titles = list(merchant_dict.keys())
|
|
|
+merchant_list_titles = get_merchant_list()
|
|
|
|
|
|
-merchant_dict = get_merchant_list()
|
|
|
-merchant_list_titles = list(merchant_dict.keys())
|
|
|
+# 确保商户列表不为空
|
|
|
+default_merchant = merchant_list_titles[0] if merchant_list_titles else None
|
|
|
+default_cust_info, default_image = get_cust_info(default_merchant)
|
|
|
|
|
|
with gr.Blocks() as demo:
|
|
|
gr.Markdown("## 侵权识别系统", elem_id="header")
|
|
|
|
|
|
with gr.Row():
|
|
|
- # 左侧部分
|
|
|
with gr.Column():
|
|
|
+ brand_state = gr.State(value="李宁")
|
|
|
brand_dropdown = gr.Dropdown(
|
|
|
- ["李宁", "耐克", "阿迪达斯", "彪马"],
|
|
|
+ ["李宁"],
|
|
|
label="品牌选择",
|
|
|
value="李宁",
|
|
|
interactive=True)
|
|
|
@@ -39,21 +96,31 @@ with gr.Blocks() as demo:
|
|
|
merchant_list = gr.Dropdown(
|
|
|
merchant_list_titles,
|
|
|
label="商户列表",
|
|
|
+ value=default_merchant, # 设置默认值
|
|
|
interactive=True
|
|
|
)
|
|
|
|
|
|
check_button = gr.Button("查询侵权")
|
|
|
|
|
|
- # 右侧部分
|
|
|
with gr.Column():
|
|
|
with gr.Row():
|
|
|
- image_display = gr.Image(label="商品图片", interactive=False)
|
|
|
- product_info = gr.Textbox(label="商品信息", interactive=False)
|
|
|
+ image_display = gr.Image(label="商品图片", interactive=False, type='pil', value=default_image)
|
|
|
+ product_info = gr.Textbox(
|
|
|
+ label="商品信息",
|
|
|
+ interactive=False,
|
|
|
+ value=default_cust_info # 预填充默认商户信息
|
|
|
+ )
|
|
|
|
|
|
infringement_result = gr.Textbox(label="侵权识别结果", interactive=False)
|
|
|
|
|
|
# 事件绑定
|
|
|
- merchant_list.change(display_outid, inputs=merchant_list, outputs=product_info)
|
|
|
- check_button.click(check_infringement, inputs=merchant_list, outputs=infringement_result)
|
|
|
+ brand_dropdown.change(
|
|
|
+ fn=lambda x: x,
|
|
|
+ inputs=brand_dropdown,
|
|
|
+ outputs=brand_state
|
|
|
+ )
|
|
|
+ search_button.click(search_by_cust_id, inputs=search_box, outputs=[product_info, image_display])
|
|
|
+ merchant_list.change(get_cust_info, inputs=merchant_list, outputs=[product_info, image_display])
|
|
|
+ check_button.click(check_infringement, inputs=[merchant_list, brand_state], outputs=infringement_result)
|
|
|
|
|
|
-demo.launch(share=True)
|
|
|
+demo.launch(server_name = "0.0.0.0", server_port = 7860)
|