|
| 1 | +#! /usr/bin/python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +"""ResNet for ImageNet. |
| 4 | +
|
| 5 | +# Reference: |
| 6 | +- [Deep Residual Learning for Image Recognition]( |
| 7 | + https://arxiv.org/abs/1512.03385) (CVPR 2016 Best Paper Award) |
| 8 | +
|
| 9 | +""" |
| 10 | + |
| 11 | +import os |
| 12 | + |
| 13 | +import tensorflow as tf |
| 14 | +from tensorlayer import logging |
| 15 | +from tensorlayer.files import (assign_weights, load_npz, maybe_download_and_extract) |
| 16 | +from tensorlayer.layers import (BatchNorm, Conv2d, Elementwise, GlobalMeanPool2d, MaxPool2d, Input, Dense) |
| 17 | +from tensorlayer.models import Model |
| 18 | + |
| 19 | +__all__ = [ |
| 20 | + 'ResNet50', |
| 21 | +] |
| 22 | + |
| 23 | + |
| 24 | +def identity_block(input, kernel_size, n_filters, stage, block): |
| 25 | + """The identity block where there is no conv layer at shortcut. |
| 26 | +
|
| 27 | + Parameters |
| 28 | + ---------- |
| 29 | + input : tf tensor |
| 30 | + Input tensor from above layer. |
| 31 | + kernel_size : int |
| 32 | + The kernel size of middle conv layer at main path. |
| 33 | + n_filters : list of integers |
| 34 | + The numbers of filters for 3 conv layer at main path. |
| 35 | + stage : int |
| 36 | + Current stage label. |
| 37 | + block : str |
| 38 | + Current block label. |
| 39 | +
|
| 40 | + Returns |
| 41 | + ------- |
| 42 | + Output tensor of this block. |
| 43 | +
|
| 44 | + """ |
| 45 | + filters1, filters2, filters3 = n_filters |
| 46 | + conv_name_base = 'res' + str(stage) + block + '_branch' |
| 47 | + bn_name_base = 'bn' + str(stage) + block + '_branch' |
| 48 | + |
| 49 | + x = Conv2d(filters1, (1, 1), W_init=tf.initializers.he_normal(), name=conv_name_base + '2a')(input) |
| 50 | + x = BatchNorm(name=bn_name_base + '2a', act='relu')(x) |
| 51 | + |
| 52 | + ks = (kernel_size, kernel_size) |
| 53 | + x = Conv2d(filters2, ks, padding='SAME', W_init=tf.initializers.he_normal(), name=conv_name_base + '2b')(x) |
| 54 | + x = BatchNorm(name=bn_name_base + '2b', act='relu')(x) |
| 55 | + |
| 56 | + x = Conv2d(filters3, (1, 1), W_init=tf.initializers.he_normal(), name=conv_name_base + '2c')(x) |
| 57 | + x = BatchNorm(name=bn_name_base + '2c')(x) |
| 58 | + |
| 59 | + x = Elementwise(tf.add, act='relu')([x, input]) |
| 60 | + return x |
| 61 | + |
| 62 | + |
| 63 | +def conv_block(input, kernel_size, n_filters, stage, block, strides=(2, 2)): |
| 64 | + """The conv block where there is a conv layer at shortcut. |
| 65 | +
|
| 66 | + Parameters |
| 67 | + ---------- |
| 68 | + input : tf tensor |
| 69 | + Input tensor from above layer. |
| 70 | + kernel_size : int |
| 71 | + The kernel size of middle conv layer at main path. |
| 72 | + n_filters : list of integers |
| 73 | + The numbers of filters for 3 conv layer at main path. |
| 74 | + stage : int |
| 75 | + Current stage label. |
| 76 | + block : str |
| 77 | + Current block label. |
| 78 | + strides : tuple |
| 79 | + Strides for the first conv layer in the block. |
| 80 | +
|
| 81 | + Returns |
| 82 | + ------- |
| 83 | + Output tensor of this block. |
| 84 | +
|
| 85 | + """ |
| 86 | + filters1, filters2, filters3 = n_filters |
| 87 | + conv_name_base = 'res' + str(stage) + block + '_branch' |
| 88 | + bn_name_base = 'bn' + str(stage) + block + '_branch' |
| 89 | + |
| 90 | + x = Conv2d(filters1, (1, 1), strides=strides, W_init=tf.initializers.he_normal(), name=conv_name_base + '2a')(input) |
| 91 | + x = BatchNorm(name=bn_name_base + '2a', act='relu')(x) |
| 92 | + |
| 93 | + ks = (kernel_size, kernel_size) |
| 94 | + x = Conv2d(filters2, ks, padding='SAME', W_init=tf.initializers.he_normal(), name=conv_name_base + '2b')(x) |
| 95 | + x = BatchNorm(name=bn_name_base + '2b', act='relu')(x) |
| 96 | + |
| 97 | + x = Conv2d(filters3, (1, 1), W_init=tf.initializers.he_normal(), name=conv_name_base + '2c')(x) |
| 98 | + x = BatchNorm(name=bn_name_base + '2c')(x) |
| 99 | + |
| 100 | + shortcut = Conv2d(filters3, (1, 1), strides=strides, W_init=tf.initializers.he_normal(), |
| 101 | + name=conv_name_base + '1')(input) |
| 102 | + shortcut = BatchNorm(name=bn_name_base + '1')(shortcut) |
| 103 | + |
| 104 | + x = Elementwise(tf.add, act='relu')([x, shortcut]) |
| 105 | + return x |
| 106 | + |
| 107 | + |
| 108 | +block_names = ['2a', '2b', '2c', '3a', '3b', '3c', '3d', '4a', '4b', '4c', '4d', '4e', '4f', '5a', '5b', '5c' |
| 109 | + ] + ['avg_pool', 'fc1000'] |
| 110 | +block_filters = [[64, 64, 256], [128, 128, 512], [256, 256, 1024], [512, 512, 2048]] |
| 111 | + |
| 112 | + |
| 113 | +def ResNet50(pretrained=False, end_with='fc1000', n_classes=1000, name=None): |
| 114 | + """Pre-trained MobileNetV1 model (static mode). Input shape [?, 224, 224, 3]. |
| 115 | + To use pretrained model, input should be in BGR format and subtracted from ImageNet mean [103.939, 116.779, 123.68]. |
| 116 | +
|
| 117 | + Parameters |
| 118 | + ---------- |
| 119 | + pretrained : boolean |
| 120 | + Whether to load pretrained weights. Default False. |
| 121 | + end_with : str |
| 122 | + The end point of the model [conv, depth1, depth2 ... depth13, globalmeanpool, out]. |
| 123 | + Default ``out`` i.e. the whole model. |
| 124 | + n_classes : int |
| 125 | + Number of classes in final prediction. |
| 126 | + name : None or str |
| 127 | + Name for this model. |
| 128 | +
|
| 129 | + Examples |
| 130 | + --------- |
| 131 | + Classify ImageNet classes, see `tutorial_models_resnet50.py` |
| 132 | +
|
| 133 | + >>> # get the whole model with pretrained weights |
| 134 | + >>> resnet = tl.models.ResNet50(pretrained=True) |
| 135 | + >>> # use for inferencing |
| 136 | + >>> output = resnet(img1, is_train=False) |
| 137 | + >>> prob = tf.nn.softmax(output)[0].numpy() |
| 138 | +
|
| 139 | + Extract the features before fc layer |
| 140 | + >>> resnet = tl.models.ResNet50(pretrained=True, end_with='5c') |
| 141 | + >>> output = resnet(img1, is_train=False) |
| 142 | +
|
| 143 | + Returns |
| 144 | + ------- |
| 145 | + ResNet50 model. |
| 146 | +
|
| 147 | + """ |
| 148 | + ni = Input([None, 224, 224, 3], name="input") |
| 149 | + n = Conv2d(64, (7, 7), strides=(2, 2), padding='SAME', W_init=tf.initializers.he_normal(), name='conv1')(ni) |
| 150 | + n = BatchNorm(name='bn_conv1', act='relu')(n) |
| 151 | + n = MaxPool2d((3, 3), strides=(2, 2), name='max_pool1')(n) |
| 152 | + |
| 153 | + for i, name in enumerate(block_names): |
| 154 | + if len(name) == 2: |
| 155 | + stage = int(name[0]) |
| 156 | + block = name[1] |
| 157 | + if block == 'a': |
| 158 | + strides = (1, 1) if stage == 2 else (2, 2) |
| 159 | + n = conv_block(n, 3, block_filters[stage - 2], stage=stage, block=block, strides=strides) |
| 160 | + else: |
| 161 | + n = identity_block(n, 3, block_filters[stage - 2], stage=stage, block=block) |
| 162 | + elif name == 'avg_pool': |
| 163 | + n = GlobalMeanPool2d(name='avg_pool')(n) |
| 164 | + elif name == 'fc1000': |
| 165 | + n = Dense(n_classes, name='fc1000')(n) |
| 166 | + |
| 167 | + if name == end_with: |
| 168 | + break |
| 169 | + |
| 170 | + network = Model(inputs=ni, outputs=n, name=name) |
| 171 | + |
| 172 | + if pretrained: |
| 173 | + restore_params(network) |
| 174 | + |
| 175 | + return network |
| 176 | + |
| 177 | + |
| 178 | +def restore_params(network, path='models'): |
| 179 | + logging.info("Restore pre-trained parameters") |
| 180 | + maybe_download_and_extract( |
| 181 | + 'resnet50_weights_tf_dim_ordering_tf_kernels.h5', |
| 182 | + path, |
| 183 | + 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/', |
| 184 | + ) # ls -al |
| 185 | + try: |
| 186 | + import h5py |
| 187 | + except Exception: |
| 188 | + raise ImportError('h5py not imported') |
| 189 | + |
| 190 | + f = h5py.File(os.path.join(path, 'resnet50_weights_tf_dim_ordering_tf_kernels.h5'), 'r') |
| 191 | + |
| 192 | + for layer in network.all_layers: |
| 193 | + if len(layer.all_weights) == 0: |
| 194 | + continue |
| 195 | + w_names = list(f[layer.name]) |
| 196 | + params = [f[layer.name][n][:] for n in w_names] |
| 197 | + if 'bn' in layer.name: |
| 198 | + params = [x.reshape(1, 1, 1, -1) for x in params] |
| 199 | + assign_weights(params, layer) |
| 200 | + del params |
| 201 | + |
| 202 | + f.close() |
0 commit comments