import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('mnist_data/',one_hot=True)
#注意这里用了one_hot表示,标签的形状是(batch_size,num_batches),类型是float,如果不用one_hot,那么标签的形状是(batch_size,),类型是int
num_classes=10
batch_size=64
hidden_dim1=32
hidden_dim2=64
epochs=10
embedding_dim=28
time_step=28
'''
这里解析一下tensorflow.nn.dynamic_rnn(cell,inputs,initial_state=None,dtype=None,time_major=False)
返回值有两个outputs,states.outputs的形状是[batch_size,time_step,cell.output_dim]
这个三维张量的第一个维度是batch_size,我们来看二三维度,相当于每一个batch_size对应于一个二维矩阵,
二维矩阵的每一行长度是rnn的输出维度,代表在当前的batch_size中的当前的time_step的特征。
将outputs取转置并且在第一个维度上取最后一个元素[-1,batch_size,:],也就得到了每一个batch_size中所有数据的最后一个time_step,也就是
最后一个时刻的输出[batch_size,output_dim],
在语言模型中,time_step就是一个句子,所以这个张量[batch_size,output_dim]通常作为下一个时刻的输入,因为这个张量的每一行代表的意义是在
读取了整个句子后所得到的特征,这是很有意义的。如果将outputs取了转置后[time_step,batch_size,output_dim]对这个张量取[1,batch_size,output_dim]
那么现在得到的张量仅仅是在读取了句子中第一个单词后获得的输出,这样的张量如果作为下一层的输入那么模型的效果将惨不忍睹。
现在来讨论states,如果cell是LSTM类型的cell,那么states是一个有两个元素的元祖,为什么有两个元素呢,因为LSTM类型的cell有两个状态
一个是cell state代表该神经元的细胞状态,另一个是hidden state代表该神经元的隐藏状态。
而且这两个状态都是最后一个时刻的神经元的状态,所以states的形状与time_step无关
这两个状态张量的形状都是[batch_size,output_dim]
states[0]是cell的状态,states[1]是hidden的状态
所以states[1]与上面提到的提取每一个time_step的最后一个时刻得到的张量[batch_size,output_dim]是一样的,因为它们都是每一个batch_size中
每一个数据的最后一个时刻的特征提取。
'''
class RNN_model:
def __init__(self):
tf.reset_default_graph()
def add_placeholder(self):
self.xs=tf.placeholder(shape=[None,784],dtype=tf.float32)
self.ys=tf.placeholder(shape=[None,10],dtype=tf.float32)
#由于用了one_hot表示,所以shape是(batch_size,10),dtype是float
#不用one_hot表示的话应该改为self.ys=tf.placeholder(shape=[None],dtype=tf.int32)
def rnn_layer(self):
rnn_input=tf.reshape(tensor=self.xs,shape=[-1,time_step,embedding_dim])
cell_1=tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim1)
cell_2=tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim2)
cells=tf.contrib.rnn.MultiRNNCell([cell_1,cell_2])
initial_state=cells.zero_state(batch_size=batch_size,dtype=tf.float32)
outputs,states=tf.nn.dynamic_rnn(cells,rnn_input,initial_state=initial_state,time_major=False)
#outputs.shape==(batch_size,time_step,hidden_dim2)
#states表示的是batch_size里的每一个长度为time_step的数据的最后一个时刻状态,所以与time_step无关
#states[0][0].shape==states[0][1].shape==(batch_size,hidden_dim1)
#states[1][0].shape==states[1][1].shape==(batch_size,hidden_dim2)
#states[1][1]==tf.transpose(outputs,[1,0,2])[-1]
outputs=tf.transpose(outputs,perm=[1,0,2])
self.rnn_output=outputs[-1]#(batch_size,hidden_dim2)
def output_layer(self):
weights=tf.Variable(tf.random_normal(shape=[hidden_dim2,num_classes],dtype=tf.float32))
biases=tf.Variable(tf.random_normal(shape=[num_classes],dtype=tf.float32))
self.predict=tf.matmul(self.rnn_output,weights)+biases#(batch_size,num_classes)
#注意softmax_cross_entropy_with_logits与sparse_softmax_cross_entropy_with_logits的区别
#前者的logits与labels必须是同样的shape(batch_size,num_classes)以及同样的dtype(float32)
#后者的logits形状是(batch_size,num_classes),dtype是float32,而labels的形状是(batch_size),dtype必须是int,每一个数值表示logits的每一行数据属于哪一类
def loss_layer(self):
self.loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.predict,labels=self.ys))
self.train_op=tf.train.AdamOptimizer(0.01).minimize(self.loss)
self.accuracy=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.ys,1),tf.argmax(self.predict,1)),dtype=tf.float32))
def build_graph(self):
self.add_placeholder()
self.rnn_layer()
self.output_layer()
def train(self):
num_batches=mnist.train.num_examples//batch_size
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(epochs):
epoch_loss=0.0
for i in range(num_batches):
batch_x,batch_y=mnist.train.next_batch(batch_size)
feed_dict={self.xs:batch_x,self.ys:batch_y}
_,loss_value=sess.run([self.train_op,self.loss],feed_dict=feed_dict)
epoch_loss+=loss_value.item()
test_xs,test_ys=mnist.test.next_batch(batch_size)
assert test_xs.shape==(batch_size,784) and test_ys.shape==(batch_size,10)
acc=sess.run(self.accuracy,feed_dict={self.xs:test_xs,self.ys:test_ys})
print("After %d epoch,loss value is %f ,and accuracy is %f " %(epoch+1,epoch_loss/num_batches,acc.item()))
saver.save(sess,"checkpoints/rnn_mnist.ckpt")
if __name__=="__main__":
model=RNN_model()
model.build_graph()
model.loss_layer()
model.train()