共计 9164 个字符,预计需要花费 23 分钟才能阅读完成。
mnist_inference.py
实现参数设置,前向网络的计算:
#!/usr/bin/python
#-*- coding:utf-8 -*-
############################
#File Name: mnist_inference.py
#Author: yang
#Mail: milkyang2008@126.com
#Created Time: 2017-08-26 14:28:23
############################
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
#define structure and parameter of neural network
input_node = 784
output_node = 10
Layer1_node = 500
def get_weight_variable(shape,regularizer):
weights = tf.get_variable(
"weights", shape,
initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None:
tf.add_to_collection('losses', regularizer(weights))
return weights
#define forward pass
def inference(input_tensor, regularizer):
#define layer1
with tf.variable_scope('layer1'):
weights = get_weight_variable(
[input_node, Layer1_node], regularizer)
biases = tf.get_variable(
"biases", [Layer1_node],
initializer = tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
#define layer2
with tf.variable_scope('layer2'):
weights = get_weight_variable(
[Layer1_node, output_node], regularizer)
biases = tf.get_variable(
"biases", [output_node],
initializer = tf.constant_initializer(0.0))
layer2 = tf.matmul(layer1, weights) + biases
#return layer2
return layer2
mnist_train.py
实现网络训练;
#!/usr/bin/python
#-*- coding:utf-8 -*-
############################
#File Name: mnist_train.py
#Author: yang
#Mail: milkyang2008@126.com
#Created Time: 2017-08-26 15:04:53
############################
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
#load training data
from tensorflow.examples.tutorials.mnist import input_data
#load mnist_inference.py
import mnist_inference
#denfine parameter of NN
batch_size = 100
learning_rate_base = 0.8
learning_rate_decay = 0.99
regularztion_rate = 0.0001
training_steps = 30000
moving_average_decay = 0.99
#save path of model and file
model_save_path = "./path/to/model"
model_name = "model.ckpt"
def train(mnist):
#initialize input & label
x = tf.placeholder(
tf.float32, [None, mnist_inference.input_node], name='x-input')
y_ = tf.placeholder(
tf.float32, [None, mnist_inference.output_node], name='y-input')
#define l2 regularizer
regularizer = tf.contrib.layers.l2_regularizer(regularztion_rate)
#compute output
y = mnist_inference.inference(x, regularizer)
global_step = tf.Variable(0, trainable=False)
#compute average of parameter
variable_averages = tf.train.ExponentialMovingAverage(
moving_average_decay, global_step)
variables_averages_op = variable_averages.apply(
tf.trainable_variables())
#compute cross_entropy
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
#compute loss
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
#adjust learning_rate
learning_rate = tf.train.exponential_decay(
learning_rate_base,
global_step,
mnist.train.num_examples / batch_size,
learning_rate_decay)
train_step = tf.train.GradientDescentOptimizer(learning_rate)\
.minimize(loss, global_step=global_step)
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train')
#initialize persist class
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
#tf.initialize_all_variables().run()
for i in range(training_steps):
xs, ys = mnist.train.next_batch(batch_size)
_, loss_value, step =sess.run([train_op, loss, global_step],
feed_dict={x: xs, y_:ys})
if i % 1000 == 0:
print("After %d training step(s), loss on training "
"batch is %g." % (step, loss_value))
saver.save(
sess, os.path.join(model_save_path, model_name),
global_step=global_step)
def main(argv=None):
mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
train(mnist)
if __name__ == '__main__':
tf.app.run()
经过多轮训练后,Loss的变化:
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
After 1 training step(s), loss on training batch is 3.08818.
After 1001 training step(s), loss on training batch is 0.232921.
After 2001 training step(s), loss on training batch is 0.213298.
After 3001 training step(s), loss on training batch is 0.140477.
After 4001 training step(s), loss on training batch is 0.135925.
After 5001 training step(s), loss on training batch is 0.11498.
After 6001 training step(s), loss on training batch is 0.111357.
After 7001 training step(s), loss on training batch is 0.100002.
After 8001 training step(s), loss on training batch is 0.0771622.
After 9001 training step(s), loss on training batch is 0.0791518.
After 10001 training step(s), loss on training batch is 0.0728585.
After 11001 training step(s), loss on training batch is 0.0680792.
After 12001 training step(s), loss on training batch is 0.06702.
After 13001 training step(s), loss on training batch is 0.0538534.
After 14001 training step(s), loss on training batch is 0.0551859.
After 15001 training step(s), loss on training batch is 0.0519553.
After 16001 training step(s), loss on training batch is 0.0486625.
After 17001 training step(s), loss on training batch is 0.0451939.
After 18001 training step(s), loss on training batch is 0.0508607.
After 19001 training step(s), loss on training batch is 0.0480105.
After 20001 training step(s), loss on training batch is 0.0426104.
After 21001 training step(s), loss on training batch is 0.0388245.
After 22001 training step(s), loss on training batch is 0.0413718.
After 23001 training step(s), loss on training batch is 0.0370494.
After 24001 training step(s), loss on training batch is 0.0424059.
After 25001 training step(s), loss on training batch is 0.0353014.
After 26001 training step(s), loss on training batch is 0.0341191.
After 27001 training step(s), loss on training batch is 0.0336432.
After 28001 training step(s), loss on training batch is 0.0346693.
After 29001 training step(s), loss on training batch is 0.0361443.
mnist_eval.py
实现最后验证集测试:
#!/usr/bin/python
#-*- coding:utf-8 -*-
############################
#File Name: mnist_eval.py
#Author: yang
#Mail: milkyang2008@126.com
#Created Time: 2017-08-26 17:02:46
############################
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train
#load nearly(last) model every 10s; and test the model accuracy
eval_interval_secs = 10
def evaluate(mnist):
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, mnist_inference.input_node],
name = 'x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.output_node],
name = 'y-input')
validate_feed = {x: mnist.validation.images,
y_:mnist.validation.labels}
y = mnist_inference.inference(x,None)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
variable_averages = tf.train.ExponentialMovingAverage(
mnist_train.moving_average_decay)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
#test the model accuracy every 10s
while True:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(
mnist_train.model_save_path)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path\
.split('/')[-1].split('-')[-1]
accuracy_score = sess.run(accuracy,
feed_dict=validate_feed)
print("After %s training step(s), validation "
"accuracy = %g" % (global_step, accuracy_score))
else:
print('No checkpoint file found!')
return
time.sleep(eval_interval_secs)
def main(argv=None):
mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
evaluate(mnist)
if __name__ == '__main__':
tf.app.run()
测试数据上正确率的表现:
Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
After 1001 training step(s), validation accuracy = 0.9772
After 2001 training step(s), validation accuracy = 0.982
After 3001 training step(s), validation accuracy = 0.9826
After 4001 training step(s), validation accuracy = 0.984
After 5001 training step(s), validation accuracy = 0.9838
After 6001 training step(s), validation accuracy = 0.9846
After 7001 training step(s), validation accuracy = 0.9834
After 8001 training step(s), validation accuracy = 0.9844
After 8001 training step(s), validation accuracy = 0.9844
After 9001 training step(s), validation accuracy = 0.984
After 10001 training step(s), validation accuracy = 0.9834
After 11001 training step(s), validation accuracy = 0.984
After 12001 training step(s), validation accuracy = 0.9846
After 13001 training step(s), validation accuracy = 0.984
After 14001 training step(s), validation accuracy = 0.985
After 15001 training step(s), validation accuracy = 0.9844
After 16001 training step(s), validation accuracy = 0.9854
After 17001 training step(s), validation accuracy = 0.985
After 17001 training step(s), validation accuracy = 0.985
After 18001 training step(s), validation accuracy = 0.9844
After 19001 training step(s), validation accuracy = 0.9846
After 20001 training step(s), validation accuracy = 0.9852
After 21001 training step(s), validation accuracy = 0.9852
After 22001 training step(s), validation accuracy = 0.9848
After 23001 training step(s), validation accuracy = 0.9848
After 24001 training step(s), validation accuracy = 0.9854
After 25001 training step(s), validation accuracy = 0.9848
After 26001 training step(s), validation accuracy = 0.985
After 27001 training step(s), validation accuracy = 0.9842
After 28001 training step(s), validation accuracy = 0.9844
After 28001 training step(s), validation accuracy = 0.9844
After 29001 training step(s), validation accuracy = 0.9844
After 29001 training step(s), validation accuracy = 0.9844
正文完
请博主喝杯咖啡吧!