Deep Learning using Python + Keras (Chapter 3): ResNet – DEVELOPPARADISE
18/06/2018

Deep Learning using Python + Keras (Chapter 3): ResNet


Introduction

This article doesn’t give you an introduction to deep learning. You are supposed to know the basis of deep learning and a little of Python coding. The main objective of this article is to introduce you to the basis of Keras framework and use with another known library to make a quick experiment and take the first conclusions.

Background

This article shows the ResNet architecture. Introduced by Microsoft, won the ILSVRC (ImageNet Large Scale Visual Recognition Challenge) in 2015. You can see the paper at https://arxiv.org/abs/1512.03385.

Deep Learning using Python + Keras (Chapter 3): ResNet

The key concept is to increase the layer number introducing a residual connection (with an identity layer). This layer go to the next layer directly, improving the learning proccess.

Deep Learning using Python + Keras (Chapter 3): ResNetWe will do the same experiment than previous chapters. I will not show the sections of loading the CIFAR-100 dataset, setting up the experiment and download python libraries. All are the same than previous chapter.

Using the code

Keras has this architecture at our disposal, but has the problem that, by default, the size of the images must be greater than 187 pixels, so we will define a smaller architecture.

def CustomResNet50(include_top=True, input_tensor=None, input_shape=(32,32,3), pooling=None, classes=100):     if input_tensor is None:         img_input = Input(shape=input_shape)     else:         if not K.is_keras_tensor(input_tensor):             img_input = Input(tensor=input_tensor, shape=input_shape)         else:             img_input = input_tensor     if K.image_data_format() == 'channels_last':         bn_axis = 3     else:         bn_axis = 1      x = ZeroPadding2D(padding=(2, 2), name='conv1_pad')(img_input)          x = resnet50.conv_block(x, 3, [32, 32, 64], stage=2, block='a')     x = resnet50.identity_block(x, 3, [32, 32, 64], stage=2, block='b')     x = resnet50.identity_block(x, 3, [32, 32, 64], stage=2, block='c')      x = resnet50.conv_block(x, 3, [64, 64, 256], stage=3, block='a', strides=(1, 1))     x = resnet50.identity_block(x, 3, [64, 64, 256], stage=3, block='b')     x = resnet50.identity_block(x, 3, [64, 64, 256], stage=3, block='c')      x = resnet50.conv_block(x, 3, [128, 128, 512], stage=4, block='a')     x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='b')     x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='c')     x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='d')      x = resnet50.conv_block(x, 3, [256, 256, 1024], stage=5, block='a')     x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='b')     x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='c')     x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='d')     x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='e')     x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='f')      x = resnet50.conv_block(x, 3, [512, 512, 2048], stage=6, block='a')     x = resnet50.identity_block(x, 3, [512, 512, 2048], stage=6, block='b')     x = resnet50.identity_block(x, 3, [512, 512, 2048], stage=6, block='c')      x = AveragePooling2D((1, 1), name='avg_pool')(x)      if include_top:         x = Flatten()(x)         x = Dense(classes, activation='softmax', name='fc1000')(x)     else:         if pooling == 'avg':             x = GlobalAveragePooling2D()(x)         elif pooling == 'max':             x = GlobalMaxPooling2D()(x)      # Ensure that the model takes into account     # any potential predecessors of `input_tensor`.     if input_tensor is not None:         inputs = get_source_inputs(input_tensor)     else:         inputs = img_input     # Create model.     model = Model(inputs, x, name='resnet50')      return model

Like previous articles, we compile with same parameters

def create_custom_resnet50():   model = CustomResNet50(include_top=True, input_tensor=None, input_shape=(32,32,3), pooling=None, classes=100)      return model    custom_resnet50_model = create_custom_resnet50() custom_resnet50_model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['acc', 'mse'])

Once it’s done, we can see a summary of the model created.

custom_resnet50_model.summary()

 

