Kaynağa Gözat

重构界面代码

yangzeyu 11 ay önce
ebeveyn
işleme
ca2bf21f9c
8 değiştirilmiş dosya ile 278 ekleme ve 146 silme
  1. 6 0
      data/__init__.py
  2. 13 0
      data/bean/BrandInfo.py
  3. 12 0
      data/bean/ProductInfo.py
  4. 2 1
      db/dao/dao.py
  5. 3 2
      utils/__init__.py
  6. 24 0
      utils/service.py
  7. 47 143
      webui.py
  8. 171 0
      webui_ori.py

+ 6 - 0
data/__init__.py

@@ -0,0 +1,6 @@
+from data.bean.BrandInfo import BrandInfo
+from data.bean.ProductInfo import ProductInfo
+__all__ = [
+    "BrandInfo",
+    "ProductInfo"
+]

+ 13 - 0
data/bean/BrandInfo.py

@@ -0,0 +1,13 @@
+from db import MongoDao
+from data.bean.ProductInfo import ProductInfo
+class BrandInfo:
+    
+    def __init__(self, brand_name):
+        self._dao = MongoDao("vbrand-ec")
+        self.brand_name = brand_name
+        self.product_list = [ProductInfo(record) for record in self._dao.get_records_by_query({"brandName": brand_name})]
+        
+if __name__ == "__main__":
+    brand_info = BrandInfo("李宁")
+    for i in brand_info.product_list:
+        print(i.image)

+ 12 - 0
data/bean/ProductInfo.py

@@ -0,0 +1,12 @@
+class ProductInfo:
+    def __init__(self, record):
+        self.id = record.get("_id", "null")
+        self.title = record.get("title", "null")
+        self.sizes = record.get("sizes", "null")
+        self.colors = record.get("colors", "null")
+        self.url = record.get("url", "null")
+        self.image = record.get("image", [])
+        self.price = record.get("price", "null")
+        self.shopTitle = record.get("shopTitle", "null")
+        self.platFormName = record.get("platFormName", "null")
+        

+ 2 - 1
db/dao/dao.py

@@ -11,13 +11,14 @@ class MongoDao:
         return res
     
     def get_records_by_query(self, query):
+        """根据查询返回多个结果"""
         collections = self.db_client.find_many(self.collection_name, query)
         records = [collection for collection in collections]
         return records
     
     def get_one_field_data(self, field):
-        fields = [field]
         """获取指定key的所有数据,返回列表"""
+        fields = [field]
         field_records = self.db_client.find_fields(self.collection_name, fields)
         return [record[field] for record in field_records]
     

+ 3 - 2
utils/__init__.py

@@ -1,5 +1,6 @@
 from utils.utils import image_to_base
-
+from utils.service import Service
 __all__ =[
-    "image_to_base"
+    "image_to_base",
+    "Service"
 ]

+ 24 - 0
utils/service.py

@@ -0,0 +1,24 @@
+from db import MongoDao
+from data import BrandInfo, ProductInfo
+
+product_dao = MongoDao("vbrand-ec")
+
+class Service:
+    @staticmethod
+    def get_brand_list():
+        """获取数据库中存储的品牌列表"""
+        brand_list = product_dao.get_one_field_data("brandName")
+        brand_list = list(dict.fromkeys(brand_list))
+        
+        return brand_list
+    
+    @staticmethod
+    def load_products_by_brand(brand_name):
+        """加载选定品牌的所有商品"""
+        brand_info = BrandInfo(brand_name)
+        return [product.title for product in brand_info.product_list], brand_info
+
+
+if __name__ == "__main__":
+    brand_list = Service.get_brand_list()
+    print(brand_list)

+ 47 - 143
webui.py

@@ -5,168 +5,72 @@ from PIL import Image
 from io import BytesIO
 from agent.agent import Agent
 import json
+import pandas as pd
+from utils import Service
 
-agent = Agent()
+brand_list = Service.get_brand_list()
 
