inference.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. from .train import ClipModelTrainer
  2. from utils.utils import load_image_from_url
  3. class ClipModelInference:
  4. def __init__(self, brand_name):
  5. self._products_feat_map = self._load_model(brand_name)
  6. def _load_model(self, brand_name):
  7. self._trainer = ClipModelTrainer(brand_name)
  8. return self._trainer.train()
  9. def calulate_similarity(self, product_image):
  10. """计算链接产品图像与授权商品产品图像相似度"""
  11. product_image_feat = self._trainer._model.extract_image_feature(product_image)
  12. similarity_map = []
  13. for product in self._products_feat_map:
  14. similarity = product_image_feat @ product['image_feat'].t() * 100
  15. similarity_map.append(
  16. {
  17. 'product_name': product['product_name'],
  18. 'image': product['image'],
  19. 'similarity': similarity.item()
  20. }
  21. )
  22. # 将列表按照similarity进行倒序排序
  23. similarity_map = sorted(similarity_map, key=lambda x: x['similarity'], reverse=True)
  24. return similarity_map
  25. if __name__ == '__main__':
  26. brand_name = '李宁'
  27. product_image_url = 'https://gw.alicdn.com/imgextra/O1CN015qx8Gw1Jdrhzk6y3v_!!3378851052.jpg_q95.jpg_.webp'
  28. product_image = load_image_from_url(product_image_url).resize((512, 512))
  29. inference = ClipModelInference(brand_name)
  30. similarity_map = inference.calulate_similarity(product_image)
  31. for item in similarity_map:
  32. print(item)