LSTM - 2. decode_training_set
# Decoding the training set
def decode_training_set(encoder_state, decoder_cell, decoder_embedded_input, sequence_length, decoding_scope, output_function, keep_prob, batch_size):
attention_states = tf.zeros([batch_size, 1, decoder_cell.output_size])
attention_keys, attention_values, attention_score_function, attention_construct_function = tf.contrib.seq2seq.prepare_attention(init_attention_states, attention_option = "bahdanau", num_units = decoder_cell.output_size)
training_decoder_function = tf.contrib.seq2seq.attention_decoder_fn_train(encoder_state[0],
attention_keys,
attention_values,
attention_score_function,
attention_construct_function,
name = "attn_dec_train")
decoder_output, decoder_final_state, decoder_final_context_state = tf.contrib.seq2seq.dynamic_rnn_decoder(decoder_cell,
training_decoder_function,
decoder_embedded_input,
sequence_length,
scope = decoding_scope)
decoder_output_dropout = tf.nn.dropout(decoder_output, keep_prob)
return output_function(decoder_output_dropout)
init_attention_states = tf.zeros([batch_size, 1, num_hidden_neurons])
Here, we set num_hidden_neurons
as decoder_cell.output_size.
attention_keys, attention_values, attention_score_function, attention_construct_function =
tf.contrib.seq2seq.prepare_attention(
init_attention_states, # init hidden states
attention_option = "bahdanau",
num_units = decoder_cell.output_size # hidden state dimension
)
Last updated
Was this helpful?