小編給大家分享一下Tensorflow中權(quán)值和feature map可視化的示例分析,相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!
為秭歸等地區(qū)用戶提供了全套網(wǎng)頁設(shè)計制作服務,及秭歸網(wǎng)站建設(shè)行業(yè)解決方案。主營業(yè)務為網(wǎng)站設(shè)計、成都網(wǎng)站設(shè)計、秭歸網(wǎng)站設(shè)計,以傳統(tǒng)方式定制建設(shè)網(wǎng)站,并提供域名空間備案等一條龍服務,秉承以專業(yè)、用心的態(tài)度為用戶提供真誠的服務。我們深信只要達到每一位用戶的要求,就會得到認可,從而選擇與我們長期合作。這樣,我們也可以走得更遠!1. 卷積知識補充
為了后面方便講解代碼,這里先對卷積的部分知識進行一下簡介。關(guān)于卷積核如何在圖像的一個通道上進行滑動計算,網(wǎng)上有諸多資料,相信對卷積神經(jīng)網(wǎng)絡有一定了解的讀者都應該比較清楚,本文就不再贅述。這里主要介紹一組卷積核如何在一幅圖像上計算得到一組feature map。
以從原始圖像經(jīng)過第一個卷積層得到第一組feature map為例(從得到的feature map到再之后的feature map也是同理),假設(shè)第一組feature map共有64個,那么可以把這組feature map也看作一幅圖像,只不過它的通道數(shù)是64, 而一般意義上的圖像是RGB3個通道。為了得到這第一組feature map,我們需要64個卷積核,每個卷積核是一個k x k x 3的矩陣,其中k是卷積核的大?。僭O(shè)是正方形卷積核),3就對應著輸入圖像的通道數(shù)。下面我以一個簡單粗糙的圖示來展示一下圖像經(jīng)過一個卷積核的卷積得到一個feature map的過程。
如圖所示,其實可以看做卷積核的每一通道(不太準確,將就一下)和圖像的每一通道對應進行卷積操作,然后再逐位置相加,便得到了一個feature map。
那么用一組(64個)卷積核去卷積一幅圖像,得到64個feature map就如下圖所示,也就是每個卷積核得到一個feature map,64個卷積核就得到64個feature map。
另外,也可以稍微換一個角度看待這個問題,那就是先讓圖片的某一通道分別與64個卷積核的對應通道做卷積,得到64個feature map的中間結(jié)果,之后3個通道對應的中間結(jié)果再相加,得到最終的feature map,如下圖所示:
可以看到這其實就是第一幅圖擴展到多卷積核的情形,圖畫得較為粗糙,有些中間結(jié)果和最終結(jié)果直接用了一樣的子圖,理解時請稍微注意一下。下面代碼中對卷積核進行展示的時候使用的就是這種方式,即對應著輸入圖像逐通道的去顯示卷積核的對應通道,而不是每次顯示一個卷積核的所有通道,可能解釋的有點繞,需要注意一下。通過下面這個小圖也許更好理解。
圖中用紅框圈出的部分即是我們一次展示出的權(quán)重參數(shù)。
2. 網(wǎng)絡權(quán)值和feature map的可視化
(1) 網(wǎng)絡權(quán)重參數(shù)可視化
首先介紹一下Tensorflow中卷積核的形狀,如下代碼所示:
weights = tf.Variable(tf.random_normal([filter_size, filter_size, channels, filter_num]))
前兩維是卷積核的高和寬,第3維是上一層feature map的通道數(shù),在第一節(jié)(卷積知識補充)中,我提到了上一層的feature map有多少個(也就是通道數(shù)是多少),那么對應著一個卷積核也要有這么多通道。第4維是當前卷積層的卷積核數(shù)量,也是當前層輸出的feature map的通道數(shù)。
以下是我更改之后的網(wǎng)絡權(quán)重參數(shù)(卷積核)的可視化代碼:
from __future__ import print_function #import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import matplotlib.cm as cm import os import visualize_utils def plot_conv_weights(weights, plot_dir, name, channels_all=True, filters_all=True, channels=[0], filters=[0]): """ Plots convolutional filters :param weights: numpy array of rank 4 :param name: string, name of convolutional layer :param channels_all: boolean, optional :return: nothing, plots are saved on the disk """ w_min = np.min(weights) w_max = np.max(weights) # make a list of channels if all are plotted if channels_all: channels = range(weights.shape[2]) # get number of convolutional filters if filters_all: num_filters = weights.shape[3] filters = range(weights.shape[3]) else: num_filters = len(filters) # get number of grid rows and columns grid_r, grid_c = visualize_utils.get_grid_dim(num_filters) # create figure and axes fig, axes = plt.subplots(min([grid_r, grid_c]), max([grid_r, grid_c])) # iterate channels for channel_ID in channels: # iterate filters inside every channel if num_filters == 1: img = weights[:, :, channel_ID, filters[0]] axes.imshow(img, vmin=w_min, vmax=w_max, interpolation='nearest', cmap='seismic') # remove any labels from the axes axes.set_xticks([]) axes.set_yticks([]) else: for l, ax in enumerate(axes.flat): # get a single filter img = weights[:, :, channel_ID, filters[l]] # put it on the grid ax.imshow(img, vmin=w_min, vmax=w_max, interpolation='nearest', cmap='seismic') # remove any labels from the axes ax.set_xticks([]) ax.set_yticks([]) # save figure plt.savefig(os.path.join(plot_dir, '{}-{}.png'.format(name, channel_ID)), bbox_inches='tight')
原項目的代碼是對某一層的權(quán)重參數(shù)或feature map在一個網(wǎng)格中進行全部展示,如果參數(shù)或feature map太多,那么展示出來的結(jié)果中每個圖都很小,很難看出有用的東西來,如下圖所示:
所以我對代碼做了些修改,使得其能顯示任意指定的filter或feature map。
代碼中,
w_min = np.min(weights) w_max = np.max(weights)
這兩句是為了后續(xù)顯示圖像用的,具體可查看matplotlib.pyplot的imshow()函數(shù)進行了解。
接下來是判斷是否顯示全部的channel(通道數(shù))或全部filter。如果是,那就和原代碼一致了。若不是,則畫出函數(shù)參數(shù)channels和filters指定的filter來。
再往下的兩句代碼是畫圖用的,我們可能會在一個圖中顯示多個子圖,以下這句是為了計算出大圖分為幾行幾列比較合適(一個大圖會盡量分解為方形的陣列,比如如果有64個子圖,那么就分成8 x 8的陣列),代碼細節(jié)可在原項目中的utils中找到。
grid_r, grid_c = visualize_utils.get_grid_dim(num_filters)
實際畫圖時,如果想要一個圖一個圖的去畫,需要單獨處理一下。如果還是想在一個大圖中顯示多個子圖,就按源代碼的方式去做,只不過這里可以顯示我們自己指定的那些filter,而不是不加篩選地全部輸出。主要拿到數(shù)據(jù)的是以下這句代碼:
img = weights[:, :, channel_ID, filters[l]]
剩下的都是是畫圖相關(guān)的函數(shù)了,本文就不再對畫圖做更多介紹了。
使用這段代碼可視化并保存filter時,先加載模型,然后拿到我們想要可視化的那部分參數(shù),之后直接調(diào)用函數(shù)就可以了,如下所示:
with tf.Session(graph=tf.get_default_graph()) as sess: init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) saver.restore(sess, model_path) with tf.variable_scope('inference', reuse=True): conv_weights = tf.get_variable('conv3_1_w').eval() visualize.plot_conv_weights(conv_weights, dir_prefix, 'conv3_1')
這里并沒有對filter進行額外的指定,在feature map的可視化中,我會給出相關(guān)例子。
(2) feature map可視化
其實feature map的可視化與filter非常相似,只有細微的不同。還是先把完整代碼貼上。
def plot_conv_output(conv_img, plot_dir, name, filters_all=True, filters=[0]): w_min = np.min(conv_img) w_max = np.max(conv_img) # get number of convolutional filters if filters_all: num_filters = conv_img.shape[3] filters = range(conv_img.shape[3]) else: num_filters = len(filters) # get number of grid rows and columns grid_r, grid_c = visualize_utils.get_grid_dim(num_filters) # create figure and axes fig, axes = plt.subplots(min([grid_r, grid_c]), max([grid_r, grid_c])) # iterate filters if num_filters == 1: img = conv_img[0, :, :, filters[0]] axes.imshow(img, vmin=w_min, vmax=w_max, interpolation='bicubic', cmap=cm.hot) # remove any labels from the axes axes.set_xticks([]) axes.set_yticks([]) else: for l, ax in enumerate(axes.flat): # get a single image img = conv_img[0, :, :, filters[l]] # put it on the grid ax.imshow(img, vmin=w_min, vmax=w_max, interpolation='bicubic', cmap=cm.hot) # remove any labels from the axes ax.set_xticks([]) ax.set_yticks([]) # save figure plt.savefig(os.path.join(plot_dir, '{}.png'.format(name)), bbox_inches='tight')
代碼中和filter可視化相同的部分就不再贅述了,這里只講feature map可視化獨特的方面,其實就在于以下這句代碼,也就是要可視化的數(shù)據(jù)的獲得:
img = conv_img[0, :, :, filters[0]]
神經(jīng)網(wǎng)絡一般都是一個batch一個batch的輸入數(shù)據(jù),其輸入的形狀為
image = tf.placeholder(tf.float32, shape = [None, IMAGE_SIZE, IMAGE_SIZE, 3], name = "input_image")
第一維是一個batch中圖片的數(shù)量,為了靈活可以設(shè)置為None,Tensorflow會根據(jù)實際輸入的數(shù)據(jù)進行計算。二三維是圖片的高和寬,第4維是圖片通道數(shù),一般為3。
如果我們想要輸入一幅圖片,然后看看它的激活值(feature map),那么也要按照以上維度以一個batch的形式進行輸入,也就是[1, IMAGE_SIZE, IMAGE_SIZE, 3]。所以拿feature map數(shù)據(jù)時,第一維度肯定是取0(就對應著batch中的當前圖片),二三維取全部,第4維度再取我們想要查看的feature map的某一通道。
如果想要可視化feature map,那么構(gòu)建網(wǎng)絡時還要動點手腳,定義計算圖時,每得到一組激活值都要將其加到Tensorflow的collection中,如下:
tf.add_to_collection('activations', current)
而實際進行feature map可視化時,就要先輸入一幅圖片,然后運行網(wǎng)絡拿到相應數(shù)據(jù),最后把數(shù)據(jù)傳參給可視化函數(shù)。以下這個例子展示的是如何將每個指定卷積層的feature map的每個通道進行單獨的可視化與存儲,使用的是VGG16網(wǎng)絡:
visualize_layers = ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3'] with tf.Session(graph=tf.get_default_graph()) as sess: init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) saver.restore(sess, model_path) image_path = root_path + 'images/train_images/sunny_0058.jpg' img = misc.imread(image_path) img = img - meanvalue img = np.float32(img) img = np.expand_dims(img, axis=0) conv_out = sess.run(tf.get_collection('activations'), feed_dict={x: img, keep_prob: 1.0}) for i, layer in enumerate(visualize_layers): visualize_utils.create_dir(dir_prefix + layer) for j in range(conv_out[i].shape[3]): visualize.plot_conv_output(conv_out[i], dir_prefix + layer, str(j), filters_all=False, filters=[j]) sess.close()
其中,conv_out包含了所有加入到collection中的feature map,這些feature map在conv_out中是按卷積層劃分的。
最終得到的結(jié)果如下圖所示:
第一個文件夾下的全部結(jié)果:
以上是“Tensorflow中權(quán)值和feature map可視化的示例分析”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對大家有所幫助,如果還想學習更多知識,歡迎關(guān)注創(chuàng)新互聯(lián)行業(yè)資訊頻道!
分享名稱:Tensorflow中權(quán)值和featuremap可視化的示例分析-創(chuàng)新互聯(lián)
標題來源:http://aaarwkj.com/article32/cogipc.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供網(wǎng)站策劃、網(wǎng)站排名、網(wǎng)站營銷、關(guān)鍵詞優(yōu)化、網(wǎng)站收錄、云服務器
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請盡快告知,我們將會在第一時間刪除。文章觀點不代表本網(wǎng)站立場,如需處理請聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時需注明來源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容