CNN实战:tensorflow训练mnist手写数字识别

用tensorflow可以轻松的搭建卷积神经网络,layer层api的加入更是方便了整个过程。本文以mnist手写数字识别的训练为例,轻松挑战99%准确率。

准备

首先,导入必要的库和函数。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

载入mnist数据,注意这里input_data.read_data_sets的参数reshape=Fales,以保留3维的图片数据,否则图片会被变换为向量。由于是黑白图片,所以这里的颜色通道数为1。这里image的shape为(None,28,28,1)。

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, reshape=False)

定义数据流图结构

定义变量。这里我们不需要再定义各层网络的权重和偏差,这便是layers的方便之处,它们会被自动创建。

x = tf.placeholder(tf.float32, shape=(None,28,28,1))
y_ = tf.placeholder(tf.float32, shape=(None,10))
istrain = tf.placeholder(tf.bool) # dropout时参数用到

然后开始定义卷积网络结构,激活函数采用relu,这里使用卷积层⇒池化层⇒卷积层⇒池化层。

conv1 = tf.layers.conv2d(
x,
filters=30,
kernel_size=(3, 3),
strides=1,
padding='same',
activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(
conv1, pool_size=(2, 2), strides=2, padding='same')

conv2 = tf.layers.conv2d(
pool1,
filters=60,
kernel_size=(3, 3),
strides=1,
padding='same',
activation=tf.nn.relu)
pool2 = tf.layers.max_pooling2d(
conv2, pool_size=(2, 2), strides=2, padding='same')

将3维数据打开成向量,准备后面进行全链接网络。

flat = tf.layers.flatten(pool2)

添加两层的全链接隐藏层,并添加dropout以避免过拟合。

dense1 = tf.layers.dense(flat, units=200, activation=tf.nn.relu)
dense1_drop = tf.layers.dropout(dense1,rate=0.5,training=istrain)
dense2 = tf.layers.dense(dense1_drop, units=50, activation=tf.nn.relu)
dense2_drop = tf.layers.dropout(dense2,rate=0.5,training=istrain)

输出层,用softmax变换来输出概率,同时输出logits,方便计算交叉熵。

logits = tf.layers.dense(dense2_drop,units=10) # 即未做softmax变换
y = tf.nn.softmax(logits) # 表示概率

定义loss函数,由于是多分类,因此这里采用softmax交叉熵。

loss = tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=y_,logits=logits))

在梯度优化方法上,选择Adam,并将学习率设置为0.001。

train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

训练过程

建立tensorflow会话,别忘了初始化全局变量。

sess = tf.Session()

init = tf.global_variables_initializer()
sess.run(init)

进行1500次更新训练,每批次输入100个训练样本,这样差不多训练3轮。并用测试集实时观察评估模型,设置每50步更新打印测试准确率。

for i in range(1500):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys, istrain:True})
# 评估模型,每50步打印准确率
if (i+1) % 50 == 0:
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("第%d步更新,准确率:%.2f%%" % (i+1, 100*sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, istrain:False})))

可以看到,训练到1000步时,测试数据准确率接近99%。

...
第1150步更新,准确率:98.56%
第1200步更新,准确率:98.61%
第1250步更新,准确率:98.66%
第1300步更新,准确率:98.77%
...

关闭会话

sess.close()

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注