AI

Image Augmentation: ImageDataGenerator & Albumentation

Heaea 2022. 5. 31. 10:48
반응형

image segmentation을 할 때 마스킹 된 이미지가 부족하면 augmentation을 해주어야 한다.

주의점은 augmentation을 할때 정답이미지와 원본이미지가 동일하게 masking되어야한다.

 

방법1

Tensorflow ImageDataGenrator 이용하기

ImageDataGenerator를 이용하면 각 이미지별로 랜덤으로 augmentation을 해주는데,

flow를 할 때 seed값을 잡아주면 동일하게 augmentation된 결과를 얻을 수 있다

seed = 909 # (IMPORTANT) to transform image and corresponding mask with same augmentation parameter.
image_datagen = ImageDataGenerator(width_shift_range=0.1,
                 height_shift_range=0.1,
                 preprocessing_function = image_preprocessing) # custom fuction for each image you can use resnet one too.
mask_datagen = ImageDataGenerator(width_shift_range=0.1,
                 height_shift_range=0.1,
                 preprocessing_function = mask_preprocessing)  # to make mask as feedable formate (256,256,1)

image_generator =image_datagen.flow_from_directory("dataset/image/",
                                                    class_mode=None, seed=seed)

mask_generator = mask_datagen.flow_from_directory("dataset/mask/",
                                                   class_mode=None, seed=seed)

train_generator = zip(image_generator, mask_generator)
 

ImageDataGenerator for semantic segmentation

I am trying to do semantic segmentation with Keras and when trying to load the images i get this error using flow_from_directory method. Found 0 images belonging to 0 classes. Found 0 images belon...

stackoverflow.com

 

방법2

Albumentation library 이용하기

import albumentations as A
import cv2

img = cv2.imread('./data/segmentation/train/CASE01_04.png')
mask = cv2.imread('./data/segmentation/trainannot/CASE01_04.png')

augmentation = A.Compose([
    A.Resize(320, 320), 
    A.RandomCrop(width = 128, height=128),
    A.HorizontalFlip(p=0.3)
])

transformed = augmentation(image=img, mask=mask)
## transformed는 dictionary 형태로 img와 mask를 출력
t_img = transformed['image']
t_mask = transformed['mask']

test = [t_img, t_mask[..., 0].squeeze()]
fig = plt.figure()
for i in range(2):
    ax = fig.add_subplot(1,2,i+1)
    ax.imshow(test[i])
    ax.axis('off')
plt.show()

 

 

Albumentations Documentation

Albumentations: fast and flexible image augmentations

albumentations.ai

 

반응형