webui_ori.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import gradio as gr
  2. from db import MongoDao
  3. import requests
  4. from PIL import Image
  5. from io import BytesIO
  6. from agent.agent import Agent
  7. import json
  8. agent = Agent()
  9. headers = {
  10. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36",
  11. "Referer": "https://www.aliexpress.com/",
  12. "Accept-Language": "en-US,en;q=0.9",
  13. }
  14. products_dao = MongoDao("vbrand-ec")
  15. license_dao = MongoDao("ProductStandard")
  16. def get_merchant_list():
  17. """ 返回商户列表,显示 title,存储 outId """
  18. merchant_data = [item["title"] for item in products_dao.get_fields_data(["title"])]
  19. # merchant_dict = {m["title"]: m["outId"] for m in merchant_data}
  20. return merchant_data
  21. # def display_cust_info(title):
  22. # """ 根据商户名称返回对应的 outId """
  23. # merchant_dict = get_merchant_list()
  24. # cust_id = merchant_dict.get(title, "-1")
  25. # if cust_id != '-1':
  26. # return get_cust_info(cust_id)
  27. # else:
  28. # return "未找到对应 ID", None
  29. def load_image(image_url):
  30. response = requests.get(image_url, headers=headers)
  31. image = Image.open(BytesIO(response.content))
  32. return image
  33. def get_cust_info(title):
  34. record = products_dao.get_one_record_by_query({"title": title})
  35. if record == None:
  36. return "title不正确", None
  37. res = f"""
  38. 商品名称:\t{record["title"]}\n
  39. 平台:\t{record["platFormName"]}\n
  40. 品牌:\t{record["brandName"]}\n
  41. 价格:\t{record["price"]}\n
  42. 链接:\t{record["url"]}\n
  43. """
  44. image_url = record["image"][0]
  45. image = load_image(image_url)
  46. return res, image
  47. def get_license_list():
  48. """获取品牌方授权商品列表"""
  49. license_list = []
  50. records = license_dao.get_records_by_query({"BrandName":"李宁"})
  51. for record in records:
  52. if "ProductSeries" not in record.keys():
  53. record["ProductSeries"] = ""
  54. product = f"{record['ProductSeries']} {record['Category']}"
  55. if product not in license_list:
  56. license_list.append(product)
  57. return license_list
  58. def check_infringement(title, brandname):
  59. """ 模拟侵权检测逻辑 """
  60. record = products_dao.get_one_record_by_query({"title": title})
  61. image_url = record["image"][0]
  62. if brandname not in title:
  63. actual_brandname = record["brandName"]
  64. if actual_brandname not in brandname:
  65. key_word_falg = True
  66. else:
  67. key_word_falg = False
  68. key_word_judgement = json.loads(agent.brand_key_word_judgement(brandname, title))
  69. license_judgement = json.loads(agent.license_product_judgement(title, license_list_str))
  70. logo_judgement = json.loads(agent.image_logo_judgement("./logo/lining.jpg", image_url))
  71. key_word_falg = key_word_judgement["key_word_flag"]
  72. license_judgement_flag = license_judgement["in_list"]
  73. # license_flag = license_judgement["in_list"]
  74. result = f"""
  75. 关键词引流: {key_word_falg}
  76. 是否为授权生产产品: {license_judgement_flag}
  77. 产品LOGO图像判定:
  78. 图像中是否包含logo: {logo_judgement["is_contain_logo"]}
  79. """
  80. if logo_judgement["is_contain_logo"]:
  81. result += f"""是否是指定品牌LOGO: {logo_judgement["is_jugement_logo"]}
  82. LOGO名称: {logo_judgement["brand_name"]}
  83. """
  84. return result
  85. def search_by_cust_id(cust_id):
  86. if cust_id == '':
  87. return None, None
  88. else:
  89. return get_cust_info(cust_id)
  90. # merchant_dict = get_merchant_list()
  91. # merchant_list_titles = list(merchant_dict.keys())
  92. merchant_list_titles = get_merchant_list()
  93. # 确保商户列表不为空
  94. default_merchant = merchant_list_titles[0] if merchant_list_titles else None
  95. default_cust_info, default_image = get_cust_info(default_merchant)
  96. license_list = get_license_list()
  97. license_list_str = ""
  98. for product in license_list:
  99. license_list_str += f"{product}\n"
  100. with gr.Blocks() as demo:
  101. gr.Markdown("## 侵权识别系统", elem_id="header")
  102. with gr.Row():
  103. with gr.Column():
  104. brand_state = gr.State(value="李宁")
  105. brand_dropdown = gr.Dropdown(
  106. ["李宁"],
  107. label="品牌选择",
  108. value="李宁",
  109. interactive=True)
  110. search_box = gr.Textbox(label="搜索商户")
  111. search_button = gr.Button("搜索")
  112. merchant_list = gr.Dropdown(
  113. merchant_list_titles,
  114. label="商户列表",
  115. value=default_merchant, # 设置默认值
  116. interactive=True
  117. )
  118. check_button = gr.Button("查询侵权")
  119. with gr.Column():
  120. with gr.Row():
  121. image_display = gr.Image(label="商品图片", interactive=False, type='pil', value=default_image)
  122. product_info = gr.Textbox(
  123. label="商品信息",
  124. interactive=False,
  125. value=default_cust_info # 预填充默认商户信息
  126. )
  127. infringement_result = gr.Textbox(label="侵权识别结果", interactive=False)
  128. # 事件绑定
  129. brand_dropdown.change(
  130. fn=lambda x: x,
  131. inputs=brand_dropdown,
  132. outputs=brand_state
  133. )
  134. search_button.click(search_by_cust_id, inputs=search_box, outputs=[product_info, image_display])
  135. merchant_list.change(get_cust_info, inputs=merchant_list, outputs=[product_info, image_display])
  136. check_button.click(check_infringement, inputs=[merchant_list, brand_state], outputs=infringement_result)
  137. demo.launch(server_name = "0.0.0.0", server_port = 7860)