| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- from .clip_model import ClipModel
- from db import MongoDao
- from tqdm import tqdm
- from utils.utils import load_image_from_url
- class ClipModelTrainer:
- def __init__(self, brand_name):
- self._brand_name = brand_name
- self._dao = MongoDao('ProductStandard')
- self._model = ClipModel()
- self._products_data = self.load_data()
-
- def load_data(self):
- """获取指定品牌的授权商品图片数据"""
- license_products_data = []
- records = self._dao.get_records_by_query({"brand_name": self._brand_name})
- for record in records:
- license_products_data.append(
- {
- 'product_name': record['product_name'],
- 'images': record['product_images']
- }
- )
-
- return license_products_data
-
- def train(self):
- """计算指定品牌所有授权商品的图片特征"""
- products_feature_map = []
- for data in tqdm(self._products_data, desc=f'正在计算{self._brand_name}的授权产品图像特征...'):
- if len(data['images']) == 0:
- continue
- for image_url in data['images']:
- image = load_image_from_url(image_url).resize((512, 512))
- feat = self._model.extract_image_feature(image)
- products_feature_map.append(
- {
- 'product_name': data['product_name'],
- 'image': image_url,
- 'image_feat': feat
- }
- )
-
- return products_feature_map
-
- def extract_image_feat(self, image):
- return self._model.extract_image_feature(image)
-
- if __name__ == '__main__':
- trainer = ClipModelTrainer('李宁')
- res = trainer.train()
- for item in res:
- print(item)
-
|