From e794929624ac14334d306c907efd9e07cfc33ddb Mon Sep 17 00:00:00 2001 From: jun Date: Thu, 7 Dec 2017 23:50:34 +0900 Subject: [PATCH] hierarchical encoder decoder added --- config/check_tiny.yml | 1 + model.py | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/config/check_tiny.yml b/config/check_tiny.yml index 625facd..4f2520c 100644 --- a/config/check_tiny.yml +++ b/config/check_tiny.yml @@ -21,6 +21,7 @@ model: dropout: 0.2 encoder_type: bi attention_mechanism: bahdanau + session_level_encoder: False train: batch_size: 2 diff --git a/model.py b/model.py index 9c87ae0..6e4a01b 100644 --- a/model.py +++ b/model.py @@ -59,6 +59,8 @@ def _init_placeholder(self, features, labels): def build_graph(self): self._build_embed() self._build_encoder() + if Config.model.session_level_encoder: + self._build_session_level_encoder() self._build_decoder() if self.mode != tf.estimator.ModeKeys.PREDICT: @@ -106,6 +108,44 @@ def _build_encoder(self): self.encoder_outputs = tf.contrib.seq2seq.tile_batch(self.encoder_outputs, beam_width) self.encoder_input_lengths = tf.contrib.seq2seq.tile_batch(self.encoder_input_lengths, beam_width) + def _build_session_level_encoder(self): + + if Config.model.num_layers > 1: + if Config.model.cell_type == "LSTM": + if Config.model.encoder_type == "bi": + session_level_encoder_input = tf.stack(self.encoder_final_state) + else: + session_level_encoder_input = tf.stack(self.encoder_final_state[-1]) + else: + if Config.model.encoder_type == "bi": + session_level_encoder_input = tf.stack([self.encoder_final_state]) + else: + session_level_encoder_input = tf.stack([self.encoder_final_state[-1]]) + else: + if Config.model.cell_type == "LSTM": + if Config.model.encoder_type == "bi": + session_level_encoder_input = tf.stack(self.encoder_final_state) + else: + session_level_encoder_input = tf.stack(self.encoder_final_state[0]) + else: + if Config.model.encoder_type == "bi": + session_level_encoder_input = tf.stack([self.encoder_final_state]) + else: + session_level_encoder_input = tf.stack(self.encoder_final_state) + + with tf.variable_scope("session_level_encoder"): + + if Config.model.cell_type == "LSTM": + session_level_encoder_cell = self._single_cell("LSTM", Config.model.dropout,Config.model.num_units) + else: + session_level_encoder_cell = self._single_cell("GRU", Config.model.dropout,Config.model.num_units) + + session_level_encoder_outputs, session_level_encoder_final_state = tf.nn.dynamic_rnn( + session_level_encoder_cell, session_level_encoder_input, + dtype=tf.float32) + + self.encoder_final_state = tf.unstack(session_level_encoder_outputs)[0] + def _build_unidirectional_rnn(self): cells = self._build_rnn_cells(Config.model.num_units) return tf.nn.dynamic_rnn( @@ -202,7 +242,7 @@ def decode(helper=None, scope="decode"): attention_mechanism, attention_layer_size=attention_layer_size, alignment_history=alignment_history, - name="attention") + name="attention") out_cell = tf.contrib.rnn.OutputProjectionWrapper( attn_cell, Config.data.vocab_size)