inference.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from .train import BrandModelTrainer
  2. from utils.utils import load_image_from_url, load_image_from_cos
  3. from .brand_model import BrandModel
  4. class BrandModelInference:
  5. def __init__(self, brand_name):
  6. self._products_feat_map = self._load_model(brand_name)
  7. def _load_model(self, brand_name):
  8. self._trainer = BrandModelTrainer(brand_name)
  9. return self._trainer.train()
  10. def calulate_similarity(self, product_image):
  11. """计算链接产品图像与授权商品产品图像相似度"""
  12. product_image_feat = self._trainer._model.extract_image_feature(product_image)
  13. similarity_map = []
  14. for product in self._products_feat_map:
  15. similarity = product_image_feat @ product['image_feat'].t() * 100
  16. similarity_map.append(
  17. {
  18. 'product_name': product['product_name'],
  19. 'image': product['image'],
  20. 'similarity': similarity.item()
  21. }
  22. )
  23. # 将列表按照similarity进行倒序排序
  24. similarity_map = sorted(similarity_map, key=lambda x: x['similarity'], reverse=True)
  25. return similarity_map
  26. class BrandCompareModelInference:
  27. def __init__(self):
  28. self._brand_model = BrandModel()
  29. def calculate_similarity(self, product_image, base_products):
  30. base_products_feats_map = self.get_base_products_feats_map(base_products)
  31. product_image_feat = self._brand_model.extract_image_feature(product_image)
  32. similarity_map = []
  33. for product in base_products_feats_map:
  34. similarity = product_image_feat @ product['image_feat'].t() * 100
  35. similarity_map.append(
  36. {
  37. 'image': product['image'],
  38. 'similarity': similarity.item()
  39. }
  40. )
  41. # 将列表按照similarity进行倒序排序
  42. similarity_map = sorted(similarity_map, key=lambda x: x['similarity'], reverse=True)
  43. return similarity_map
  44. def get_base_products_feats_map(self, base_products):
  45. feats_map = []
  46. for url in base_products:
  47. image = load_image_from_cos(url).resize((512, 512))
  48. feat = self._brand_model.extract_image_feature(image)
  49. feats_map.append(
  50. {
  51. 'image': url,
  52. 'image_feat': feat
  53. }
  54. )
  55. return feats_map
  56. if __name__ == '__main__':
  57. brand_name = '李宁'
  58. product_image_url = 'https://gw.alicdn.com/imgextra/O1CN015qx8Gw1Jdrhzk6y3v_!!3378851052.jpg_q95.jpg_.webp'
  59. product_image = load_image_from_url(product_image_url).resize((512, 512))
  60. inference = ClipModelInference(brand_name)
  61. similarity_map = inference.calulate_similarity(product_image)
  62. for item in similarity_map:
  63. print(item)