brand_model.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import clip
  2. import torch
  3. import torch.nn.functional as F
  4. from utils.utils import load_image_from_url
  5. class BrandModel:
  6. def __init__(self):
  7. self._device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  8. self._model, self._preprocess = clip.load("./model/ckpt/brand_model.pt", device=self._device)
  9. self._tokenizer = clip.tokenize
  10. def extract_image_feature(self, image):
  11. """提取图像特征"""
  12. input_image = self._preprocess(image).unsqueeze(0).to(self._device)
  13. with torch.no_grad():
  14. image_feat = self._model.encode_image(input_image)
  15. image_feat = F.normalize(image_feat, dim=-1)
  16. return image_feat
  17. def extract_text_feature(self, text):
  18. """提取文本特征"""
  19. input_text = self._tokenizer(text, context_length=self._model.context_length).to(self._device)
  20. with torch.no_grad():
  21. text_feat = self._model.encode_text(input_text)
  22. text_feat = F.normalize(text_feat, dim=-1)
  23. return text_feat
  24. if __name__ == '__main__':
  25. clip_model = ClipModel()
  26. image_url1 = 'https://img.alicdn.com/imgextra/i3/2145866038/O1CN017JGmX21uTSPqUQmQ8_!!2145866038.jpg_q75.jpg_.webp'
  27. image1 = load_image_from_url(image_url1).resize((512, 512))
  28. image_url2 = 'https://ai-public-1251740668.cos.ap-guangzhou.myqcloud.com/basicdata/lining_price_01.jpg'
  29. image2 = load_image_from_url(image_url2).resize((512, 512))
  30. feat1 = clip_model.extract_image_feature(image1)
  31. feat2 = clip_model.extract_image_feature(image2)
  32. print(feat1 @ feat2.t() * 100)