__________________________________________________________________________________________________ Layer (type)                    Output Shape         Param #     Connected to                      ================================================================================================== input_1 (InputLayer)            (None, 32, 32, 3)    0                                             __________________________________________________________________________________________________ conv1_pad (ZeroPadding2D)       (None, 36, 36, 3)    0           input_1[0][0]                     __________________________________________________________________________________________________ res2a_branch2a (Conv2D)         (None, 18, 18, 32)   128         conv1_pad[0][0]                   __________________________________________________________________________________________________ bn2a_branch2a (BatchNormalizati (None, 18, 18, 32)   128         res2a_branch2a[0][0]              __________________________________________________________________________________________________ activation_1 (Activation)       (None, 18, 18, 32)   0           bn2a_branch2a[0][0]               __________________________________________________________________________________________________ res2a_branch2b (Conv2D)         (None, 18, 18, 32)   9248        activation_1[0][0]                __________________________________________________________________________________________________ bn2a_branch2b (BatchNormalizati (None, 18, 18, 32)   128         res2a_branch2b[0][0]              __________________________________________________________________________________________________ activation_2 (Activation)       (None, 18, 18, 32)   0           bn2a_branch2b[0][0]               __________________________________________________________________________________________________ res2a_branch2c (Conv2D)         (None, 18, 18, 64)   2112        activation_2[0][0]                __________________________________________________________________________________________________ res2a_branch1 (Conv2D)          (None, 18, 18, 64)   256         conv1_pad[0][0]                   __________________________________________________________________________________________________ bn2a_branch2c (BatchNormalizati (None, 18, 18, 64)   256         res2a_branch2c[0][0]              __________________________________________________________________________________________________ bn2a_branch1 (BatchNormalizatio (None, 18, 18, 64)   256         res2a_branch1[0][0]               __________________________________________________________________________________________________ add_1 (Add)                     (None, 18, 18, 64)   0           bn2a_branch2c[0][0]                                                                                bn2a_branch1[0][0]                __________________________________________________________________________________________________ activation_3 (Activation)       (None, 18, 18, 64)   0           add_1[0][0]                       __________________________________________________________________________________________________ res2b_branch2a (Conv2D)         (None, 18, 18, 32)   2080        activation_3[0][0]                __________________________________________________________________________________________________ bn2b_branch2a (BatchNormalizati (None, 18, 18, 32)   128         res2b_branch2a[0][0]              __________________________________________________________________________________________________ activation_4 (Activation)       (None, 18, 18, 32)   0           bn2b_branch2a[0][0]               __________________________________________________________________________________________________ res2b_branch2b (Conv2D)         (None, 18, 18, 32)   9248        activation_4[0][0]                __________________________________________________________________________________________________ bn2b_branch2b (BatchNormalizati (None, 18, 18, 32)   128         res2b_branch2b[0][0]              __________________________________________________________________________________________________ activation_5 (Activation)       (None, 18, 18, 32)   0           bn2b_branch2b[0][0]               __________________________________________________________________________________________________ res2b_branch2c (Conv2D)         (None, 18, 18, 64)   2112        activation_5[0][0]                __________________________________________________________________________________________________ bn2b_branch2c (BatchNormalizati (None, 18, 18, 64)   256         res2b_branch2c[0][0]              __________________________________________________________________________________________________ add_2 (Add)                     (None, 18, 18, 64)   0           bn2b_branch2c[0][0]                                                                                activation_3[0][0]                __________________________________________________________________________________________________ activation_6 (Activation)       (None, 18, 18, 64)   0           add_2[0][0]                       __________________________________________________________________________________________________ res2c_branch2a (Conv2D)         (None, 18, 18, 32)   2080        activation_6[0][0]                __________________________________________________________________________________________________ bn2c_branch2a (BatchNormalizati (None, 18, 18, 32)   128         res2c_branch2a[0][0]              __________________________________________________________________________________________________ activation_7 (Activation)       (None, 18, 18, 32)   0           bn2c_branch2a[0][0]               __________________________________________________________________________________________________ res2c_branch2b (Conv2D)         (None, 18, 18, 32)   9248        activation_7[0][0]                __________________________________________________________________________________________________ bn2c_branch2b (BatchNormalizati (None, 18, 18, 32)   128         res2c_branch2b[0][0]              __________________________________________________________________________________________________ activation_8 (Activation)       (None, 18, 18, 32)   0           bn2c_branch2b[0][0]               __________________________________________________________________________________________________ res2c_branch2c (Conv2D)         (None, 18, 18, 64)   2112        activation_8[0][0]                __________________________________________________________________________________________________ bn2c_branch2c (BatchNormalizati (None, 18, 18, 64)   256         res2c_branch2c[0][0]              __________________________________________________________________________________________________ add_3 (Add)                     (None, 18, 18, 64)   0           bn2c_branch2c[0][0]                                                                                activation_6[0][0]                __________________________________________________________________________________________________ activation_9 (Activation)       (None, 18, 18, 64)   0           add_3[0][0]                       __________________________________________________________________________________________________ res3a_branch2a (Conv2D)         (None, 18, 18, 64)   4160        activation_9[0][0]                __________________________________________________________________________________________________ bn3a_branch2a (BatchNormalizati (None, 18, 18, 64)   256         res3a_branch2a[0][0]              __________________________________________________________________________________________________ activation_10 (Activation)      (None, 18, 18, 64)   0           bn3a_branch2a[0][0]               __________________________________________________________________________________________________ res3a_branch2b (Conv2D)         (None, 18, 18, 64)   36928       activation_10[0][0]               __________________________________________________________________________________________________ bn3a_branch2b (BatchNormalizati (None, 18, 18, 64)   256         res3a_branch2b[0][0]              __________________________________________________________________________________________________ activation_11 (Activation)      (None, 18, 18, 64)   0           bn3a_branch2b[0][0]               __________________________________________________________________________________________________ res3a_branch2c (Conv2D)         (None, 18, 18, 256)  16640       activation_11[0][0]               __________________________________________________________________________________________________ res3a_branch1 (Conv2D)          (None, 18, 18, 256)  16640       activation_9[0][0]                __________________________________________________________________________________________________ bn3a_branch2c (BatchNormalizati (None, 18, 18, 256)  1024        res3a_branch2c[0][0]              __________________________________________________________________________________________________ bn3a_branch1 (BatchNormalizatio (None, 18, 18, 256)  1024        res3a_branch1[0][0]               __________________________________________________________________________________________________ add_4 (Add)                     (None, 18, 18, 256)  0           bn3a_branch2c[0][0]                                                                                bn3a_branch1[0][0]                __________________________________________________________________________________________________ activation_12 (Activation)      (None, 18, 18, 256)  0           add_4[0][0]                       __________________________________________________________________________________________________ res3b_branch2a (Conv2D)         (None, 18, 18, 64)   16448       activation_12[0][0]               __________________________________________________________________________________________________ bn3b_branch2a (BatchNormalizati (None, 18, 18, 64)   256         res3b_branch2a[0][0]              __________________________________________________________________________________________________ activation_13 (Activation)      (None, 18, 18, 64)   0           bn3b_branch2a[0][0]               __________________________________________________________________________________________________ res3b_branch2b (Conv2D)         (None, 18, 18, 64)   36928       activation_13[0][0]               __________________________________________________________________________________________________ bn3b_branch2b (BatchNormalizati (None, 18, 18, 64)   256         res3b_branch2b[0][0]              __________________________________________________________________________________________________ activation_14 (Activation)      (None, 18, 18, 64)   0           bn3b_branch2b[0][0]               __________________________________________________________________________________________________ res3b_branch2c (Conv2D)         (None, 18, 18, 256)  16640       activation_14[0][0]               __________________________________________________________________________________________________ bn3b_branch2c (BatchNormalizati (None, 18, 18, 256)  1024        res3b_branch2c[0][0]              __________________________________________________________________________________________________ add_5 (Add)                     (None, 18, 18, 256)  0           bn3b_branch2c[0][0]                                                                                activation_12[0][0]               __________________________________________________________________________________________________ activation_15 (Activation)      (None, 18, 18, 256)  0           add_5[0][0]                       __________________________________________________________________________________________________ res3c_branch2a (Conv2D)         (None, 18, 18, 64)   16448       activation_15[0][0]               __________________________________________________________________________________________________ bn3c_branch2a (BatchNormalizati (None, 18, 18, 64)   256         res3c_branch2a[0][0]              __________________________________________________________________________________________________ activation_16 (Activation)      (None, 18, 18, 64)   0           bn3c_branch2a[0][0]               __________________________________________________________________________________________________ res3c_branch2b (Conv2D)         (None, 18, 18, 64)   36928       activation_16[0][0]               __________________________________________________________________________________________________ bn3c_branch2b (BatchNormalizati (None, 18, 18, 64)   256         res3c_branch2b[0][0]              __________________________________________________________________________________________________ activation_17 (Activation)      (None, 18, 18, 64)   0           bn3c_branch2b[0][0]               __________________________________________________________________________________________________ res3c_branch2c (Conv2D)         (None, 18, 18, 256)  16640       activation_17[0][0]               __________________________________________________________________________________________________ bn3c_branch2c (BatchNormalizati (None, 18, 18, 256)  1024        res3c_branch2c[0][0]              __________________________________________________________________________________________________ add_6 (Add)                     (None, 18, 18, 256)  0           bn3c_branch2c[0][0]                                                                                activation_15[0][0]               __________________________________________________________________________________________________ activation_18 (Activation)      (None, 18, 18, 256)  0           add_6[0][0]                       __________________________________________________________________________________________________ res4a_branch2a (Conv2D)         (None, 9, 9, 128)    32896       activation_18[0][0]               __________________________________________________________________________________________________ bn4a_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4a_branch2a[0][0]              __________________________________________________________________________________________________ activation_19 (Activation)      (None, 9, 9, 128)    0           bn4a_branch2a[0][0]               __________________________________________________________________________________________________ res4a_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_19[0][0]               __________________________________________________________________________________________________ bn4a_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4a_branch2b[0][0]              __________________________________________________________________________________________________ activation_20 (Activation)      (None, 9, 9, 128)    0           bn4a_branch2b[0][0]               __________________________________________________________________________________________________ res4a_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_20[0][0]               __________________________________________________________________________________________________ res4a_branch1 (Conv2D)          (None, 9, 9, 512)    131584      activation_18[0][0]               __________________________________________________________________________________________________ bn4a_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4a_branch2c[0][0]              __________________________________________________________________________________________________ bn4a_branch1 (BatchNormalizatio (None, 9, 9, 512)    2048        res4a_branch1[0][0]               __________________________________________________________________________________________________ add_7 (Add)                     (None, 9, 9, 512)    0           bn4a_branch2c[0][0]                                                                                bn4a_branch1[0][0]                __________________________________________________________________________________________________ activation_21 (Activation)      (None, 9, 9, 512)    0           add_7[0][0]                       __________________________________________________________________________________________________ res4b_branch2a (Conv2D)         (None, 9, 9, 128)    65664       activation_21[0][0]               __________________________________________________________________________________________________ bn4b_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4b_branch2a[0][0]              __________________________________________________________________________________________________ activation_22 (Activation)      (None, 9, 9, 128)    0           bn4b_branch2a[0][0]               __________________________________________________________________________________________________ res4b_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_22[0][0]               __________________________________________________________________________________________________ bn4b_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4b_branch2b[0][0]              __________________________________________________________________________________________________ activation_23 (Activation)      (None, 9, 9, 128)    0           bn4b_branch2b[0][0]               __________________________________________________________________________________________________ res4b_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_23[0][0]               __________________________________________________________________________________________________ bn4b_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4b_branch2c[0][0]              __________________________________________________________________________________________________ add_8 (Add)                     (None, 9, 9, 512)    0           bn4b_branch2c[0][0]                                                                                activation_21[0][0]               __________________________________________________________________________________________________ activation_24 (Activation)      (None, 9, 9, 512)    0           add_8[0][0]                       __________________________________________________________________________________________________ res4c_branch2a (Conv2D)         (None, 9, 9, 128)    65664       activation_24[0][0]               __________________________________________________________________________________________________ bn4c_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4c_branch2a[0][0]              __________________________________________________________________________________________________ activation_25 (Activation)      (None, 9, 9, 128)    0           bn4c_branch2a[0][0]               __________________________________________________________________________________________________ res4c_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_25[0][0]               __________________________________________________________________________________________________ bn4c_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4c_branch2b[0][0]              __________________________________________________________________________________________________ activation_26 (Activation)      (None, 9, 9, 128)    0           bn4c_branch2b[0][0]               __________________________________________________________________________________________________ res4c_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_26[0][0]               __________________________________________________________________________________________________ bn4c_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4c_branch2c[0][0]              __________________________________________________________________________________________________ add_9 (Add)                     (None, 9, 9, 512)    0           bn4c_branch2c[0][0]                                                                                activation_24[0][0]               __________________________________________________________________________________________________ activation_27 (Activation)      (None, 9, 9, 512)    0           add_9[0][0]                       __________________________________________________________________________________________________ res4d_branch2a (Conv2D)         (None, 9, 9, 128)    65664       activation_27[0][0]               __________________________________________________________________________________________________ bn4d_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4d_branch2a[0][0]              __________________________________________________________________________________________________ activation_28 (Activation)      (None, 9, 9, 128)    0           bn4d_branch2a[0][0]               __________________________________________________________________________________________________ res4d_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_28[0][0]               __________________________________________________________________________________________________ bn4d_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4d_branch2b[0][0]              __________________________________________________________________________________________________ activation_29 (Activation)      (None, 9, 9, 128)    0           bn4d_branch2b[0][0]               __________________________________________________________________________________________________ res4d_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_29[0][0]               __________________________________________________________________________________________________ bn4d_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4d_branch2c[0][0]              __________________________________________________________________________________________________ add_10 (Add)                    (None, 9, 9, 512)    0           bn4d_branch2c[0][0]                                                                                activation_27[0][0]               __________________________________________________________________________________________________ activation_30 (Activation)      (None, 9, 9, 512)    0           add_10[0][0]                      __________________________________________________________________________________________________ res5a_branch2a (Conv2D)         (None, 5, 5, 256)    131328      activation_30[0][0]               __________________________________________________________________________________________________ bn5a_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5a_branch2a[0][0]              __________________________________________________________________________________________________ activation_31 (Activation)      (None, 5, 5, 256)    0           bn5a_branch2a[0][0]               __________________________________________________________________________________________________ res5a_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_31[0][0]               __________________________________________________________________________________________________ bn5a_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5a_branch2b[0][0]              __________________________________________________________________________________________________ activation_32 (Activation)      (None, 5, 5, 256)    0           bn5a_branch2b[0][0]               __________________________________________________________________________________________________ res5a_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_32[0][0]               __________________________________________________________________________________________________ res5a_branch1 (Conv2D)          (None, 5, 5, 1024)   525312      activation_30[0][0]               __________________________________________________________________________________________________ bn5a_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5a_branch2c[0][0]              __________________________________________________________________________________________________ bn5a_branch1 (BatchNormalizatio (None, 5, 5, 1024)   4096        res5a_branch1[0][0]               __________________________________________________________________________________________________ add_11 (Add)                    (None, 5, 5, 1024)   0           bn5a_branch2c[0][0]                                                                                bn5a_branch1[0][0]                __________________________________________________________________________________________________ activation_33 (Activation)      (None, 5, 5, 1024)   0           add_11[0][0]                      __________________________________________________________________________________________________ res5b_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_33[0][0]               __________________________________________________________________________________________________ bn5b_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5b_branch2a[0][0]              __________________________________________________________________________________________________ activation_34 (Activation)      (None, 5, 5, 256)    0           bn5b_branch2a[0][0]               __________________________________________________________________________________________________ res5b_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_34[0][0]               __________________________________________________________________________________________________ bn5b_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5b_branch2b[0][0]              __________________________________________________________________________________________________ activation_35 (Activation)      (None, 5, 5, 256)    0           bn5b_branch2b[0][0]               __________________________________________________________________________________________________ res5b_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_35[0][0]               __________________________________________________________________________________________________ bn5b_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5b_branch2c[0][0]              __________________________________________________________________________________________________ add_12 (Add)                    (None, 5, 5, 1024)   0           bn5b_branch2c[0][0]                                                                                activation_33[0][0]               __________________________________________________________________________________________________ activation_36 (Activation)      (None, 5, 5, 1024)   0           add_12[0][0]                      __________________________________________________________________________________________________ res5c_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_36[0][0]               __________________________________________________________________________________________________ bn5c_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5c_branch2a[0][0]              __________________________________________________________________________________________________ activation_37 (Activation)      (None, 5, 5, 256)    0           bn5c_branch2a[0][0]               __________________________________________________________________________________________________ res5c_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_37[0][0]               __________________________________________________________________________________________________ bn5c_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5c_branch2b[0][0]              __________________________________________________________________________________________________ activation_38 (Activation)      (None, 5, 5, 256)    0           bn5c_branch2b[0][0]               __________________________________________________________________________________________________ res5c_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_38[0][0]               __________________________________________________________________________________________________ bn5c_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5c_branch2c[0][0]              __________________________________________________________________________________________________ add_13 (Add)                    (None, 5, 5, 1024)   0           bn5c_branch2c[0][0]                                                                                activation_36[0][0]               __________________________________________________________________________________________________ activation_39 (Activation)      (None, 5, 5, 1024)   0           add_13[0][0]                      __________________________________________________________________________________________________ res5d_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_39[0][0]               __________________________________________________________________________________________________ bn5d_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5d_branch2a[0][0]              __________________________________________________________________________________________________ activation_40 (Activation)      (None, 5, 5, 256)    0           bn5d_branch2a[0][0]               __________________________________________________________________________________________________ res5d_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_40[0][0]               __________________________________________________________________________________________________ bn5d_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5d_branch2b[0][0]              __________________________________________________________________________________________________ activation_41 (Activation)      (None, 5, 5, 256)    0           bn5d_branch2b[0][0]               __________________________________________________________________________________________________ res5d_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_41[0][0]               __________________________________________________________________________________________________ bn5d_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5d_branch2c[0][0]              __________________________________________________________________________________________________ add_14 (Add)                    (None, 5, 5, 1024)   0           bn5d_branch2c[0][0]                                                                                activation_39[0][0]               __________________________________________________________________________________________________ activation_42 (Activation)      (None, 5, 5, 1024)   0           add_14[0][0]                      __________________________________________________________________________________________________ res5e_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_42[0][0]               __________________________________________________________________________________________________ bn5e_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5e_branch2a[0][0]              __________________________________________________________________________________________________ activation_43 (Activation)      (None, 5, 5, 256)    0           bn5e_branch2a[0][0]               __________________________________________________________________________________________________ res5e_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_43[0][0]               __________________________________________________________________________________________________ bn5e_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5e_branch2b[0][0]              __________________________________________________________________________________________________ activation_44 (Activation)      (None, 5, 5, 256)    0           bn5e_branch2b[0][0]               __________________________________________________________________________________________________ res5e_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_44[0][0]               __________________________________________________________________________________________________ bn5e_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5e_branch2c[0][0]              __________________________________________________________________________________________________ add_15 (Add)                    (None, 5, 5, 1024)   0           bn5e_branch2c[0][0]                                                                                activation_42[0][0]               __________________________________________________________________________________________________ activation_45 (Activation)      (None, 5, 5, 1024)   0           add_15[0][0]                      __________________________________________________________________________________________________ res5f_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_45[0][0]               __________________________________________________________________________________________________ bn5f_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5f_branch2a[0][0]              __________________________________________________________________________________________________ activation_46 (Activation)      (None, 5, 5, 256)    0           bn5f_branch2a[0][0]               __________________________________________________________________________________________________ res5f_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_46[0][0]               __________________________________________________________________________________________________ bn5f_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5f_branch2b[0][0]              __________________________________________________________________________________________________ activation_47 (Activation)      (None, 5, 5, 256)    0           bn5f_branch2b[0][0]               __________________________________________________________________________________________________ res5f_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_47[0][0]               __________________________________________________________________________________________________ bn5f_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5f_branch2c[0][0]              __________________________________________________________________________________________________ add_16 (Add)                    (None, 5, 5, 1024)   0           bn5f_branch2c[0][0]                                                                                activation_45[0][0]               __________________________________________________________________________________________________ activation_48 (Activation)      (None, 5, 5, 1024)   0           add_16[0][0]                      __________________________________________________________________________________________________ res6a_branch2a (Conv2D)         (None, 3, 3, 512)    524800      activation_48[0][0]               __________________________________________________________________________________________________ bn6a_branch2a (BatchNormalizati (None, 3, 3, 512)    2048        res6a_branch2a[0][0]              __________________________________________________________________________________________________ activation_49 (Activation)      (None, 3, 3, 512)    0           bn6a_branch2a[0][0]               __________________________________________________________________________________________________ res6a_branch2b (Conv2D)         (None, 3, 3, 512)    2359808     activation_49[0][0]               __________________________________________________________________________________________________ bn6a_branch2b (BatchNormalizati (None, 3, 3, 512)    2048        res6a_branch2b[0][0]              __________________________________________________________________________________________________ activation_50 (Activation)      (None, 3, 3, 512)    0           bn6a_branch2b[0][0]               __________________________________________________________________________________________________ res6a_branch2c (Conv2D)         (None, 3, 3, 2048)   1050624     activation_50[0][0]               __________________________________________________________________________________________________ res6a_branch1 (Conv2D)          (None, 3, 3, 2048)   2099200     activation_48[0][0]               __________________________________________________________________________________________________ bn6a_branch2c (BatchNormalizati (None, 3, 3, 2048)   8192        res6a_branch2c[0][0]              __________________________________________________________________________________________________ bn6a_branch1 (BatchNormalizatio (None, 3, 3, 2048)   8192        res6a_branch1[0][0]               __________________________________________________________________________________________________ add_17 (Add)                    (None, 3, 3, 2048)   0           bn6a_branch2c[0][0]                                                                                bn6a_branch1[0][0]                __________________________________________________________________________________________________ activation_51 (Activation)      (None, 3, 3, 2048)   0           add_17[0][0]                      __________________________________________________________________________________________________ res6b_branch2a (Conv2D)         (None, 3, 3, 512)    1049088     activation_51[0][0]               __________________________________________________________________________________________________ bn6b_branch2a (BatchNormalizati (None, 3, 3, 512)    2048        res6b_branch2a[0][0]              __________________________________________________________________________________________________ activation_52 (Activation)      (None, 3, 3, 512)    0           bn6b_branch2a[0][0]               __________________________________________________________________________________________________ res6b_branch2b (Conv2D)         (None, 3, 3, 512)    2359808     activation_52[0][0]               __________________________________________________________________________________________________ bn6b_branch2b (BatchNormalizati (None, 3, 3, 512)    2048        res6b_branch2b[0][0]              __________________________________________________________________________________________________ activation_53 (Activation)      (None, 3, 3, 512)    0           bn6b_branch2b[0][0]               __________________________________________________________________________________________________ res6b_branch2c (Conv2D)         (None, 3, 3, 2048)   1050624     activation_53[0][0]               __________________________________________________________________________________________________ bn6b_branch2c (BatchNormalizati (None, 3, 3, 2048)   8192        res6b_branch2c[0][0]              __________________________________________________________________________________________________ add_18 (Add)                    (None, 3, 3, 2048)   0           bn6b_branch2c[0][0]                                                                                activation_51[0][0]               __________________________________________________________________________________________________ activation_54 (Activation)      (None, 3, 3, 2048)   0           add_18[0][0]                      __________________________________________________________________________________________________ res6c_branch2a (Conv2D)         (None, 3, 3, 512)    1049088     activation_54[0][0]               __________________________________________________________________________________________________ bn6c_branch2a (BatchNormalizati (None, 3, 3, 512)    2048        res6c_branch2a[0][0]              __________________________________________________________________________________________________ activation_55 (Activation)      (None, 3, 3, 512)    0           bn6c_branch2a[0][0]               __________________________________________________________________________________________________ res6c_branch2b (Conv2D)         (None, 3, 3, 512)    2359808     activation_55[0][0]               __________________________________________________________________________________________________ bn6c_branch2b (BatchNormalizati (None, 3, 3, 512)    2048        res6c_branch2b[0][0]              __________________________________________________________________________________________________ activation_56 (Activation)      (None, 3, 3, 512)    0           bn6c_branch2b[0][0]               __________________________________________________________________________________________________ res6c_branch2c (Conv2D)         (None, 3, 3, 2048)   1050624     activation_56[0][0]               __________________________________________________________________________________________________ bn6c_branch2c (BatchNormalizati (None, 3, 3, 2048)   8192        res6c_branch2c[0][0]              __________________________________________________________________________________________________ add_19 (Add)                    (None, 3, 3, 2048)   0           bn6c_branch2c[0][0]                                                                                activation_54[0][0]               __________________________________________________________________________________________________ activation_57 (Activation)      (None, 3, 3, 2048)   0           add_19[0][0]                      __________________________________________________________________________________________________ avg_pool (AveragePooling2D)     (None, 3, 3, 2048)   0           activation_57[0][0]               __________________________________________________________________________________________________ flatten_1 (Flatten)             (None, 18432)        0           avg_pool[0][0]                    __________________________________________________________________________________________________ fc1000 (Dense)                  (None, 100)          1843300     flatten_1[0][0]                   ================================================================================================== Total params: 25,461,700 Trainable params: 25,407,812 Non-trainable params: 53,888 __________________________________________________________________________________________________

