Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3b201b9

Browse files
authored
Merge pull request #263 from rsepassi/push
v1.2.1
2 parents 9ef2517 + f715f85 commit 3b201b9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2161
-547
lines changed

README.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,26 @@ You can chat with us and other users on
2929
with T2T announcements.
3030

3131
Here is a one-command version that installs tensor2tensor, downloads the data,
32-
trains an English-German translation model, and lets you use it interactively:
32+
trains an English-German translation model, and evaluates it:
3333
```
3434
pip install tensor2tensor && t2t-trainer \
3535
--generate_data \
3636
--data_dir=~/t2t_data \
3737
--problems=translate_ende_wmt32k \
3838
--model=transformer \
3939
--hparams_set=transformer_base_single_gpu \
40-
--output_dir=~/t2t_train/base \
40+
--output_dir=~/t2t_train/base
41+
```
42+
43+
You can decode from the model interactively:
44+
45+
```
46+
t2t-decoder \
47+
--data_dir=~/t2t_data \
48+
--problems=translate_ende_wmt32k \
49+
--model=transformer \
50+
--hparams_set=transformer_base_single_gpu \
51+
--output_dir=~/t2t_train/base
4152
--decode_interactive
4253
```
4354

@@ -106,14 +117,12 @@ echo "Goodbye world" >> $DECODE_FILE
106117
BEAM_SIZE=4
107118
ALPHA=0.6
108119
109-
t2t-trainer \
120+
t2t-decoder \
110121
--data_dir=$DATA_DIR \
111122
--problems=$PROBLEM \
112123
--model=$MODEL \
113124
--hparams_set=$HPARAMS \
114125
--output_dir=$TRAIN_DIR \
115-
--train_steps=0 \
116-
--eval_steps=0 \
117126
--decode_beam_size=$BEAM_SIZE \
118127
--decode_alpha=$ALPHA \
119128
--decode_from_file=$DECODE_FILE

