root преди 6 месеца
родител
ревизия
f951f8d519
променени са 9 файла, в които са добавени 30 реда и са изтрити 29 реда
  1. 2 1
      .gitignore
  2. 3 3
      api_test.py
  3. 3 3
      model/__init__.py
  4. 2 2
      model/brand/brand_model.py
  5. 8 8
      model/brand/inference.py
  6. 3 3
      model/brand/train.py
  7. BIN
      product.png
  8. 3 3
      utils/api_service.py
  9. 6 6
      utils/utils.py

+ 2 - 1
.gitignore

@@ -1,4 +1,5 @@
 .idea/
 .vscode/
 __pycache__/
-*.pyc
+*.pyc
+*.pt

+ 3 - 3
api_test.py

@@ -35,8 +35,8 @@ import json
 # response = requests.post(url, data=json.dumps(payload), headers=headers)
 # print(response.json())
 
-url = "https://670813644644357-http-7860.northwest1.gpugeek.com:8443/brandanalysis/api/v1/infringe_judgement"
-# url = "http://127.0.0.1:7860/brandanalysis/api/v1/infringe_judgement"
+# url = "https://670813644644357-http-7860.northwest1.gpugeek.com:8443/brandanalysis/api/v1/infringe_judgement"
+url = "http://127.0.0.1:7860/brandanalysis/api/v1/infringe_judgement"
 url_data = {
     'title': '全棉时代男童内裤a类纯棉平角短裤全棉男宝中大儿童不夹屁屁内裤',
     'brand_name': '全棉时代',
@@ -55,7 +55,7 @@ basic_data = {
     'brand_name': '全棉时代',
     'similarity_logos': ["全棉時代", "全绵时代"],
     'product_images': [
-           "https://dev-govern-private-1251740668.cos.ap-guangzhou.myqcloud.com/private/20250619/39870ece5e124ce2a083487ba243998e.avif?q-sign-algorithm=sha1&q-ak=AKIDIWXN4kqgpiMm0z4T5VgcKn4KSP8cZwnO&q-sign-time=1754031190%3B1754038390&q-key-time=1754031190%3B1754038390&q-header-list=host&q-url-param-list=&q-signature=43c0703c339819f0a4e3c47223d902caf96285a4"
+           "https://gw.alicdn.com/imgextra/O1CN01pralFH1CNBzDetKtx_!!2215815550068.jpg_q95.jpg_.webp"
         ],
     'base_price': 70.5,
     'price_percent': 0.4

+ 3 - 3
model/__init__.py

@@ -1,6 +1,6 @@
-from .clip.inference import ClipModelInference, ClipCompareModelInference
+from .brand.inference import BrandModelInference, BrandCompareModelInference
 
 __all__ = [
-    'ClipModelInference',
-    'ClipCompareModelInference'
+    'BrandModelInference',
+    'BrandCompareModelInference'
 ]

+ 2 - 2
model/clip/clip_model.py → model/brand/brand_model.py

@@ -3,10 +3,10 @@ import torch
 import torch.nn.functional as F
 from utils.utils import load_image_from_url
 
-class ClipModel:
+class BrandModel:
     def __init__(self):
         self._device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
-        self._model, self._preprocess = clip.load("ViT-L/14", device=self._device)
+        self._model, self._preprocess = clip.load("./model/ckpt/brand_model.pt", device=self._device)
         self._tokenizer = clip.tokenize
         
     def extract_image_feature(self, image):

+ 8 - 8
model/clip/inference.py → model/brand/inference.py

@@ -1,12 +1,12 @@
-from .train import ClipModelTrainer
+from .train import BrandModelTrainer
 from utils.utils import load_image_from_url, load_image_from_cos
-from .clip_model import ClipModel
-class ClipModelInference:
+from .brand_model import BrandModel
+class BrandModelInference:
     def __init__(self, brand_name):
         self._products_feat_map = self._load_model(brand_name)
         
     def _load_model(self, brand_name):
-        self._trainer = ClipModelTrainer(brand_name)
+        self._trainer = BrandModelTrainer(brand_name)
         return self._trainer.train()
     
     def calulate_similarity(self, product_image):
@@ -27,13 +27,13 @@ class ClipModelInference:
         similarity_map = sorted(similarity_map, key=lambda x: x['similarity'], reverse=True)
         return similarity_map
 
-class ClipCompareModelInference:
+class BrandCompareModelInference:
     def __init__(self):
-        self._clip_model = ClipModel()
+        self._brand_model = BrandModel()
     
     def calculate_similarity(self, product_image, base_products):
         base_products_feats_map = self.get_base_products_feats_map(base_products)
-        product_image_feat = self._clip_model.extract_image_feature(product_image)
+        product_image_feat = self._brand_model.extract_image_feature(product_image)
         similarity_map = []
         for product in base_products_feats_map:
             similarity = product_image_feat @ product['image_feat'].t() * 100
@@ -53,7 +53,7 @@ class ClipCompareModelInference:
         feats_map =  []
         for url in base_products:
             image = load_image_from_cos(url).resize((512, 512))
-            feat = self._clip_model.extract_image_feature(image)
+            feat = self._brand_model.extract_image_feature(image)
             feats_map.append(
                 {
                     'image': url,

+ 3 - 3
model/clip/train.py → model/brand/train.py

@@ -1,13 +1,13 @@
-from .clip_model import ClipModel
+from .brand_model import BrandModel
 from db import MongoDao
 from tqdm import tqdm
 from utils.utils import load_image_from_url,load_image_from_cos
 
-class ClipModelTrainer:
+class BrandModelTrainer:
     def __init__(self, brand_name):
         self._brand_name = brand_name
         self._dao = MongoDao('ProductStandard')
-        self._model = ClipModel()
+        self._model = BrandModel()
         self._products_data = self.load_data()
         
     def load_data(self):

BIN
product.png


+ 3 - 3
utils/api_service.py

@@ -1,15 +1,15 @@
 from agent.agent import Agent
 from db import MongoDao
 import json5
-from model import ClipModelInference, ClipCompareModelInference
+from model import BrandModelInference, BrandCompareModelInference
 
 from utils.utils import load_image_from_url, load_image_from_cos
 
 import pandas as pd
 
 license_dao = MongoDao("ProductStandard")
-# license_infernece = ClipModelInference('全棉时代')
-license_infernece = ClipCompareModelInference()
+# license_infernece = BrandModelInference('全棉时代')
+license_infernece = BrandCompareModelInference()
 
 class ApiService:
     agent = Agent()

+ 6 - 6
utils/utils.py

@@ -41,10 +41,10 @@ def load_image_from_cos(cos_url):
             
 
 if __name__ == "__main__":
-    # url = 'https://img.alicdn.com/imgextra/i1/2212526294503/O1CN01P0qxZL1j8QU5cC4ed_!!2212526294503.jpg_q75.jpg_.webp'
-    # image = load_image_from_url(url)
-    # image.save('./product.png')  # 保存为PNG
+    url = 'https://img.alicdn.com/imgextra/i1/2212526294503/O1CN01P0qxZL1j8QU5cC4ed_!!2212526294503.jpg_q75.jpg_.webp'
+    image = load_image_from_url(url)
+    image.save('./product.png')  # 保存为PNG
     
-    cos_url = "https://dev-govern-private-1251740668.cos.ap-guangzhou.myqcloud.com/private/20250723/b9c411c4c2514367afc4dd4ad199eb0d.webp?q-sign-algorithm=sha1&q-ak=AKIDIWXN4kqgpiMm0z4T5VgcKn4KSP8cZwnO&q-sign-time=1753320960%3B1753328160&q-key-time=1753320960%3B1753328160&q-header-list=host&q-url-param-list=&q-signature=f6223c1353c1c97c8e42355625dec94967c9c2f6"
-    image = load_image_from_cos(cos_url)
-    image.save('./product.png')
+    # cos_url = "https://dev-govern-private-1251740668.cos.ap-guangzhou.myqcloud.com/private/20250723/b9c411c4c2514367afc4dd4ad199eb0d.webp?q-sign-algorithm=sha1&q-ak=AKIDIWXN4kqgpiMm0z4T5VgcKn4KSP8cZwnO&q-sign-time=1753320960%3B1753328160&q-key-time=1753320960%3B1753328160&q-header-list=host&q-url-param-list=&q-signature=f6223c1353c1c97c8e42355625dec94967c9c2f6"
+    # image = load_image_from_cos(cos_url)
+    # image.save('./product.png')