Then, the next step is to train the model.

crn50 = custom_resnet50_model.fit(x=x_train, y=y_train, batch_size=32, epochs=10, verbose=1, validation_data=(x_test, y_test), shuffle=True)   Train on 50000 samples, validate on 10000 samples Epoch 1/10  50000/50000 [==============================] - 441s 9ms/step - loss: 4.5655 - acc: 0.0817 - mean_squared_error: 0.0101 - val_loss: 4.2085 - val_acc: 0.1228 - val_mean_squared_error: 0.0099 Epoch 2/10  50000/50000 [==============================] - 434s 9ms/step - loss: 4.1448 - acc: 0.1348 - mean_squared_error: 0.0098 - val_loss: 4.2032 - val_acc: 0.1236 - val_mean_squared_error: 0.0099 Epoch 3/10  50000/50000 [==============================] - 433s 9ms/step - loss: 4.2682 - acc: 0.1146 - mean_squared_error: 0.0099 - val_loss: 4.3306 - val_acc: 0.1066 - val_mean_squared_error: 0.0100 Epoch 4/10  50000/50000 [==============================] - 434s 9ms/step - loss: 4.1581 - acc: 0.1340 - mean_squared_error: 0.0098 - val_loss: 4.1405 - val_acc: 0.1384 - val_mean_squared_error: 0.0098 Epoch 5/10  50000/50000 [==============================] - 431s 9ms/step - loss: 3.9395 - acc: 0.1653 - mean_squared_error: 0.0096 - val_loss: 3.8838 - val_acc: 0.1718 - val_mean_squared_error: 0.0095 Epoch 6/10  50000/50000 [==============================] - 432s 9ms/step - loss: 3.9598 - acc: 0.1698 - mean_squared_error: 0.0096 - val_loss: 4.0047 - val_acc: 0.1608 - val_mean_squared_error: 0.0096 Epoch 7/10  50000/50000 [==============================] - 433s 9ms/step - loss: 3.8715 - acc: 0.1797 - mean_squared_error: 0.0095 - val_loss: 4.2620 - val_acc: 0.1184 - val_mean_squared_error: 0.0099 Epoch 8/10  50000/50000 [==============================] - 434s 9ms/step - loss: 3.9661 - acc: 0.1666 - mean_squared_error: 0.0096 - val_loss: 3.8181 - val_acc: 0.1898 - val_mean_squared_error: 0.0095 Epoch 9/10  50000/50000 [==============================] - 434s 9ms/step - loss: 3.8110 - acc: 0.1901 - mean_squared_error: 0.0095 - val_loss: 3.7521 - val_acc: 0.1966 - val_mean_squared_error: 0.0094 Epoch 10/10  50000/50000 [==============================] - 432s 9ms/step - loss: 3.7247 - acc: 0.2048 - mean_squared_error: 0.0094 - val_loss: 3.8206 - val_acc: 0.1929 - val_mean_squared_error: 0.0095

