| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- from .train import ClipModelTrainer
- from utils.utils import load_image_from_url, load_image_from_cos
- from .clip_model import ClipModel
- class ClipModelInference:
- def __init__(self, brand_name):
- self._products_feat_map = self._load_model(brand_name)
-
- def _load_model(self, brand_name):
- self._trainer = ClipModelTrainer(brand_name)
- return self._trainer.train()
-
- def calulate_similarity(self, product_image):
- """计算链接产品图像与授权商品产品图像相似度"""
- product_image_feat = self._trainer._model.extract_image_feature(product_image)
- similarity_map = []
- for product in self._products_feat_map:
- similarity = product_image_feat @ product['image_feat'].t() * 100
- similarity_map.append(
- {
- 'product_name': product['product_name'],
- 'image': product['image'],
- 'similarity': similarity.item()
- }
- )
-
- # 将列表按照similarity进行倒序排序
- similarity_map = sorted(similarity_map, key=lambda x: x['similarity'], reverse=True)
- return similarity_map
- class ClipCompareModelInference:
- def __init__(self):
- self._clip_model = ClipModel()
-
- def calculate_similarity(self, product_image, base_products):
- base_products_feats_map = self.get_base_products_feats_map(base_products)
- product_image_feat = self._clip_model.extract_image_feature(product_image)
- similarity_map = []
- for product in base_products_feats_map:
- similarity = product_image_feat @ product['image_feat'].t() * 100
- similarity_map.append(
- {
- 'image': product['image'],
- 'similarity': similarity.item()
- }
- )
-
- # 将列表按照similarity进行倒序排序
- similarity_map = sorted(similarity_map, key=lambda x: x['similarity'], reverse=True)
- return similarity_map
-
-
- def get_base_products_feats_map(self, base_products):
- feats_map = []
- for url in base_products:
- image = load_image_from_cos(url).resize((512, 512))
- feat = self._clip_model.extract_image_feature(image)
- feats_map.append(
- {
- 'image': url,
- 'image_feat': feat
- }
- )
- return feats_map
-
- if __name__ == '__main__':
- brand_name = '李宁'
- product_image_url = 'https://gw.alicdn.com/imgextra/O1CN015qx8Gw1Jdrhzk6y3v_!!3378851052.jpg_q95.jpg_.webp'
- product_image = load_image_from_url(product_image_url).resize((512, 512))
- inference = ClipModelInference(brand_name)
- similarity_map = inference.calulate_similarity(product_image)
-
- for item in similarity_map:
- print(item)
|