-headers = {
-    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36",
-    "Referer": "https://www.aliexpress.com/",
-    "Accept-Language": "en-US,en;q=0.9",
-}
-
-products_dao = MongoDao("vbrand-ec")
-license_dao = MongoDao("ProductStandard")
-
-def get_merchant_list():
-    """ 返回商户列表,显示 title,存储 outId """
-    merchant_data = [item["title"] for item in products_dao.get_fields_data(["title"])]
-    # merchant_dict = {m["title"]: m["outId"] for m in merchant_data}
-    return merchant_data
-
-# def display_cust_info(title):
-#     """ 根据商户名称返回对应的 outId """
-#     merchant_dict = get_merchant_list()
-#     cust_id = merchant_dict.get(title, "-1")
-#     if cust_id != '-1':
-#         return get_cust_info(cust_id)
-#     else:
-#         return "未找到对应 ID", None
-    
-def load_image(image_url):
-    response = requests.get(image_url, headers=headers)
-    image = Image.open(BytesIO(response.content))
-    return image
-
-def get_cust_info(title):
-    record = products_dao.get_one_record_by_query({"title": title})
-    if record == None:
-        return "title不正确", None
-    res = f"""
-            商品名称:\t{record["title"]}\n
-            平台:\t{record["platFormName"]}\n
-            品牌:\t{record["brandName"]}\n
-            价格:\t{record["price"]}\n
-            链接:\t{record["url"]}\n
-    """
-    image_url = record["image"][0]
-    image = load_image(image_url)
-    return res, image
-
-def get_license_list():
-    """获取品牌方授权商品列表"""
-    license_list = []
-    records = license_dao.get_records_by_query({"BrandName":"李宁"})
-    for record in records:
-        if "ProductSeries" not in record.keys():
-            record["ProductSeries"] = ""
-        # license_list.append(
-        #     {
-        #         "产品名称":record["ProductTitle"],
-        #         "产品分类":record["Category"],
-        #         "产品系列":record["ProductSeries"]
-        #     }
-        # )
-        product = f"{record['ProductSeries']} {record['Category']}"
-        if product not in license_list:
-            license_list.append(product)
-    return license_list
-    
-
-def check_infringement(title, brandname):
-    """ 模拟侵权检测逻辑 """
-    record = products_dao.get_one_record_by_query({"title": title})
-    image_url = record["image"][0]
-    
-    if brandname not in title:
-        actual_brandname = record["brandName"]
-        if actual_brandname not in brandname:
-            key_word_falg = True
-        else:
-            key_word_falg = False
-    
-    key_word_judgement = json.loads(agent.brand_key_word_judgement(brandname, title))
-    license_judgement = json.loads(agent.license_product_judgement(title, license_list_str))
-    logo_judgement = json.loads(agent.image_logo_judgement("./logo/lining.jpg", image_url))
-   
-    key_word_falg = key_word_judgement["key_word_flag"]
-    license_judgement_flag = license_judgement["in_list"]
-    # license_flag = license_judgement["in_list"]
-    result = f"""
-        关键词引流: {key_word_falg}
-        是否为授权生产产品: {license_judgement_flag}
-        
-        产品LOGO图像判定:
-        图像中是否包含logo: {logo_judgement["is_contain_logo"]}
-    """
-    
-    if logo_judgement["is_contain_logo"]:
-        result += 
-        f"""是否是指定品牌LOGO: {logo_judgement["is_jugement_logo"]}
-        LOGO名称: {logo_judgement["brand_name"]}
-        """
-    return result
-
-def search_by_cust_id(cust_id):
-    if cust_id == '':
-        return None, None
-    else:
-        return get_cust_info(cust_id)
-
-# merchant_dict = get_merchant_list()
-# merchant_list_titles = list(merchant_dict.keys())
-merchant_list_titles = get_merchant_list()
-
-# 确保商户列表不为空
-default_merchant = merchant_list_titles[0] if merchant_list_titles else None
-default_cust_info, default_image = get_cust_info(default_merchant)
-
-license_list = get_license_list()
-license_list_str = ""
-for product in license_list:
-    license_list_str += f"{product}\n"
 with gr.Blocks() as demo:
     gr.Markdown("## 侵权识别系统", elem_id="header")
     
     with gr.Row():
-        with gr.Column():
-            brand_state = gr.State(value="李宁")
+        with gr.Column():  # 左侧控制面板
+            brand_state = gr.State(value=brand_list[0])
             brand_dropdown = gr.Dropdown(
-                ["李宁"],
+                brand_list,
                 label="品牌选择", 
-                value="李宁",
+                value=brand_list[0],
                 interactive=True)
             
             search_box = gr.Textbox(label="搜索商户")
             search_button = gr.Button("搜索")
