webui.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import gradio as gr
  2. from db import MongoDao
  3. import requests
  4. from PIL import Image
  5. from io import BytesIO
  6. from agent.agent import Agent
  7. import json
  8. agent = Agent()
  9. headers = {
  10. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36",
  11. "Referer": "https://www.aliexpress.com/",
  12. "Accept-Language": "en-US,en;q=0.9",
  13. }
  14. products_dao = MongoDao("vbrand-ec")
  15. license_dao = MongoDao("ProductStandard")
  16. def get_merchant_list():
  17. """ 返回商户列表,显示 title,存储 outId """
  18. merchant_data = [item["title"] for item in products_dao.get_fields_data(["title"])]
  19. # merchant_dict = {m["title"]: m["outId"] for m in merchant_data}
  20. return merchant_data
  21. # def display_cust_info(title):
  22. # """ 根据商户名称返回对应的 outId """
  23. # merchant_dict = get_merchant_list()
  24. # cust_id = merchant_dict.get(title, "-1")
  25. # if cust_id != '-1':
  26. # return get_cust_info(cust_id)
  27. # else:
  28. # return "未找到对应 ID", None
  29. def load_image(image_url):
  30. response = requests.get(image_url, headers=headers)
  31. image = Image.open(BytesIO(response.content))
  32. return image
  33. def get_cust_info(title):
  34. record = products_dao.get_one_record_by_query({"title": title})
  35. if record == None:
  36. return "title不正确", None
  37. res = f"""
  38. 商品名称:\t{record["title"]}\n
  39. 平台:\t{record["platFormName"]}\n
  40. 品牌:\t{record["brandName"]}\n
  41. 价格:\t{record["price"]}\n
  42. 链接:\t{record["url"]}\n
  43. """
  44. image_url = record["image"][0]
  45. image = load_image(image_url)
  46. return res, image
  47. def get_license_list():
  48. """获取品牌方授权商品列表"""
  49. license_list = []
  50. records = license_dao.get_records_by_query({"BrandName":"李宁"})
  51. for record in records:
  52. if "ProductSeries" not in record.keys():
  53. record["ProductSeries"] = ""
  54. # license_list.append(
  55. # {
  56. # "产品名称":record["ProductTitle"],
  57. # "产品分类":record["Category"],
  58. # "产品系列":record["ProductSeries"]
  59. # }
  60. # )
  61. product = f"{record['ProductSeries']} {record['Category']}"
  62. if product not in license_list:
  63. license_list.append(product)
  64. return license_list
  65. def check_infringement(title, brandname):
  66. """ 模拟侵权检测逻辑 """
  67. record = products_dao.get_one_record_by_query({"title": title})
  68. image_url = record["image"][0]
  69. if brandname not in title:
  70. actual_brandname = record["brandName"]
  71. if actual_brandname not in brandname:
  72. key_word_falg = True
  73. else:
  74. key_word_falg = False
  75. key_word_judgement = json.loads(agent.brand_key_word_judgement(brandname, title))
  76. license_judgement = json.loads(agent.license_product_judgement(title, license_list_str))
  77. logo_judgement = json.loads(agent.image_logo_judgement("./logo/lining.jpg", image_url))
  78. key_word_falg = key_word_judgement["key_word_flag"]
  79. license_judgement_flag = license_judgement["in_list"]
  80. # license_flag = license_judgement["in_list"]
  81. result = f"""
  82. 关键词引流: {key_word_falg}
  83. 是否为授权生产产品: {license_judgement_flag}
  84. 产品LOGO图像判定:
  85. 图像中是否包含logo: {logo_judgement["is_contain_logo"]}
  86. """
  87. if logo_judgement["is_contain_logo"]:
  88. result +=
  89. f"""是否是指定品牌LOGO: {logo_judgement["is_jugement_logo"]}
  90. LOGO名称: {logo_judgement["brand_name"]}
  91. """
  92. return result
  93. def search_by_cust_id(cust_id):
  94. if cust_id == '':
  95. return None, None
  96. else:
  97. return get_cust_info(cust_id)
  98. # merchant_dict = get_merchant_list()
  99. # merchant_list_titles = list(merchant_dict.keys())
  100. merchant_list_titles = get_merchant_list()
  101. # 确保商户列表不为空
  102. default_merchant = merchant_list_titles[0] if merchant_list_titles else None
  103. default_cust_info, default_image = get_cust_info(default_merchant)
  104. license_list = get_license_list()
  105. license_list_str = ""
  106. for product in license_list:
  107. license_list_str += f"{product}\n"
  108. with gr.Blocks() as demo:
  109. gr.Markdown("## 侵权识别系统", elem_id="header")
  110. with gr.Row():
  111. with gr.Column():
  112. brand_state = gr.State(value="李宁")
  113. brand_dropdown = gr.Dropdown(
  114. ["李宁"],
  115. label="品牌选择",
  116. value="李宁",
  117. interactive=True)
  118. search_box = gr.Textbox(label="搜索商户")
  119. search_button = gr.Button("搜索")
  120. merchant_list = gr.Dropdown(
  121. merchant_list_titles,
  122. label="商户列表",
  123. value=default_merchant, # 设置默认值
  124. interactive=True
  125. )
  126. check_button = gr.Button("查询侵权")
  127. with gr.Column():
  128. with gr.Row():
  129. image_display = gr.Image(label="商品图片", interactive=False, type='pil', value=default_image)
  130. product_info = gr.Textbox(
  131. label="商品信息",
  132. interactive=False,
  133. value=default_cust_info # 预填充默认商户信息
  134. )
  135. infringement_result = gr.Textbox(label="侵权识别结果", interactive=False)
  136. # 事件绑定
  137. brand_dropdown.change(
  138. fn=lambda x: x,
  139. inputs=brand_dropdown,
  140. outputs=brand_state
  141. )
  142. search_button.click(search_by_cust_id, inputs=search_box, outputs=[product_info, image_display])
  143. merchant_list.change(get_cust_info, inputs=merchant_list, outputs=[product_info, image_display])
  144. check_button.click(check_infringement, inputs=[merchant_list, brand_state], outputs=infringement_result)
  145. demo.launch(server_name = "0.0.0.0", server_port = 7860)