Let’s see the metrics for the train and test results graphically (using matplotlib library, of course).

plt.figure(0) plt.plot(crn50.history['acc'],'r') plt.plot(crn50.history['val_acc'],'g') plt.xticks(np.arange(0, 11, 2.0)) plt.rcParams['figure.figsize'] = (8, 6) plt.xlabel("Num of Epochs") plt.ylabel("Accuracy") plt.title("Training Accuracy vs Validation Accuracy") plt.legend(['train','validation'])   plt.figure(1) plt.plot(crn50.history['loss'],'r') plt.plot(crn50.history['val_loss'],'g') plt.xticks(np.arange(0, 11, 2.0)) plt.rcParams['figure.figsize'] = (8, 6) plt.xlabel("Num of Epochs") plt.ylabel("Loss") plt.title("Training Loss vs Validation Loss") plt.legend(['train','validation'])   plt.show()

Deep Learning using Python + Keras (Chapter 3): ResNet

Deep Learning using Python + Keras (Chapter 3): ResNet

The training has given acceptable results and has generalized well (0.0119).

Confussion matrix

Once we have trained our model, we want to see another metrics before taking any conclusion of the usability of the model we have been created. For this, we will create the confusion matrix and, from that, we will see the precisionrecall y F1-score metrics (see wikipedia).

