import clip import torch import torch.nn.functional as F from utils.utils import load_image_from_url class ClipModel: def __init__(self): self._device = 'cuda:0' if torch.cuda.is_available() else 'cpu' self._model, self._preprocess = clip.load("ViT-L/14", device=self._device) self._tokenizer = clip.tokenize def extract_image_feature(self, image): """提取图像特征""" input_image = self._preprocess(image).unsqueeze(0).to(self._device) with torch.no_grad(): image_feat = self._model.encode_image(input_image) image_feat = F.normalize(image_feat, dim=-1) return image_feat def extract_text_feature(self, text): """提取文本特征""" input_text = self._tokenizer(text, context_length=self._model.context_length).to(self._device) with torch.no_grad(): text_feat = self._model.encode_text(input_text) text_feat = F.normalize(text_feat, dim=-1) return text_feat if __name__ == '__main__': clip_model = ClipModel() image_url1 = 'https://img.alicdn.com/imgextra/i3/2145866038/O1CN017JGmX21uTSPqUQmQ8_!!2145866038.jpg_q75.jpg_.webp' image1 = load_image_from_url(image_url1).resize((512, 512)) image_url2 = 'https://ai-public-1251740668.cos.ap-guangzhou.myqcloud.com/basicdata/lining_price_01.jpg' image2 = load_image_from_url(image_url2).resize((512, 512)) feat1 = clip_model.extract_image_feature(image1) feat2 = clip_model.extract_image_feature(image2) print(feat1 @ feat2.t() * 100)