|
|
@@ -1,12 +1,12 @@
|
|
|
-from .train import ClipModelTrainer
|
|
|
+from .train import BrandModelTrainer
|
|
|
from utils.utils import load_image_from_url, load_image_from_cos
|
|
|
-from .clip_model import ClipModel
|
|
|
-class ClipModelInference:
|
|
|
+from .brand_model import BrandModel
|
|
|
+class BrandModelInference:
|
|
|
def __init__(self, brand_name):
|
|
|
self._products_feat_map = self._load_model(brand_name)
|
|
|
|
|
|
def _load_model(self, brand_name):
|
|
|
- self._trainer = ClipModelTrainer(brand_name)
|
|
|
+ self._trainer = BrandModelTrainer(brand_name)
|
|
|
return self._trainer.train()
|
|
|
|
|
|
def calulate_similarity(self, product_image):
|
|
|
@@ -27,13 +27,13 @@ class ClipModelInference:
|
|
|
similarity_map = sorted(similarity_map, key=lambda x: x['similarity'], reverse=True)
|
|
|
return similarity_map
|
|
|
|
|
|
-class ClipCompareModelInference:
|
|
|
+class BrandCompareModelInference:
|
|
|
def __init__(self):
|
|
|
- self._clip_model = ClipModel()
|
|
|
+ self._brand_model = BrandModel()
|
|
|
|
|
|
def calculate_similarity(self, product_image, base_products):
|
|
|
base_products_feats_map = self.get_base_products_feats_map(base_products)
|
|
|
- product_image_feat = self._clip_model.extract_image_feature(product_image)
|
|
|
+ product_image_feat = self._brand_model.extract_image_feature(product_image)
|
|
|
similarity_map = []
|
|
|
for product in base_products_feats_map:
|
|
|
similarity = product_image_feat @ product['image_feat'].t() * 100
|
|
|
@@ -53,7 +53,7 @@ class ClipCompareModelInference:
|
|
|
feats_map = []
|
|
|
for url in base_products:
|
|
|
image = load_image_from_cos(url).resize((512, 512))
|
|
|
- feat = self._clip_model.extract_image_feature(image)
|
|
|
+ feat = self._brand_model.extract_image_feature(image)
|
|
|
feats_map.append(
|
|
|
{
|
|
|
'image': url,
|