Segnet

segnet网络使用编解码结构,编码部分使用的是VGG16的前13个卷积层。关于VGG参考:VGG

编码过程中池化时使用池化索引,记录索引位置,供上采样时使用,通常情况下每经过一次池化,卷即核数量翻倍,编码部分确实如此,从最开始的64到最后是512,但要注意解码部分,由于需要与编码对应池化索引,卷即核数量则不符合上边规律,需要重新计算。

下面是池化索引核心代码,使用tensorflow2.1.0,Segnet详细网络结构点我

代码使用了tf.keras.layers.Lambda 调用函数,它可以快速搭建简单函数API模型,可以使用任意的Tensorflow函数,相比tf.keras.layers.Layer 自定义层更快速,但相对于高级的函数模型,还是推荐继承Layer自定义,

使用时需注意,除此之外,若需要save_model()时,需要有get_config()方法,此时只能使用自定义层。

def MaxPool2DWithArgmax(input_tensor, ksize=(1,2,2,1), strides=(1,2,2,1)):
    p, m = tf.nn.max_pool_with_argmax(input_tensor, ksize=ksize, strides=strides, padding="SAME", include_batch_in_index=True)
    m = K.cast(m, dtype=tf.int32)
    return [p, m]

def Unpool2D(input_tensors, factor=(1,2,2,1)):
    pool, mask = input_tensors
    indices = tf.reshape(mask, (-1,mask.shape[1]*mask.shape[2]*mask.shape[3],1))
    values = tf.reshape(pool, (-1,pool.shape[1]*pool.shape[2]*mask.shape[3]))
    size = tf.size(indices) * factor[1] * factor[2] # 获取上采样后的数据数量
    size = tf.reshape(size, [-1]) # 转为1维向量,此时里边应该只有一个数
    t = tf.scatter_nd(indices, values, size) # 使用方式如下所示
    t = tf.reshape(t, (-1, mask.shape[1]*factor[1], mask.shape[2]*factor[2], mask.shape[3])) # 恢复四维数据
    return t

Scatter_nd介绍

 indices = tf.constant([[4], [3], [1], [7]])
 updates = tf.constant([9, 10, 11, 12])
 shape = tf.constant([8])
 scatter = tf.scatter_nd(indices, updates, shape)
 print(scatter)
 # output :[0, 11, 0, 10, 9, 0, 0, 12]

Dean0731

海纳百川,有容乃大,壁立千仞,无欲则刚

相关推荐

发表评论