Просмотр исходного кода

增加clip图像特征提取模型

Sherlock1011 8 месяцев назад
Родитель
Сommit
9884b8c079
7 измененных файлов с 167 добавлено и 13 удалено
  1. 7 7
      api_test.py
  2. 5 0
      model/__init__.py
  3. 45 0
      model/clip/clip_model.py
  4. 37 0
      model/clip/inference.py
  5. 54 0
      model/clip/train.py
  6. BIN
      product.png
  7. 19 6
      utils/api_service.py

+ 7 - 7
api_test.py

@@ -38,15 +38,15 @@ import json
 url = "https://670813644644357-http-7860.northwest1.gpugeek.com:8443/brandanalysis/api/v1/infringe_judgement"
 # url = "http://172.18.1.189:7860/brandanalysis/api/v1/infringe_judgement"
 url_data = {
-    'title': '李宁云缓震跑步鞋男鞋2025新款春夏季网面透气休闲运动鞋男慢跑鞋',
+    'title': '李宁超轻21男跑鞋2024年新款反光䨻丝高回弹轻质透气缓震ARBU001',
     'brand_name': '李宁',
     'product_images': [
-            "https://gw.alicdn.com/imgextra/O1CN01rMR0141uTSU983635_!!2145866038.jpg_q95.jpg_.webp",
-            "https://img.alicdn.com/imgextra/i4/2145866038/O1CN01ZM1xNI1uTSTGvavrL_!!2145866038.jpg_q75.jpg_.webp",
-            "https://img.alicdn.com/imgextra/i4/2145866038/O1CN01QaaAgt1uTSTFOdQGD_!!2145866038.jpg_q75.jpg_.webp",
-            "https://img.alicdn.com/imgextra/i4/2145866038/O1CN01dPZcPe1uTSSEBeiwS_!!2145866038.jpg_q75.jpg_.webp",
-            "https://img.alicdn.com/imgextra/i3/2145866038/O1CN018tszjT1uTSNbo83rm_!!2145866038.jpg_q75.jpg_.webp",
-            "https://img.alicdn.com/imgextra/i3/2145866038/O1CN017JGmX21uTSPqUQmQ8_!!2145866038.jpg_q75.jpg_.webp"
+            "https://gw.alicdn.com/imgextra/O1CN01DoWU8A1JdrhxBBsmP_!!3378851052.jpg_q95.jpg_.webp",
+            "https://img.alicdn.com/imgextra/i4/3378851052/O1CN01E1epnW1JdrgxCnPpR_!!3378851052.jpg_q75.jpg_.webp",
+            "https://img.alicdn.com/imgextra/i2/3378851052/O1CN01Kx7yq11JdrgEO6kSN_!!3378851052.jpg_q75.jpg_.webp",
+            "https://img.alicdn.com/imgextra/i1/3378851052/O1CN010sutlV1JdrgGP4GTI_!!3378851052.jpg_q75.jpg_.webp",
+            "https://img.alicdn.com/imgextra/i3/3378851052/O1CN0190Cvpw1JdriE6otol_!!3378851052.jpg_q75.jpg_.webp",
+            "https://img.alicdn.com/imgextra/i4/3378851052/O1CN01SsCo6a1JdriD00UuX_!!3378851052.jpg_q75.jpg_.webp"
         ],
     'price': 19.9
 }

+ 5 - 0
model/__init__.py

@@ -0,0 +1,5 @@
+from .clip.inference import ClipModelInference
+
+__all__ = [
+    'ClipModelInference'
+]

+ 45 - 0
model/clip/clip_model.py

@@ -0,0 +1,45 @@
+import clip
+import torch
+import torch.nn.functional as F
+from utils.utils import load_image_from_url
+
+class ClipModel:
+    def __init__(self):
+        self._device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
+        self._model, self._preprocess = clip.load("ViT-L/14", device=self._device)
+        self._tokenizer = clip.tokenize
+        
+    def extract_image_feature(self, image):
+        """提取图像特征"""
+        input_image = self._preprocess(image).unsqueeze(0).to(self._device)
+        with torch.no_grad():
+            image_feat = self._model.encode_image(input_image)
+        image_feat = F.normalize(image_feat, dim=-1)
+        
+        return image_feat
+    
+    def extract_text_feature(self, text):
+        """提取文本特征"""
+        input_text = self._tokenizer(text, context_length=self._model.context_length).to(self._device)
+        with torch.no_grad():
+            text_feat = self._model.encode_text(input_text)
+        text_feat = F.normalize(text_feat, dim=-1)
+        
+        return text_feat
+        
+        
+if __name__ == '__main__':
+    clip_model = ClipModel()
+    image_url1 = 'https://img.alicdn.com/imgextra/i3/2145866038/O1CN017JGmX21uTSPqUQmQ8_!!2145866038.jpg_q75.jpg_.webp'
+    image1 = load_image_from_url(image_url1).resize((512, 512))
+    
+    image_url2 = 'https://ai-public-1251740668.cos.ap-guangzhou.myqcloud.com/basicdata/lining_price_01.jpg'
+    image2 = load_image_from_url(image_url2).resize((512, 512))
+    
+    feat1 = clip_model.extract_image_feature(image1)
+    feat2 = clip_model.extract_image_feature(image2)
+    
+    
+    print(feat1 @ feat2.t() * 100)
+        
+        

