Image Semantic Segmentation

1. Overview

Source: https://www.robots.ox.ac.uk/~vgg/data/pets/

2. Fully Convolution Network (FCN)

2.1. Dataset

In this example, I will use Oxford-IIIT Pet dataset. We can access it by using tensorflow-datasets package.

pip install tensorflow-datasets

Now load the data

import tensorflow_datasets as tfds
import numpy as np

dataset, info = tfds.load('oxford_iiit_pet', with_info=True)

TRAIN_SIZE = info.splits['train'].num_examples
VALIDATION_SIZE = info.splits['test'].num_examples

BATCH_SIZE = 32
BUFFER_SIZE = 1000


def normalize(input_image, input_mask):
input_image = tf.cast(input_image, tf.float32) / 255.0
input_mask -= 1
return input_image, input_mask


def load_image(dataset_element):
input_image = tf.image.resize(dataset_element['image'], (256, 256))
input_mask = tf.image.resize(dataset_element['segmentation_mask'], (256, 256))
input_image, input_mask = normalize(input_image, input_mask)
return input_image, input_mask


train_dataset = (dataset['train']
.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
.prefetch(tf.data.experimental.AUTOTUNE))

test_dataset = (dataset['test']
.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.batch(BATCH_SIZE))

2.2. Define the model

import tensorflow as tf

