| 123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- import clip
- import torch
- import torch.nn.functional as F
- from utils.utils import load_image_from_url
- class ClipModel:
- def __init__(self):
- self._device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
- self._model, self._preprocess = clip.load("ViT-L/14", device=self._device)
- self._tokenizer = clip.tokenize
-
- def extract_image_feature(self, image):
- """提取图像特征"""
- input_image = self._preprocess(image).unsqueeze(0).to(self._device)
- with torch.no_grad():
- image_feat = self._model.encode_image(input_image)
- image_feat = F.normalize(image_feat, dim=-1)
-
- return image_feat
-
- def extract_text_feature(self, text):
- """提取文本特征"""
- input_text = self._tokenizer(text, context_length=self._model.context_length).to(self._device)
- with torch.no_grad():
- text_feat = self._model.encode_text(input_text)
- text_feat = F.normalize(text_feat, dim=-1)
-
- return text_feat
-
-
- if __name__ == '__main__':
- clip_model = ClipModel()
- image_url1 = 'https://img.alicdn.com/imgextra/i3/2145866038/O1CN017JGmX21uTSPqUQmQ8_!!2145866038.jpg_q75.jpg_.webp'
- image1 = load_image_from_url(image_url1).resize((512, 512))
-
- image_url2 = 'https://ai-public-1251740668.cos.ap-guangzhou.myqcloud.com/basicdata/lining_price_01.jpg'
- image2 = load_image_from_url(image_url2).resize((512, 512))
-
- feat1 = clip_model.extract_image_feature(image1)
- feat2 = clip_model.extract_image_feature(image2)
-
-
- print(feat1 @ feat2.t() * 100)
-
-
|