train.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from .clip_model import ClipModel
  2. from db import MongoDao
  3. from tqdm import tqdm
  4. from utils.utils import load_image_from_url,load_image_from_cos
  5. class ClipModelTrainer:
  6. def __init__(self, brand_name):
  7. self._brand_name = brand_name
  8. self._dao = MongoDao('ProductStandard')
  9. self._model = ClipModel()
  10. self._products_data = self.load_data()
  11. def load_data(self):
  12. """获取指定品牌的授权商品图片数据"""
  13. license_products_data = []
  14. records = self._dao.get_records_by_query({"brand_name": self._brand_name})
  15. for record in records:
  16. license_products_data.append(
  17. {
  18. 'product_name': record['product_name'],
  19. 'images': record['product_images']
  20. }
  21. )
  22. return license_products_data
  23. def train(self):
  24. """计算指定品牌所有授权商品的图片特征"""
  25. products_feature_map = []
  26. for data in tqdm(self._products_data, desc=f'正在计算{self._brand_name}的授权产品图像特征...'):
  27. if len(data['images']) == 0:
  28. continue
  29. for image_url in data['images']:
  30. image = load_image_from_cos(image_url).resize((512, 512))
  31. feat = self._model.extract_image_feature(image)
  32. products_feature_map.append(
  33. {
  34. 'product_name': data['product_name'],
  35. 'image': image_url,
  36. 'image_feat': feat
  37. }
  38. )
  39. return products_feature_map
  40. def extract_image_feat(self, image):
  41. return self._model.extract_image_feature(image)
  42. if __name__ == '__main__':
  43. trainer = ClipModelTrainer('李宁')
  44. res = trainer.train()
  45. for item in res:
  46. print(item)