5-RNN-02_RNN基本api

Azura ·
更新时间:2024-09-20
· 951 次阅读

import tensorflow as tf """ tf.nn.rnn_cell # 定义rnn 细胞核相关的信息的 tf.nn.rnn_cell_impl # 定义rnn细胞核具体是如何实现的 tf.nn.dynamic_rnn() # 单向动态rnn。 tf.nn.bidirectional_dynamic_rnn() # 双向动态rnn tf.nn.static_rnn() # 单向静态rnn tf.nn.static_bidirectional_rnn() # 双向静态rnn """ """ tf.nn.dynamic_rnn() # 单向动态rnn。 特点 1、在每个批次执行之间构建rnn的执行结构。 允许每个批量数据的 时间步不一致。 效率慢。 2、输入要求:3-D tensor[N, n_steps, num_classes] 3、输出形状: 3-D tensor[N, n_steps, lstm_size] tf.nn.static_rnn() # 单向静态rnn 1、静态rnn在执行之前,执行结构已经构建好, 每个批量数据的时间步必须一致。 效果快 2、输入的要求是一个列表(每一个值对应一个时刻的输入): [[N, num_classes], [N, num_classes],[N, num_classes] ....] 3、返回的输出,也是一个列表(每一个值对应一个时刻的输出) [[N, lstm_size], [N, lstm_size],[N, lstm_size] ....] """ # 一、关于rnn中的细胞核cell """ tf.nn.rnn_cell.BasicLSTMCell # 基础的LSTM cell tf.nn.rnn_cell.LSTMCell() # 带peephole的 LSTM cell tf.nn.rnn_cell.BasicRNNCell() # 基础的rnn vanilla rnn tf.nn.rnn_cell.GRUCell() # GRU实现 tf.nn.rnn_cell.MultiRNNCell() # 堆栈多层隐藏层的 tf.nn.rnn_cell.DropoutWrapper() # rnn的dropout tf.nn.rnn_cell.RNNCell # 所有cell实现的父类 """ def BasicRNN(): """ 学习tf.nn.rnn_cell.BasicRNNCell() 使用 :return: """ cell = tf.nn.rnn_cell.BasicRNNCell(num_units=64, activation=tf.nn.tanh) print(cell.state_size, cell.output_size) def BasicRNN_n_steps(): # 定义一个输入,每个时刻2个样本(batch_size),每个样本由3个维度组成(one-hot) batch_size = tf.placeholder_with_default(2, shape=[], name='batch') inputs1 = tf.placeholder(tf.float32, shape=[2, 3]) inputs2 = tf.placeholder(tf.float32, shape=[2, 3]) inputs = [inputs1, inputs2] # 实例化一个细胞核 cell = tf.nn.rnn_cell.BasicRNNCell(num_units=4) # 初始化一个状态值 state0 = cell.zero_state(batch_size, tf.float32) # 获取t=1时刻的输出 (需要传入 t=1时刻的输入和 上一时刻的状态值) # output1, state1 = cell.__call__(inputs[0], state0) output1, state1 = cell(inputs[0], state0) print(output1, state1) # 获取t=2时刻的输出 (需要传入 t=2时刻的输入和 上一时刻的状态值state1) output2, state2 = cell(inputs[1], state1) print(output2, state2) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) data1 = [ [1.2, 2.3, 3.4], [2.3, 3.5, 3.8] ] data2 = [ [1.32, 2.33, 3.34], [2.33, 3.35, 3.38] ] feed = {inputs1: data1, inputs2: data2} output1_, state1_, output2_, state2_ = sess.run( [output1, state1, output2, state2], feed) print(output1_, state1_) print('**'*56) print(output2_, state2_) def BasicLSTM(): """ 学习lstm cell :return: """ cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=3) # 定义一个输入,总共2个时刻(时间步),每个时刻输入4个样本(batch_size,),每个样本由2个维度组成(one-hot维度) old_inputs = tf.placeholder(tf.float32, [8, 2]) inputs = tf.split(old_inputs, num_or_size_splits=2, axis=0) print(inputs) # 初始化状态值 s0 = cell.zero_state(batch_size=4, dtype=tf.float32) # 直接将输入 传入静态rnn。 tf.nn.static_bidirectional_rnn() rnn_outputs, final_state = tf.nn.static_rnn(cell=cell, inputs=inputs, initial_state=s0) # rnn_outputs [[N, lstm_size], [N, lstm_size]] print(rnn_outputs) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) data = [ [2.3, 2.5], [2.33, 2.5], [2.34, 2.45], [2.3, 23.5], [2.33, 2.45], [2.3, 2.5], [2.33, 2.65], [2.3, 2.5] ] print(sess.run([rnn_outputs, final_state], feed_dict={old_inputs: data})) if __name__ == '__main__': BasicRNN() BasicRNN_n_steps() BasicLSTM() D:\Anaconda\python.exe D:/AI20/HJZ/04-深度学习/4-RNN/20191228___AI20_RNN/02_RNN基本api.py 64 64 Tensor("basic_rnn_cell/Tanh:0", shape=(2, 4), dtype=float32) Tensor("basic_rnn_cell/Tanh:0", shape=(2, 4), dtype=float32) Tensor("basic_rnn_cell/Tanh_1:0", shape=(2, 4), dtype=float32) Tensor("basic_rnn_cell/Tanh_1:0", shape=(2, 4), dtype=float32) 2020-02-18 10:24:39.299294: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2 [[-0.48828983 0.9539078 -0.66409636 0.6394202 ] [-0.6233796 0.9742228 -0.8671773 0.5339205 ]] [[-0.48828983 0.9539078 -0.66409636 0.6394202 ] [-0.6233796 0.9742228 -0.8671773 0.5339205 ]] **************************************************************************************************************** [[-0.4096039 0.71057636 -0.49753952 0.4326929 ] [-0.4899121 0.7949111 -0.80847675 0.07424883]] [[-0.4096039 0.71057636 -0.49753952 0.4326929 ] [-0.4899121 0.7949111 -0.80847675 0.07424883]] [, ] [, ] [[array([[1.9140340e-01, 1.6005790e-02, 2.5509506e-01], [1.8950416e-01, 1.5850056e-02, 2.5679973e-01], [1.8151365e-01, 1.6453112e-02, 2.5943112e-01], [2.0262927e-02, 1.5629010e-12, 2.4361372e-05]], dtype=float32), array([[0.31692508, 0.02951055, 0.37255847], [0.32423702, 0.0286498 , 0.36590496], [0.33835158, 0.02595335, 0.35356337], [0.20455933, 0.01601724, 0.36540356]], dtype=float32)], LSTMStateTuple(c=array([[0.46875435, 0.19634287, 1.4682989 ], [0.4777323 , 0.19432597, 1.4763477 ], [0.49135807, 0.18868431, 1.4950181 ], [0.2809511 , 0.09612122, 1.6794734 ]], dtype=float32), h=array([[0.31692508, 0.02951055, 0.37255847], [0.32423702, 0.0286498 , 0.36590496], [0.33835158, 0.02595335, 0.35356337], [0.20455933, 0.01601724, 0.36540356]], dtype=float32))] Process finished with exit code 0
作者:HJZ11



api rnn

需要 登录 后方可回复, 如果你还没有账号请 注册新账号