-            
-            merchant_list = gr.Dropdown(
-                merchant_list_titles, 
-                label="商户列表", 
-                value=default_merchant,  # 设置默认值
-                interactive=True
-            )
-            
             check_button = gr.Button("查询侵权")
         
-        with gr.Column():
-            with gr.Row():
-                image_display = gr.Image(label="商品图片", interactive=False, type='pil', value=default_image)
-                product_info = gr.Textbox(
-                    label="商品信息",
-                    interactive=False,
-                    value=default_cust_info  # 预填充默认商户信息
-                )
+        with gr.Column():  # 右侧展示面板
+            # 1. 图片展示区(顶部)
+            image_display = gr.Gallery(
+                label="商品图片",
+                columns=3,  # 每行显示3张图片
+                height="auto",
+                object_fit="contain",  # 保持图片比例
+                interactive=False,
+                show_share_button=False
+            )
             
-            infringement_result = gr.Textbox(label="侵权识别结果", interactive=False)
+            # # 2. 商品信息区(中部)
+            # product_info = gr.Textbox(
+            #     label="商品信息",
+            #     interactive=False,
+            #     lines=5  # 增加显示行数
+            # )
+            product_info = gr.Dataframe(
+                headers=["", ""],
+                row_count=5,
+                col_count=2,
+                value=[
+                    ["商户名称", "某某专卖店"],
+                    ["商品ID", "SP20230001"],
+                    ["上架时间", "2023-05-15"],
+                    ["价格", "¥299"],
+                    ["销量", "1,208件"]
+                ],
+                interactive=False  # 设为True可允许编辑
+            )
+            
+            # 3. 侵权结果区(底部)
+            infringement_result = gr.Textbox(
+                label="侵权识别结果",
+                interactive=False,
+                lines=3
+            )
     
-    # 事件绑定
+    # 事件绑定(保持不变)
     brand_dropdown.change(
         fn=lambda x: x,
         inputs=brand_dropdown,
         outputs=brand_state
     )
-    search_button.click(search_by_cust_id, inputs=search_box, outputs=[product_info, image_display])
-    merchant_list.change(get_cust_info, inputs=merchant_list, outputs=[product_info, image_display])
-    check_button.click(check_infringement, inputs=[merchant_list, brand_state], outputs=infringement_result)
+    # search_button.click(search_by_cust_id, inputs=search_box, outputs=[product_info, image_display])
+    # check_button.click(check_infringement, inputs=[...], outputs=infringement_result)
 
-demo.launch(server_name = "0.0.0.0", server_port = 7860)
+demo.launch(server_name="0.0.0.0", server_port=7860)

+ 171 - 0
webui_ori.py