To create the confusion matrix, we need to make the predictions over the test set and then, we can create the confusion matrix and show that metrics.

crn50_pred = custom_resnet50_model.predict(x_test, batch_size=32, verbose=1) crn50_predicted = np.argmax(crn50_pred, axis=1)  crn50_cm = confusion_matrix(np.argmax(y_test, axis=1), crn50_predicted)  # Visualizing of confusion matrix crn50_df_cm = pd.DataFrame(crn50_cm, range(100), range(100)) plt.figure(figsize = (20,14)) sn.set(font_scale=1.4) #for label size sn.heatmap(crn50_df_cm, annot=True, annot_kws={"size": 12}) # font size plt.show()

Deep Learning using Python + Keras (Chapter 3): ResNet

And the next step, show the metrics.

crn50_report = classification_report(np.argmax(y_test, axis=1), crn50_predicted) print(crn50_report)               precision    recall  f1-score   support            0       0.46      0.32      0.38       100           1       0.25      0.17      0.20       100           2       0.17      0.09      0.12       100           3       0.05      0.62      0.09       100           4       0.18      0.06      0.09       100           5       0.25      0.05      0.08       100           6       0.11      0.14      0.12       100           7       0.15      0.12      0.13       100           8       0.21      0.20      0.20       100           9       0.49      0.21      0.29       100          10       0.11      0.03      0.05       100          11       0.08      0.05      0.06       100          12       0.38      0.13      0.19       100          13       0.23      0.10      0.14       100          14       0.18      0.05      0.08       100          15       0.14      0.06      0.08       100          16       0.19      0.24      0.21       100          17       0.40      0.19      0.26       100          18       0.19      0.24      0.21       100          19       0.20      0.22      0.21       100          20       0.42      0.31      0.36       100          21       0.31      0.23      0.26       100          22       0.35      0.09      0.14       100          23       0.36      0.37      0.37       100          24       0.31      0.49      0.38       100          25       0.17      0.03      0.05       100          26       0.43      0.06      0.11       100          27       0.11      0.03      0.05       100          28       0.31      0.35      0.33       100          29       0.12      0.10      0.11       100          30       0.27      0.33      0.30       100          31       0.11      0.09      0.10       100          32       0.22      0.20      0.21       100          33       0.23      0.30      0.26       100          34       0.17      0.05      0.08       100          35       0.09      0.02      0.03       100          36       0.10      0.23      0.14       100          37       0.15      0.16      0.16       100          38       0.08      0.24      0.12       100          39       0.23      0.18      0.20       100          40       0.26      0.20      0.22       100          41       0.45      0.49      0.47       100          42       0.12      0.17      0.14       100          43       0.11      0.02      0.03       100          44       0.14      0.09      0.11       100          45       0.08      0.01      0.02       100          46       0.07      0.29      0.12       100          47       0.55      0.18      0.27       100          48       0.23      0.31      0.26       100          49       0.27      0.23      0.25       100          50       0.12      0.05      0.07       100          51       0.28      0.09      0.14       100          52       0.47      0.62      0.54       100          53       0.25      0.13      0.17       100          54       0.18      0.25      0.21       100          55       0.00      0.00      0.00       100          56       0.27      0.27      0.27       100          57       0.27      0.11      0.16       100          58       0.15      0.41      0.22       100          59       0.18      0.10      0.13       100          60       0.41      0.63      0.50       100          61       0.33      0.32      0.32       100          62       0.15      0.07      0.09       100          63       0.31      0.26      0.28       100          64       0.11      0.11      0.11       100          65       0.15      0.11      0.13       100          66       0.10      0.06      0.08       100          67       0.15      0.15      0.15       100          68       0.37      0.66      0.47       100          69       0.38      0.25      0.30       100          70       0.21      0.04      0.07       100          71       0.27      0.54      0.36       100          72       0.20      0.01      0.02       100          73       0.30      0.21      0.25       100          74       0.14      0.15      0.14       100          75       0.30      0.29      0.29       100          76       0.40      0.40      0.40       100          77       0.13      0.14      0.13       100          78       0.15      0.08      0.10       100          79       0.14      0.05      0.07       100          80       0.08      0.05      0.06       100          81       0.14      0.11      0.12       100          82       0.37      0.24      0.29       100          83       0.08      0.02      0.03       100          84       0.10      0.11      0.10       100          85       0.23      0.39      0.29       100          86       0.36      0.21      0.26       100          87       0.21      0.19      0.20       100          88       0.05      0.06      0.05       100          89       0.24      0.18      0.20       100          90       0.21      0.24      0.22       100          91       0.33      0.31      0.32       100          92       0.11      0.11      0.11       100          93       0.16      0.10      0.12       100          94       0.38      0.26      0.31       100          95       0.21      0.50      0.30       100          96       0.22      0.23      0.22       100          97       0.10      0.18      0.13       100          98       0.12      0.02      0.03       100          99       0.24      0.08      0.12       100  avg / total       0.22      0.19      0.19     10000

