webui.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import gradio as gr
  2. from db import MongoDao
  3. import requests
  4. from PIL import Image
  5. from io import BytesIO
  6. from agent.agent import Agent
  7. import json
  8. agent = Agent()
  9. headers = {
  10. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36",
  11. "Referer": "https://www.aliexpress.com/",
  12. "Accept-Language": "en-US,en;q=0.9",
  13. }
  14. dao = MongoDao("vbrand-ec")
  15. def get_merchant_list():
  16. """ 返回商户列表,显示 title,存储 outId """
  17. merchant_data = [item["title"] for item in dao.get_fields_data(["title"])]
  18. # merchant_dict = {m["title"]: m["outId"] for m in merchant_data}
  19. return merchant_data
  20. # def display_cust_info(title):
  21. # """ 根据商户名称返回对应的 outId """
  22. # merchant_dict = get_merchant_list()
  23. # cust_id = merchant_dict.get(title, "-1")
  24. # if cust_id != '-1':
  25. # return get_cust_info(cust_id)
  26. # else:
  27. # return "未找到对应 ID", None
  28. def load_image(image_url):
  29. response = requests.get(image_url, headers=headers)
  30. image = Image.open(BytesIO(response.content))
  31. return image
  32. def get_cust_info(title):
  33. record = dao.get_one_record_by_query({"title": title})
  34. if record == None:
  35. return "title不正确", None
  36. res = f"""
  37. 商品名称:\t{record["title"]}\n
  38. 平台:\t{record["platFormName"]}\n
  39. 品牌:\t{record["brandName"]}\n
  40. 价格:\t{record["price"]}\n
  41. 链接:\t{record["url"]}\n
  42. """
  43. image_url = record["image"][0]
  44. image = load_image(image_url)
  45. return res, image
  46. def check_infringement(title, brandname):
  47. """ 模拟侵权检测逻辑 """
  48. if brandname not in title:
  49. record = dao.get_one_record_by_query({"title": title})
  50. actual_brandname = record["brandName"]
  51. if actual_brandname not in brandname:
  52. key_word_falg = True
  53. else:
  54. key_word_falg = False
  55. key_word_judgement = json.loads(agent.brand_key_word_judgement(brandname, title))
  56. key_word_falg = key_word_judgement["key_word_flag"]
  57. result = f"""
  58. 关键词引流: {key_word_falg}
  59. """
  60. return result
  61. def search_by_cust_id(cust_id):
  62. if cust_id == '':
  63. return None, None
  64. else:
  65. return get_cust_info(cust_id)
  66. # merchant_dict = get_merchant_list()
  67. # merchant_list_titles = list(merchant_dict.keys())
  68. merchant_list_titles = get_merchant_list()
  69. # 确保商户列表不为空
  70. default_merchant = merchant_list_titles[0] if merchant_list_titles else None
  71. default_cust_info, default_image = get_cust_info(default_merchant)
  72. with gr.Blocks() as demo:
  73. gr.Markdown("## 侵权识别系统", elem_id="header")
  74. with gr.Row():
  75. with gr.Column():
  76. brand_state = gr.State(value="李宁")
  77. brand_dropdown = gr.Dropdown(
  78. ["李宁"],
  79. label="品牌选择",
  80. value="李宁",
  81. interactive=True)
  82. search_box = gr.Textbox(label="搜索商户")
  83. search_button = gr.Button("搜索")
  84. merchant_list = gr.Dropdown(
  85. merchant_list_titles,
  86. label="商户列表",
  87. value=default_merchant, # 设置默认值
  88. interactive=True
  89. )
  90. check_button = gr.Button("查询侵权")
  91. with gr.Column():
  92. with gr.Row():
  93. image_display = gr.Image(label="商品图片", interactive=False, type='pil', value=default_image)
  94. product_info = gr.Textbox(
  95. label="商品信息",
  96. interactive=False,
  97. value=default_cust_info # 预填充默认商户信息
  98. )
  99. infringement_result = gr.Textbox(label="侵权识别结果", interactive=False)
  100. # 事件绑定
  101. brand_dropdown.change(
  102. fn=lambda x: x,
  103. inputs=brand_dropdown,
  104. outputs=brand_state
  105. )
  106. search_button.click(search_by_cust_id, inputs=search_box, outputs=[product_info, image_display])
  107. merchant_list.change(get_cust_info, inputs=merchant_list, outputs=[product_info, image_display])
  108. check_button.click(check_infringement, inputs=[merchant_list, brand_state], outputs=infringement_result)
  109. demo.launch(server_name = "0.0.0.0", server_port = 7860)