|
@@ -0,0 +1,54 @@
|
|
|
|
|
+from .clip_model import ClipModel
|
|
|
|
|
+from db import MongoDao
|
|
|
|
|
+from tqdm import tqdm
|
|
|
|
|
+from utils.utils import load_image_from_url
|
|
|
|
|
+
|
|
|
|
|
+class ClipModelTrainer:
|
|
|
|
|
+ def __init__(self, brand_name):
|
|
|
|
|
+ self._brand_name = brand_name
|
|
|
|
|
+ self._dao = MongoDao('ProductStandard')
|
|
|
|
|
+ self._model = ClipModel()
|
|
|
|
|
+ self._products_data = self.load_data()
|
|
|
|
|
+
|
|
|
|
|
+ def load_data(self):
|
|
|
|
|
+ """获取指定品牌的授权商品图片数据"""
|
|
|
|
|
+ license_products_data = []
|
|
|
|
|
+ records = self._dao.get_records_by_query({"brand_name": self._brand_name})
|
|
|
|
|
+ for record in records:
|
|
|
|
|
+ license_products_data.append(
|
|
|
|
|
+ {
|
|
|
|
|
+ 'product_name': record['product_name'],
|
|
|
|
|
+ 'images': record['product_images']
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return license_products_data
|
|
|
|
|
+
|
|
|
|
|
+ def train(self):
|
|
|
|
|
+ """计算指定品牌所有授权商品的图片特征"""
|
|
|
|
|
+ products_feature_map = []
|
|
|
|
|
+ for data in tqdm(self._products_data, desc=f'正在计算{self._brand_name}的授权产品图像特征...'):
|
|
|
|
|
+ if len(data['images']) == 0:
|
|
|
|
|
+ continue
|
|
|
|
|
+ for image_url in data['images']:
|
|
|
|
|
+ image = load_image_from_url(image_url).resize((512, 512))
|
|
|
|
|
+ feat = self._model.extract_image_feature(image)
|
|
|
|
|
+ products_feature_map.append(
|
|
|
|
|
+ {
|
|
|
|
|
+ 'product_name': data['product_name'],
|
|
|
|
|
+ 'image': image_url,
|
|
|
|
|
+ 'image_feat': feat
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return products_feature_map
|
|
|
|
|
+
|
|
|
|
|
+ def extract_image_feat(self, image):
|
|
|
|
|
+ return self._model.extract_image_feature(image)
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
|
+ trainer = ClipModelTrainer('李宁')
|
|
|
|
|
+ res = trainer.train()
|
|
|
|
|
+ for item in res:
|
|
|
|
|
+ print(item)
|
|
|
|
|
+
|