ROC Curve

The ROC curve is used by binary clasifiers because is a good tool to see the true positives rate versus false positives. Following lines show the code for the multiclass classification ROC curve. This code is from DloLogy, but you can go to the Scikit Learn documentation page.

from sklearn.datasets import make_classification from sklearn.preprocessing import label_binarize from scipy import interp from itertools import cycle  n_classes = 100  from sklearn.metrics import roc_curve, auc  # Plot linewidth. lw = 2  # Compute ROC curve and ROC area for each class fpr = dict() tpr = dict() roc_auc = dict() for i in range(n_classes):     fpr[i], tpr[i], _ = roc_curve(y_test[:, i], crn50_pred[:, i])     roc_auc[i] = auc(fpr[i], tpr[i])  # Compute micro-average ROC curve and ROC area fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), crn50_pred.ravel()) roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])  # Compute macro-average ROC curve and ROC area  # First aggregate all false positive rates all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))  # Then interpolate all ROC curves at this points mean_tpr = np.zeros_like(all_fpr) for i in range(n_classes):     mean_tpr += interp(all_fpr, fpr[i], tpr[i])  # Finally average it and compute AUC mean_tpr /= n_classes  fpr["macro"] = all_fpr tpr["macro"] = mean_tpr roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])  # Plot all ROC curves plt.figure(1) plt.plot(fpr["micro"], tpr["micro"],          label='micro-average ROC curve (area = {0:0.2f})'                ''.format(roc_auc["micro"]),          color='deeppink', linestyle=':', linewidth=4)  plt.plot(fpr["macro"], tpr["macro"],          label='macro-average ROC curve (area = {0:0.2f})'                ''.format(roc_auc["macro"]),          color='navy', linestyle=':', linewidth=4)  colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) for i, color in zip(range(n_classes-97), colors):     plt.plot(fpr[i], tpr[i], color=color, lw=lw,              label='ROC curve of class {0} (area = {1:0.2f})'              ''.format(i, roc_auc[i]))  plt.plot([0, 1], [0, 1], 'k--', lw=lw) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Some extension of Receiver operating characteristic to multi-class') plt.legend(loc="lower right") plt.show()   # Zoom in view of the upper left corner. plt.figure(2) plt.xlim(0, 0.2) plt.ylim(0.8, 1) plt.plot(fpr["micro"], tpr["micro"],          label='micro-average ROC curve (area = {0:0.2f})'                ''.format(roc_auc["micro"]),          color='deeppink', linestyle=':', linewidth=4)  plt.plot(fpr["macro"], tpr["macro"],          label='macro-average ROC curve (area = {0:0.2f})'                ''.format(roc_auc["macro"]),          color='navy', linestyle=':', linewidth=4)  colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) for i, color in zip(range(10), colors):     plt.plot(fpr[i], tpr[i], color=color, lw=lw,              label='ROC curve of class {0} (area = {1:0.2f})'              ''.format(i, roc_auc[i]))  plt.plot([0, 1], [0, 1], 'k--', lw=lw) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Some extension of Receiver operating characteristic to multi-class') plt.legend(loc="lower right") plt.show()

