from .clip_model import ClipModel from db import MongoDao from tqdm import tqdm from utils.utils import load_image_from_url class ClipModelTrainer: def __init__(self, brand_name): self._brand_name = brand_name self._dao = MongoDao('ProductStandard') self._model = ClipModel() self._products_data = self.load_data() def load_data(self): """获取指定品牌的授权商品图片数据""" license_products_data = [] records = self._dao.get_records_by_query({"brand_name": self._brand_name}) for record in records: license_products_data.append( { 'product_name': record['product_name'], 'images': record['product_images'] } ) return license_products_data def train(self): """计算指定品牌所有授权商品的图片特征""" products_feature_map = [] for data in tqdm(self._products_data, desc=f'正在计算{self._brand_name}的授权产品图像特征...'): if len(data['images']) == 0: continue for image_url in data['images']: image = load_image_from_url(image_url).resize((512, 512)) feat = self._model.extract_image_feature(image) products_feature_map.append( { 'product_name': data['product_name'], 'image': image_url, 'image_feat': feat } ) return products_feature_map def extract_image_feat(self, image): return self._model.extract_image_feature(image) if __name__ == '__main__': trainer = ClipModelTrainer('李宁') res = trainer.train() for item in res: print(item)