docs/example_life.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# T2T: Life of an Example
2+
3+
[![PyPI
4+
version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
5+
[![GitHub
6+
Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
7+
[![Contributions
8+
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
9+
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
10+
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
11+
12+
This document show how a training example passes through the T2T pipeline,
13+
and how all its parts are connected to work together.
14+
15+
## The Life of an Example
16+
17+
A training example passes the following stages in T2T:
18+
* raw input (text from command line or file)
19+
* encoded input after [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `encode` is usually a sparse tensor, e.g., a vector of `tf.int32`s
20+
* batched input after [data input pipeline](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/data_reader.py#L242) where the inputs, after [Problem.preprocess_examples](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L188) are grouped by their length and made into batches.
21+
* dense input after being processed by a [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `bottom`.
22+
* dense output after [T2T.model_fn_body](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L542)
23+
* back to sparse output through [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `top`.
24+
* if decoding, back through [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `decode` to display on the screen.
25+
26+
We go into these phases step by step below.
27+
28+
## Feature Encoders
29+
30+
TODO: describe [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) which is a dict of encoders that have `encode` and `decode` functions.
31+
32+
## Modalities
33+
34+
TODO: describe [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) which has `bottom` and `top` but also sharded versions and one for targets.

docs/index.md

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
1-
# T2T: Tensor2Tensor Transformers
2-
3-
Check us out on
4-
<a href=https://github.com/tensorflow/tensor2tensor>
5-
GitHub
6-
<img src="https://github.com/favicon.ico" width="16">
7-
</a>
8-
.
1+
# Tensor2Tensor Docs Index
92

103
[![PyPI
114
version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
@@ -16,8 +9,26 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
169
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
1710
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
1811

19-
See our
20-
[README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/README.md)
21-
for documentation.
2212

23-
More documentation and tutorials coming soon...
13+
Welcome to Tensor2Tensor!
14+
15+
Tensor2Tensor, or T2T for short, is a library we use to create,
16+
investigate and deploy deep learning models. This page hosts our
17+
documentation, from basic tutorials to full code documentation.
18+
19+
## Basics
20+
21+
* [Walkthrough: Install and Run](walkthrough.md)
22+
* [Tutorial: Train on Your Data](new_problem.md)
23+
* [Tutorial: Create Your Own Model](new_model.md)
24+
25+
## Deep Dive
26+
27+
* [Life of an Example](example_life.md): how all parts of T2T are connected and work together
28+
* [Distributed Training](distributed_training.md)
29+
30+
## Code documentation
31+
32+
See our
33+
[README](https://github.com/tensorflow/tensor2tensor/blob/master/README.md)
34+
for now, code docs coming.

docs/new_model.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# T2T: Create Your Own Model
2+
3+
[![PyPI
4+
version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
5+
[![GitHub
6+
Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
7+
[![Contributions
8+
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
9+
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
10+
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
11+
12+
Here we show how to create your own model in T2T.
13+
14+
## The T2TModel class
15+
16+
TODO: complete.

docs/new_problem.md

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# T2T: Train on Your Own Data
2+
3+
[![PyPI
4+
version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
5+
[![GitHub
6+
Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
7+
[![Contributions
8+
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
9+
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
10+
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
11+
12+
Let's add a new dataset together and train the transformer model. We'll be learning to define English words by training the transformer to "translate" between English words and their definitions on a character level.
13+
14+
# About the Problem
15+
16+
For each problem we want to tackle we create a new problem class and register it. Let's call our problem `Word2def`.
17+
18+
Since many text2text problems share similar methods, there's already a class called `Text2TextProblem` that extends the base problem class, `Problem` (both found in `problem.py`).
19+
20+
For our problem, we can go ahead and create the file `word2def.py` in the `data_generators` folder and add our new problem, `Word2def`, which extends `TranslateProblem`. Let's also register it while we're at it so we can specify the problem through flags.
21+
22+
```python
23+
@registry.register_problem()
24+
class Word2def(problem.Text2TextProblem):
25+
"""Problem spec for English word to dictionary definition."""
26+
return NotImplementedError()
27+
```
28+
29+
We need to implement the following methods from `Text2TextProblem` in our new class:
30+
* is_character_level
31+
* targeted_vocab_size
32+
* generator
33+
* input_space_id
34+
* target_space_id
35+
* num_shards
36+
* vocab_name
37+
* use_subword_tokenizer
38+
39+
Let's tackle them one by one:
40+
41+
**input_space_id, target_space_id, is_character_level, targeted_vocab_size, use_subword_tokenizer**:
42+
43+
SpaceIDs tell Tensor2Tensor what sort of space the input and target tensors are in. These are things like, EN_CHR (English character), EN_TOK (English token), AUDIO_WAV (audio waveform), IMAGE, DNA (genetic bases). The complete list can be found at `data_generators/problem.py` in the class `SpaceID`.
44+
45+
Since we're generating definitions and feeding in words at the character level, we set `is_character_level` to true, and use the same SpaceID, EN_CHR, for both input and target. Additionally, since we aren't using tokens, we don't need to give a `targeted_vocab_size` or define `use_subword_tokenizer`.
46+
47+
**vocab_name**:
48+
49+
`vocab_name` will be used to name your vocabulary files. We can call ours `'vocab.word2def.en'`
50+
51+
**num_shards**:
52+
53+
The number of shards to break data files into.
54+
55+
```python
56+
@registry.register_problem()
57+
class Word2def(problem.Text2TextProblem):
58+
"""Problem spec for English word to dictionary definition."""
59+
def is_character_level(self):
60+
return True
61+
62+
@property
63+
def vocab_name(self):
64+
return "vocab.word2def.en"
65+
66+
@property
67+
def input_space_id(self):
68+
return problem.SpaceID.EN_CHR
69+
70+
@property
71+
def target_space_id(self):
72+
return problem.SpaceID.EN_CHR
73+
74+
@property
75+
def num_shards(self):
76+
return 100
77+
78+
@property
79+
def use_subword_tokenizer(self):
80+
return False
81+
```
82+
83+
**generator**:
84+
85+
We're almost done. `generator` generates the training and evaluation data and stores them in files like "word2def_train.lang1" in your DATA_DIR. Thankfully several commonly used methods like `character_generator`, and `token_generator` are already written in the file `wmt.py`. We will import `character_generator` and write:
86+
```python
87+
def generator(self, data_dir, tmp_dir, train):
88+
character_vocab = text_encoder.ByteTextEncoder()
89+
datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
90+
tag = "train" if train else "dev"
91+
return character_generator(datasets[0], datasets[1], character_vocab, EOS)
92+
```
93+
94+
Now our `word2def.py` file looks like the below:
95+
96+
```python
97+
@registry.register_problem()
98+
class Word2def(problem.Text2TextProblem):
99+
"""Problem spec for English word to dictionary definition."""
100+
@property
101+
def is_character_level(self):
102+
return True
103+
104+
@property
105+
def vocab_name(self):
106+
return "vocab.word2def.en"
107+
108+
def generator(self, data_dir, tmp_dir, train):
109+
character_vocab = text_encoder.ByteTextEncoder()
110+
datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
111+
tag = "train" if train else "dev"
112+
return character_generator(datasets[0], datasets[1], character_vocab, EOS)
113+
114+
@property
115+
def input_space_id(self):
116+
return problem.SpaceID.EN_CHR
117+
118+
@property
119+
def target_space_id(self):
120+
return problem.SpaceID.EN_CHR
121+
122+
@property
123+
def num_shards(self):
124+
return 100
125+
126+
@property
127+
def use_subword_tokenizer(self):
128+
return False
129+
```
130+
131+
## Data:
132+
Now we need to tell Tensor2Tensor where our data is located.
133+
134+
I've gone ahead and split all words into a train and test set and saved them in files called `words.train.txt`, `words.test.txt`,
135+
`definitions.train.txt`, and `definitions.test.txt` in a directory called `LOCATION_OF_DATA/`. Let's tell T2T where these files are:
136+
137+
```python
138+
# English Word2def datasets
139+
_WORD2DEF_TRAIN_DATASETS = [
140+
[
141+
"LOCATION_OF_DATA/", ("words_train.txt", "definitions_train.txt")
142+
]
143+
]
144+
_WORD2DEF_TEST_DATASETS = [
145+
[
146+
"LOCATION_OF_DATA", ("words_test.txt", "definitions_test.txt")
147+
]
148+
]
149+
```
150+
151+
## Putting it all together
152+
153+
Now our `word2def.py` file looks like: (with the correct imports)
154+
```python
155+
""" Problem definition for word to dictionary definition.
156+
"""
157+
158+
from __future__ import absolute_import
159+
from __future__ import division
160+
from __future__ import print_function
161+
162+
import os
163+
import tarfile # do we need this import
164+
165+
from tensor2tensor.data_generators import generator_utils
166+
from tensor2tensor.data_generators import problem
167+
from tensor2tensor.data_generators import text_encoder
168+
from tensor2tensor.data_generators.wmt import character_generator
169+
170+
from tensor2tensor.utils import registry
171+
172+
import tensorflow as tf
173+
174+
FLAGS = tf.flags.FLAGS
175+
176+
# English Word2def datasets
177+
_WORD2DEF_TRAIN_DATASETS = [
178+
LOCATION_OF_DATA+'words_train.txt',
179+
LOCATION_OF_DATA+'definitions_train.txt'
180+
]
181+
182+
_WORD2DEF_TEST_DATASETS = [
183+
LOCATION_OF_DATA+'words_test.txt',
184+
LOCATION_OF_DATA+'definitions_test.txt'
185+
]
186+
187+
@registry.register_problem()
188+
class Word2def(problem.Text2TextProblem):
189+
"""Problem spec for English word to dictionary definition."""
190+
@property
191+
def is_character_level(self):
192+
return True
193+
194+
@property
195+
def vocab_name(self):
196+
return "vocab.word2def.en"
197+
198+
def generator(self, data_dir, tmp_dir, train):
199+
character_vocab = text_encoder.ByteTextEncoder()
200+
datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
201+
tag = "train" if train else "dev"
202+
return character_generator(datasets[0], datasets[1], character_vocab, EOS)
203+
204+
@property
205+
def input_space_id(self):
206+
return problem.SpaceID.EN_CHR
207+
208+
@property
209+
def target_space_id(self):
210+
return problem.SpaceID.EN_CHR
211+
212+
@property
213+
def num_shards(self):
214+
return 100
215+
216+
@property
217+
def use_subword_tokenizer(self):
218+
return False
219+
220+
```
221+
222+
# Hyperparameters
223+
All hyperparamters inherit from `_default_hparams()` in `problem.py.` If you would like to customize your hyperparameters, add another method to the file `problem_hparams.py`.
224+
225+
# Run the problem
226+
Now that we've gotten our problem set up, let's train a model and generate definitions.
227+
228+
We specify our problem name, the model, and hparams.
229+
```bash
230+
PROBLEM=word2def
231+
MODEL=transformer
232+
HPARAMS=transofmer_base_single_gpu
233+
```
234+
235+
The rest of the steps are as given in the [walkthrough](walkthrough.md).
236+
237+
238+
What if we wanted to train a model to generate words given definitions? In T2T, we can change the problem name to be `PROBLEM=word2def_rev`.
239+
240+
All done. Let us know what definitions your model generated.

0 commit comments

Comments
 (0)