from .train import ClipModelTrainer from utils.utils import load_image_from_url class ClipModelInference: def __init__(self, brand_name): self._products_feat_map = self._load_model(brand_name) def _load_model(self, brand_name): self._trainer = ClipModelTrainer(brand_name) return self._trainer.train() def calulate_similarity(self, product_image): """计算链接产品图像与授权商品产品图像相似度""" product_image_feat = self._trainer._model.extract_image_feature(product_image) similarity_map = [] for product in self._products_feat_map: similarity = product_image_feat @ product['image_feat'].t() * 100 similarity_map.append( { 'product_name': product['product_name'], 'image': product['image'], 'similarity': similarity.item() } ) # 将列表按照similarity进行倒序排序 similarity_map = sorted(similarity_map, key=lambda x: x['similarity'], reverse=True) return similarity_map if __name__ == '__main__': brand_name = '李宁' product_image_url = 'https://gw.alicdn.com/imgextra/O1CN015qx8Gw1Jdrhzk6y3v_!!3378851052.jpg_q95.jpg_.webp' product_image = load_image_from_url(product_image_url).resize((512, 512)) inference = ClipModelInference(brand_name) similarity_map = inference.calulate_similarity(product_image) for item in similarity_map: print(item)