Deep Learning using Python + Keras (Chapter 3): ResNet

Deep Learning using Python + Keras (Chapter 3): ResNet

Then, we will save the train history results to future comparisons and the model.

#Model custom_resnet50_model.save(path_base + '/crn50.h5')  #Historical results with open(path_base + '/crn50_history.txt', 'wb') as file_pi:   pickle.dump(crn50.history, file_pi)

Models comparisons

The next step is compare the metrics of the previous experiment with this results. We will compare accuracyloss and mean squared errors for the models ConvNet and regular net that we saw in previous chapters and some VGG models trained with the same parameters.

plt.figure(0) plt.plot(snn.history['val_acc'],'r') plt.plot(scnn.history['val_acc'],'g') plt.plot(vgg16.history['val_acc'],'b') plt.plot(vgg19.history['val_acc'],'y') plt.plot(vgg16Bis.history['val_acc'],'m') plt.plot(crn50.history['val_acc'],'gold') plt.xticks(np.arange(0, 11, 2.0)) plt.rcParams['figure.figsize'] = (8, 6) plt.xlabel("Num of Epochs") plt.ylabel("Accuracy") plt.title("Simple NN Accuracy vs simple CNN Accuracy") plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])

Deep Learning using Python + Keras (Chapter 3): ResNet

