参考:
TensorFlow运作方式入门http://www.tensorfly.cn/tfdoc/tutorials/mnist_tf.html
注意以下代码仅为示例
Step1准备数据输入
按需制作训练集
Step2 构造图表(Build the Graph)
####2.1定义占位符
在创建session的时候数据才真正流入神经网络
1 2 3
| images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, IMAGE_PIXELS)) labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
|
####2.2构造inference()
占位符为输入,使数据经过神经网络向前反馈输出预测结果
每一层都创建于一个唯一的tf.name_scope
之下,创建于该作用域之下的所有元素都将带有其前缀
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| def inference(images, hidden1_units, hidden2_units): with tf.name_scope('hidden1'): weights = tf.Variable( tf.truncated_normal([IMAGE_PIXELS, hidden1_units], stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))), name='weights') biases = tf.Variable(tf.zeros([hidden1_units]), name='biases') hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases) with tf.name_scope('hidden2'): weights = tf.Variable( tf.truncated_normal([hidden1_units, hidden2_units], stddev=1.0 / math.sqrt(float(hidden1_units))), name='weights') biases = tf.Variable(tf.zeros([hidden2_units]), name='biases') hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases) with tf.name_scope('softmax_linear'): weights = tf.Variable( tf.truncated_normal([hidden2_units, NUM_CLASSES], stddev=1.0 / math.sqrt(float(hidden2_units))), name='weights') biases = tf.Variable(tf.zeros([NUM_CLASSES]), name='biases') logits = tf.matmul(hidden2, weights) + biases return logits
|
####2.3损失(Loss)
1 2 3 4
| cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, onehot_labels, name='entropy') loss = tf.reduce_mean(cross_entropy, name='entropy_mean')
|
####2.4训练(training)
#####2.4.1将损失最小化
1 2
| optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_op = optimizer.minimize(loss, global_step=global_step)
|
#####2.4.2生成一个变量用于保存全局训练步骤(global training step)的数值
1
| global_step = tf.Variable(0, name='global_step', trainable=False)
|
##step3 启动会话并执行图表
####3.1关联图表构建会话
1 2 3 4 5
| saver = tf.train.Saver() with tf.Graph().as_default(): with tf.Session() as sess: init = tf.initialize_all_variables() sess.run(init)
|
####3.2 feed_dict参数传入sess.run(),真正训练模型,保存模型
1 2 3 4 5 6 7 8 9 10
| for step in xrange(FLAGS.max_steps): feed_dict = { images_placeholder: images_feed, labels_placeholder: labels_feed, } _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) saver.save(sess, FLAGS.train_dir, global_step=step)
|
Step4 恢复并评估模型
1 2 3 4
| with tf.Session() as sess: model_dir=tf.train.latest_checkpoint('ckpt/') saver.restore(sess,model_dir) ***=sess.run(***, feed_dict={))
|