@@ -0,0 +1,171 @@
+import gradio as gr
+from db import MongoDao
+import requests
+from PIL import Image
+from io import BytesIO
+from agent.agent import Agent
+import json
+
+agent = Agent()
+
+headers = {
+    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36",
+    "Referer": "https://www.aliexpress.com/",
+    "Accept-Language": "en-US,en;q=0.9",
+}
+
+products_dao = MongoDao("vbrand-ec")
+license_dao = MongoDao("ProductStandard")
+
+def get_merchant_list():
+    """ 返回商户列表,显示 title,存储 outId """
+    merchant_data = [item["title"] for item in products_dao.get_fields_data(["title"])]
+    # merchant_dict = {m["title"]: m["outId"] for m in merchant_data}
+    return merchant_data
+
+# def display_cust_info(title):
+#     """ 根据商户名称返回对应的 outId """
+#     merchant_dict = get_merchant_list()
+#     cust_id = merchant_dict.get(title, "-1")
+#     if cust_id != '-1':
+#         return get_cust_info(cust_id)
+#     else:
+#         return "未找到对应 ID", None
+    
+def load_image(image_url):
+    response = requests.get(image_url, headers=headers)
+    image = Image.open(BytesIO(response.content))
+    return image
+
+def get_cust_info(title):
+    record = products_dao.get_one_record_by_query({"title": title})
+    if record == None:
+        return "title不正确", None
+    res = f"""
+            商品名称:\t{record["title"]}\n
+            平台:\t{record["platFormName"]}\n
+            品牌:\t{record["brandName"]}\n
+            价格:\t{record["price"]}\n
+            链接:\t{record["url"]}\n
+    """
+    image_url = record["image"][0]
+    image = load_image(image_url)
+    return res, image
+
+def get_license_list():
+    """获取品牌方授权商品列表"""
+    license_list = []
+    records = license_dao.get_records_by_query({"BrandName":"李宁"})
+    for record in records:
+        if "ProductSeries" not in record.keys():
+            record["ProductSeries"] = ""
+        # license_list.append(
+        #     {
+        #         "产品名称":record["ProductTitle"],
+        #         "产品分类":record["Category"],
+        #         "产品系列":record["ProductSeries"]
+        #     }
+        # )
+        product = f"{record['ProductSeries']} {record['Category']}"
+        if product not in license_list:
+            license_list.append(product)
+    return license_list
+    
+
+def check_infringement(title, brandname):
+    """ 模拟侵权检测逻辑 """
+    record = products_dao.get_one_record_by_query({"title": title})
+    image_url = record["image"][0]
+    
+    if brandname not in title:
+        actual_brandname = record["brandName"]
+        if actual_brandname not in brandname:
+            key_word_falg = True
+        else:
+            key_word_falg = False
+    
+    key_word_judgement = json.loads(agent.brand_key_word_judgement(brandname, title))
+    license_judgement = json.loads(agent.license_product_judgement(title, license_list_str))
+    logo_judgement = json.loads(agent.image_logo_judgement("./logo/lining.jpg", image_url))
+   
+    key_word_falg = key_word_judgement["key_word_flag"]
+    license_judgement_flag = license_judgement["in_list"]
+    # license_flag = license_judgement["in_list"]
+    result = f"""
+        关键词引流: {key_word_falg}
+        是否为授权生产产品: {license_judgement_flag}
+        
+        产品LOGO图像判定:
+        图像中是否包含logo: {logo_judgement["is_contain_logo"]}
+    """
+    
+    if logo_judgement["is_contain_logo"]:
+        result += f"""是否是指定品牌LOGO: {logo_judgement["is_jugement_logo"]}
+        LOGO名称: {logo_judgement["brand_name"]}
+        """
+    return result
+
+def search_by_cust_id(cust_id):
+    if cust_id == '':
+        return None, None
+    else:
+        return get_cust_info(cust_id)
+
+# merchant_dict = get_merchant_list()
+# merchant_list_titles = list(merchant_dict.keys())
+merchant_list_titles = get_merchant_list()
+
+# 确保商户列表不为空
+default_merchant = merchant_list_titles[0] if merchant_list_titles else None
+default_cust_info, default_image = get_cust_info(default_merchant)
+
+license_list = get_license_list()
+license_list_str = ""
+for product in license_list:
+    license_list_str += f"{product}\n"
+with gr.Blocks() as demo:
+    gr.Markdown("## 侵权识别系统", elem_id="header")
+    
+    with gr.Row():
+        with gr.Column():
+            brand_state = gr.State(value="李宁")
+            brand_dropdown = gr.Dropdown(
+                ["李宁"],
+                label="品牌选择", 
+                value="李宁",
+                interactive=True)
+            
+            search_box = gr.Textbox(label="搜索商户")
+            search_button = gr.Button("搜索")
+            
+            merchant_list = gr.Dropdown(
+                merchant_list_titles, 
+                label="商户列表", 
+                value=default_merchant,  # 设置默认值
+                interactive=True
+            )
+            
+            check_button = gr.Button("查询侵权")
+        
+        with gr.Column():
+            with gr.Row():
+                image_display = gr.Image(label="商品图片", interactive=False, type='pil', value=default_image)
+                product_info = gr.Textbox(
+                    label="商品信息",
+                    interactive=False,
+                    value=default_cust_info  # 预填充默认商户信息
+                )
+            
+            infringement_result = gr.Textbox(label="侵权识别结果", interactive=False)
+    
+    # 事件绑定
+    brand_dropdown.change(
+        fn=lambda x: x,
+        inputs=brand_dropdown,
+        outputs=brand_state
+    )
+    search_button.click(search_by_cust_id, inputs=search_box, outputs=[product_info, image_display])
+    merchant_list.change(get_cust_info, inputs=merchant_list, outputs=[product_info, image_display])
+    check_button.click(check_infringement, inputs=[merchant_list, brand_state], outputs=infringement_result)
+
+demo.launch(server_name = "0.0.0.0", server_port = 7860)