plt.figure(0) plt.plot(snn.history['val_loss'],'r') plt.plot(scnn.history['val_loss'],'g') plt.plot(vgg16.history['val_loss'],'b') plt.plot(vgg19.history['val_loss'],'y') plt.plot(vgg16Bis.history['val_loss'],'m') plt.plot(crn50.history['val_loss'],'gold') plt.xticks(np.arange(0, 11, 2.0)) plt.rcParams['figure.figsize'] = (8, 6) plt.xlabel("Num of Epochs") plt.ylabel("Loss") plt.title("Simple NN Loss vs simple CNN Loss") plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])

Deep Learning using Python + Keras (Chapter 3): ResNet

plt.figure(0) plt.plot(snn.history['val_mean_squared_error'],'r') plt.plot(scnn.history['val_mean_squared_error'],'g') plt.plot(vgg16.history['val_mean_squared_error'],'b') plt.plot(vgg19.history['val_mean_squared_error'],'y') plt.plot(vgg16Bis.history['val_mean_squared_error'],'m') plt.plot(crn50.history['val_mean_squared_error'],'gold') plt.xticks(np.arange(0, 11, 2.0)) plt.rcParams['figure.figsize'] = (8, 6) plt.xlabel("Num of Epochs") plt.ylabel("Mean Squared Error") plt.title("Simple NN MSE vs simple CNN MSE") plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])

Deep Learning using Python + Keras (Chapter 3): ResNet

Conclussion

As you can see, the architecture marks a turning point. Not only because it is of the best results than the previous architectures, but also in the training times, since it allows to increase the layers with an acceptable time; and also in the number of parameters, which has been reduced considerably with respect to the VGG architecture.

In the next article, we will show the DenseNet.