inputs = tf.keras.layers.Input(shape=(256, 256, 3))
x = tf.keras.layers.Conv2D(
filters=64,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block1_conv1')(inputs)
x = tf.keras.layers.Conv2D(
filters=64,
kernel_size=(3, 3),
activation="relu",
padding="same",
name="block1_conv2")(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='block1_pool')(x)

x = tf.keras.layers.Conv2D(
filters=128,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block2_conv1')(x)
x = tf.keras.layers.Conv2D(
filters=128,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block2_conv2')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='block2_pool')(x)

x = tf.keras.layers.Conv2D(
filters=256,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block3_conv1')(x)
x = tf.keras.layers.Conv2D(
filters=256,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block3_conv2')(x)
x = tf.keras.layers.Conv2D(
filters=256,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block3_conv3')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='block3_pool')(x)
block3_pool = x

x = tf.keras.layers.Conv2D(
filters=512,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block4_conv1')(x)
x = tf.keras.layers.Conv2D(
filters=512,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block4_conv2')(x)
x = tf.keras.layers.Conv2D(
filters=512,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block4_conv3')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='block4_pool')(x)
block4_pool = x

x = tf.keras.layers.Conv2D(
filters=512,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block5_conv1')(x)
x = tf.keras.layers.Conv2D(
filters=512,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block5_conv2')(x)
x = tf.keras.layers.Conv2D(
filters=512,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='block5_conv3')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='block5_pool')(x)
block5_pool = x

outputs = tf.keras.layers.Conv2D(
filters=3,
kernel_size=(7, 7),
activation='relu',
padding='same',
name='conv6')(block5_pool)
conv6_4 = tf.keras.layers.Conv2DTranspose(
filters=3,
kernel_size=(4, 4),
strides=4,
use_bias=False)(outputs)

pool4_n = tf.keras.layers.Conv2D(
filters=3,
kernel_size=(1, 1),
activation='relu',
padding='same',
name='pool4_n')(block4_pool)
pool4_n_2 = tf.keras.layers.Conv2DTranspose(
filters=3,
kernel_size=(2, 2),
strides=2,
use_bias=False)(pool4_n)

pool3_n = tf.keras.layers.Conv2D(
filters=3,
kernel_size=(3, 3),
activation='relu',
padding='same',
name='pool3_n')(block3_pool)

output = tf.keras.layers.Add(name='add')([pool4_n_2, pool3_n, conv6_4])
output = tf.keras.layers.Conv2DTranspose(filters=3, kernel_size=(8, 8), strides=8, use_bias=False)(output)
output = tf.keras.layers.Softmax()(output)

fcn_model = tf.keras.models.Model(inputs, output)

fcn_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.RMSprop(),
metrics=['accuracy'])

fcn_model.summary()

In this step, we use SparseCategoricalCrossentropy as loss function for pixel classification task. The output channels here are 3 because each pixel can be categorized into one of three classes.

The network summary

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_2 (InputLayer) [(None, 256, 256, 3) 0
__________________________________________________________________________________________________
block1_conv1 (Conv2D) (None, 256, 256, 64) 1792 input_2[0][0]
__________________________________________________________________________________________________
block1_conv2 (Conv2D) (None, 256, 256, 64) 36928 block1_conv1[0][0]
__________________________________________________________________________________________________
block1_pool (MaxPooling2D) (None, 128, 128, 64) 0 block1_conv2[0][0]
__________________________________________________________________________________________________
block2_conv1 (Conv2D) (None, 128, 128, 128 73856 block1_pool[0][0]
__________________________________________________________________________________________________
block2_conv2 (Conv2D) (None, 128, 128, 128 147584 block2_conv1[0][0]
__________________________________________________________________________________________________
block2_pool (MaxPooling2D) (None, 64, 64, 128) 0 block2_conv2[0][0]
__________________________________________________________________________________________________
block3_conv1 (Conv2D) (None, 64, 64, 256) 295168 block2_pool[0][0]
__________________________________________________________________________________________________
block3_conv2 (Conv2D) (None, 64, 64, 256) 590080 block3_conv1[0][0]
__________________________________________________________________________________________________
block3_conv3 (Conv2D) (None, 64, 64, 256) 590080 block3_conv2[0][0]
__________________________________________________________________________________________________
block3_pool (MaxPooling2D) (None, 32, 32, 256) 0 block3_conv3[0][0]
__________________________________________________________________________________________________
block4_conv1 (Conv2D) (None, 32, 32, 512) 1180160 block3_pool[0][0]
__________________________________________________________________________________________________
block4_conv2 (Conv2D) (None, 32, 32, 512) 2359808 block4_conv1[0][0]
__________________________________________________________________________________________________
block4_conv3 (Conv2D) (None, 32, 32, 512) 2359808 block4_conv2[0][0]
__________________________________________________________________________________________________
block4_pool (MaxPooling2D) (None, 16, 16, 512) 0 block4_conv3[0][0]
__________________________________________________________________________________________________
block5_conv1 (Conv2D) (None, 16, 16, 512) 2359808 block4_pool[0][0]
__________________________________________________________________________________________________
block5_conv2 (Conv2D) (None, 16, 16, 512) 2359808 block5_conv1[0][0]
__________________________________________________________________________________________________
block5_conv3 (Conv2D) (None, 16, 16, 512) 2359808 block5_conv2[0][0]
__________________________________________________________________________________________________
block5_pool (MaxPooling2D) (None, 8, 8, 512) 0 block5_conv3[0][0]
__________________________________________________________________________________________________
pool4_n (Conv2D) (None, 16, 16, 3) 1539 block4_pool[0][0]
__________________________________________________________________________________________________
conv6 (Conv2D) (None, 8, 8, 3) 75267 block5_pool[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 32, 32, 3) 36 pool4_n[0][0]
__________________________________________________________________________________________________
pool3_n (Conv2D) (None, 32, 32, 3) 6915 block3_pool[0][0]
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 32, 32, 3) 144 conv6[0][0]
__________________________________________________________________________________________________
add (Add) (None, 32, 32, 3) 0 conv2d_transpose_1[0][0]
pool3_n[0][0]
conv2d_transpose[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 256, 256, 3) 576 add[0][0]
__________________________________________________________________________________________________
softmax (Softmax) (None, 256, 256, 3) 0 conv2d_transpose_2[0][0]
==================================================================================================
Total params: 14,799,165
Trainable params: 14,799,165
Non-trainable params: 0
__________________________________________________________________________________________________

2.3. Training

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath='./model.best.hdf5',
save_weights_only=False,
save_best_only=True,
monitor='val_loss')

hist = fcn_model.fit(
train_dataset,
epochs=120,
validation_data=test_dataset,
callbacks=[model_checkpoint_callback])

2.4. Experiment

def create_mask(prediction_mask):
prediction_mask = tf.argmax(prediction_mask, axis=-1)
prediction_mask = prediction_mask[..., tf.newaxis]
return prediction_mask

batch = next(iter(test_dataset))
results = fcn_model.predict(batch)
images, masks = batch

i = 0
for image, mask, result in zip(images, masks, results):
i += 1
image = tf.keras.preprocessing.image.array_to_img(image)
mask = tf.keras.preprocessing.image.array_to_img(mask)
output = tf.keras.preprocessing.image.array_to_img(create_mask(result))

image.save(f'image.{i:02d}.jpg')
mask.save(f'mask.{i:02d}.jpg')
output.save(f'output.{i:02d}.jpg')

Some results

ID Raw Image Ground truth Predicted
1
2
3
4

Read more about FCN here

3. U-Net

3.1. Dataset

I will use the same dataset Oxford-IIIT Pet in this example.

3.2. Define the model

We will need some building block

  • Down sampling block use normal convolution to down sample it inputs
def down_sample(filters, kernel_size, batch_norm=True):
initializer = tf.random_normal_initializer(0.0, 0.02)
layers = tf.keras.Sequential()
layers.add(tf.keras.layers.Conv2D(
filters=filters,
kernel_size=kernel_size,
strides=2,
padding='same',
kernel_initializer=initializer,
use_bias=False)
)

if batch_norm:
layers.add(tf.keras.layers.BatchNormalization())

layers.add(tf.keras.layers.LeakyReLU())
return layers
  • Up sampling block use transpose convolution to up sample it inputs
def up_sample(filters, kernel_size, dropout=False):
initializer = tf.random_normal_initializer(0.0, 0.02)
layers = tf.keras.Sequential()
layers.add(tf.keras.layers.Conv2DTranspose(
filters=filters,
kernel_size=kernel_size,
strides=2,
padding='same',
kernel_initializer=initializer,
use_bias=False)
)
layers.add(tf.keras.layers.BatchNormalization())

if dropout:
layers.add(tf.keras.layers.Dropout(0.5))

layers.add(tf.keras.layers.ReLU())
return layers

  • Construct the U-Net model
down_stack = [down_sample(64, 4, batch_norm=False)]
for filters in [128, 256, 512, 512, 512, 512, 512]:
down_stack.append(down_sample(filters, 4))

up_stack = []
for _ in range(3):
up_stack.append(up_sample(512, 4, dropout=True))

for filters in [512, 256, 128, 64]:
up_stack.append(up_sample(filters, 4))


inputs = tf.keras.layers.Input(shape=(256, 256, 3))
x = inputs
skip_layers = []

for down in down_stack:
x = down(x)
skip_layers.append(x)

skip_layers = reversed(skip_layers[:-1])

# We use skip connection to avoid vanishing gradient problem
for up, skip_connection in zip(up_stack, skip_layers):
x = up(x)
x = tf.keras.layers.Concatenate()([x, skip_connection])


N_CLASSES = 3

initializer = tf.random_normal_initializer(0.0, 0.02)
outputs = tf.keras.layers.Conv2DTranspose(
filters=N_CLASSES,
kernel_size=3,
strides=2,
padding='same',
kernel_initializer=initializer)(x)

model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.RMSprop(),
metrics=['accuracy'])

3.3. Training

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath='./unet.{val_loss:.9f}.hdf5',
save_weights_only=False,
save_best_only=True,
monitor='val_loss')

model.fit(train_dataset, epochs=50, validation_data=test_dataset, callbacks=[model_checkpoint_callback])

3.4. Experiment

Some results

ID Raw Image Ground truth Predicted
1
2
3
4

3.5. How it works?

To understand more about vanishing gradient, click here

Read more about U-Net

4. U-Net With Pretrain

4.1. Dataset

I will use the same dataset Oxford-IIIT Pet in this example.

4.2. Define model

  • The pretrain model MobileNetV2 will be used in this example.
pretrain_model = tf.keras.applications.MobileNetV2(
input_shape=(256, 256, 3),
include_top=False,
weights='imagenet')

target_values = [
'block_1_expand_relu',
'block_3_expand_relu',
'block_6_expand_relu',
'block_13_expand_relu',
'block_16_project'
]

layers = [pretrain_model.get_layer(layer).output for layer in target_values]

down_stack = tf.keras.models.Model(inputs=pretrain_model.input, outputs=layers)
down_stack.trainable = False

up_stack = []
for filters in [512, 256, 128, 64]:
up_stack.append(up_sample(filters, 4))

inputs = tf.keras.layers.Input(shape=(256, 256, 3))
x = inputs

skip_layers = down_stack(x)

x = skip_layers[-1]
skip_layers = reversed(skip_layers[:-1])

for up, skip_connection in zip(up_stack, skip_layers):
x = up(x)
x = tf.keras.layers.Concatenate()([x, skip_connection])


N_CLASSES = 3

initializer = tf.random_normal_initializer(0.0, 0.02)
outputs = tf.keras.layers.Conv2DTranspose(
filters=N_CLASSES,
kernel_size=3,
strides=2,
padding='same',
kernel_initializer=initializer)(x)

model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.RMSprop(),
metrics=['accuracy'])

4.3. Training

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath='./unet.pretrain.{val_loss:.9f}.hdf5',
save_weights_only=False,
save_best_only=True,
monitor='val_loss')

hist = model.fit(
train_dataset,
epochs=50,
validation_data=test_dataset,
callbacks=[model_checkpoint_callback])

4.4. Experiment

Some results

ID Raw Image Ground truth Predicted
1
2
3
4

5. Mask-RCNN

5.1. Load pretrain model from tensorflow hub

import tensorflow_hub as hub

MODEL_PATH = ('https://tfhub.dev/tensorflow/mask_rcnn/inception_resnet_v2_1024x1024/1')
mask_rcnn = hub.load(MODEL_PATH)

5.2. Install visualization utils from tensorflow models

  • Shell script
git clone –-depth 1 https://github.com/tensorflow/models
sudo apt install -y protobuf-compiler
cd models/research
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install -q .
  • Import visualization package
from object_detection.utils import ops
from object_detection.utils import visualization_utils as viz
from object_detection.utils.label_map_util import create_category_index_from_labelmap

5.3. Experiment

results = mask_rcnn(image)
model_output = {k: v.numpy() for k, v in results.items()}

detection_masks = model_output['detection_masks'][0]
detection_masks = tf.convert_to_tensor(detection_masks)

detection_boxes = model_output['detection_boxes'][0]
detection_boxes = tf.convert_to_tensor(detection_boxes)

detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
detection_masks,
detection_boxes,
image.shape[1],
image.shape[2])

detection_masks_reframed = tf.cast(detection_masks_reframed > 0.5, tf.uint8)

model_output['detection_masks_reframed'] = detection_masks_reframed.numpy()

boxes = model_output['detection_boxes'][0]

classes = model_output['detection_classes'][0].astype('int')
scores = model_output['detection_scores'][0]
masks = model_output['detection_masks_reframed']

image_with_mask = image.copy()
viz.visualize_boxes_and_labels_on_image_array(
image=image_with_mask[0],
boxes=boxes,
classes=classes,
scores=scores,
category_index=CATEGORY_IDX,
use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=0.30,
agnostic_mode=False,
instance_masks=masks,
line_thickness=5
)

plt.figure(figsize=(24, 32))
plt.imshow(image_with_mask[0])
plt.savefig(f'maskrcnn_output.jpg')
  • Some results
ID Raw Image Ground truth Predicted
1
2
3
4

5.4. See also

  • https://tfhub.dev/tensorflow/mask_rcnn/inception_resnet_v2_1024x1024/1
  • https://arxiv.org/abs/1703.06870