+ 37 - 0
model/clip/inference.py

@@ -0,0 +1,37 @@
+from .train import ClipModelTrainer
+from utils.utils import load_image_from_url
+class ClipModelInference:
+    def __init__(self, brand_name):
+        self._products_feat_map = self._load_model(brand_name)
+        
+    def _load_model(self, brand_name):
+        self._trainer = ClipModelTrainer(brand_name)
+        return self._trainer.train()
+    
+    def calulate_similarity(self, product_image):
+        """计算链接产品图像与授权商品产品图像相似度"""
+        product_image_feat = self._trainer._model.extract_image_feature(product_image)
+        similarity_map = []
+        for product in self._products_feat_map:
+            similarity = product_image_feat @ product['image_feat'].t() * 100
+            similarity_map.append(
+                {
+                    'product_name': product['product_name'],
+                    'image': product['image'],
+                    'similarity': similarity.item()
+                }
+            )
+        
+        # 将列表按照similarity进行倒序排序
+        similarity_map = sorted(similarity_map, key=lambda x: x['similarity'], reverse=True)
+        return similarity_map
+    
+if __name__ == '__main__':
+    brand_name = '李宁'
+    product_image_url = 'https://gw.alicdn.com/imgextra/O1CN015qx8Gw1Jdrhzk6y3v_!!3378851052.jpg_q95.jpg_.webp'
+    product_image = load_image_from_url(product_image_url).resize((512, 512))
+    inference = ClipModelInference(brand_name)
+    similarity_map = inference.calulate_similarity(product_image)
+    
+    for item in similarity_map:
+        print(item)

+ 54 - 0
model/clip/train.py

@@ -0,0 +1,54 @@
+from .clip_model import ClipModel
+from db import MongoDao
+from tqdm import tqdm
+from utils.utils import load_image_from_url
+
+class ClipModelTrainer:
+    def __init__(self, brand_name):
+        self._brand_name = brand_name
+        self._dao = MongoDao('ProductStandard')
+        self._model = ClipModel()
+        self._products_data = self.load_data()
+        
+    def load_data(self):
+        """获取指定品牌的授权商品图片数据"""
+        license_products_data = []
+        records = self._dao.get_records_by_query({"brand_name": self._brand_name})
+        for record in records:
+            license_products_data.append(
+                {
+                    'product_name': record['product_name'],
+                    'images': record['product_images']
+                }
+            )
+            
+        return license_products_data
+        
+    def train(self):
+        """计算指定品牌所有授权商品的图片特征"""
+        products_feature_map = []
+        for data in tqdm(self._products_data, desc=f'正在计算{self._brand_name}的授权产品图像特征...'):
+            if len(data['images']) == 0:
+                continue
+            for image_url in data['images']:
+                image = load_image_from_url(image_url).resize((512, 512))
+                feat = self._model.extract_image_feature(image)
+                products_feature_map.append(
+                    {
+                        'product_name': data['product_name'],
+                        'image': image_url,
+                        'image_feat': feat
+                    }
+                )
+                
+        return products_feature_map
+    
+    def extract_image_feat(self, image):
+        return self._model.extract_image_feature(image)
+    
+if __name__ == '__main__':
+    trainer = ClipModelTrainer('李宁')
+    res = trainer.train()
+    for item in res:
+        print(item)
+    


+ 19 - 6
utils/api_service.py

@@ -1,8 +1,12 @@
 from agent.agent import Agent
 from db import MongoDao
 import json5
+from model import ClipModelInference
+
+from utils.utils import load_image_from_url
 
 license_dao = MongoDao("ProductStandard")
+license_infernece = ClipModelInference('李宁')
 
 class ApiService:
     agent = Agent()
@@ -27,10 +31,19 @@ class ApiService:
                         return True
             
             # 与授权商品对比
-            similarity_judgement = json5.loads(ApiService.agent.multi_products_images_similarity_judgement(url_data['product_images'], basic_data['product_images']))
-            if similarity_judgement['is_similarity_product']:
-                return False
-            else:
+            # similarity_judgement = json5.loads(ApiService.agent.multi_products_images_similarity_judgement(url_data['product_images'], basic_data['product_images']))
+            # if similarity_judgement['is_similarity_product']:
+            #     return False
+            # else:
+            #     return True
+            
+            if len(url_data['product_images']) != 0:
+                for image_url in url_data['product_images']:
+                    product_image = load_image_from_url(image_url).resize((512, 512))
+                    similarity_map = license_infernece.calulate_similarity(product_image)
+                    if similarity_map[0]['similarity'] >= 90.0:
+                        return False
+                    
                 return True
         
         # 图像判定
@@ -110,9 +123,9 @@ class ApiService:
     def get_license_list(brand_name):
         """获取品牌方授权生成的商品列表"""
         license_list = []
-        records = license_dao.get_records_by_query({"BrandName": brand_name})
+        records = license_dao.get_records_by_query({"brand_name": brand_name})
         for record in records:
-            license_list.append(record['ProductTitle'].strip())
+            license_list.append(record['product_name'].strip())
         return license_list
      
     @staticmethod