Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 357c9d4

Browse files
katelee168Ryan Sepassi
authored andcommitted
Adding example problem to T2T documentation
PiperOrigin-RevId: 166912906
1 parent 1fc6766 commit 357c9d4

File tree

1 file changed

+229
-3
lines changed

1 file changed

+229
-3
lines changed

docs/new_problem.md

Lines changed: 229 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,234 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
99
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
1010
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
1111

12-
Here we show how to hook-up your own data to train T2T models on it.
12+
Let's add a new dataset together and train the transformer model. We'll be learning to define English words by training the transformer to "translate" between English words and their definitions on a character level.
1313

14-
## The Problem class
14+
# About the Problem
1515

16-
TODO: complete.
16+
For each problem we want to tackle we create a new problem class and register it. Let's call our problem `Word2def`.
17+
18+
Since many text2text problems share similar methods, there's already a class called `Text2TextProblem` that extends the base problem class, `Problem` (both found in `problem.py`).
19+
20+
For our problem, we can go ahead and create the file `word2def.py` in the `data_generators` folder and add our new problem, `Word2def`, which extends `TranslateProblem`. Let's also register it while we're at it so we can specify the problem through flags.
21+
22+
```python
23+
@registry.register_problem()
24+
class Word2def(problem.Text2TextProblem):
25+
"""Problem spec for English word to dictionary definition."""
26+
return NotImplementedError()
27+
```
28+
29+
We need to implement the following methods from `Text2TextProblem` in our new class:
30+
* is_character_level
31+
* targeted_vocab_size
32+
* generator
33+
* input_space_id
34+
* target_space_id
35+
* num_shards
36+
* vocab_name
37+
* use_subword_tokenizer
38+
39+
Let's tackle them one by one:
40+
41+
**input_space_id, target_space_id, is_character_level, targeted_vocab_size, use_subword_tokenizer**:
42+
43+
SpaceIDs tell Tensor2Tensor what sort of space the input and target tensors are in. These are things like, EN_CHR (English character), EN_TOK (English token), AUDIO_WAV (audio waveform), IMAGE, DNA (genetic bases). The complete list can be found at `data_generators/problem.py` in the class `SpaceID`.
44+
45+
Since we're generating definitions and feeding in words at the character level, we set `is_character_level` to true, and use the same SpaceID, EN_CHR, for both input and target. Additionally, since we aren't using tokens, we don't need to give a `targeted_vocab_size` or define `use_subword_tokenizer`.
46+
47+
**vocab_name**:
48+
49+
`vocab_name` will be used to name your vocabulary files. We can call ours `'vocab.word2def.en'`
50+
51+
**num_shards**:
52+
53+
The number of shards to break data files into.
54+
55+
```python
56+
@registry.register_problem()
57+
class Word2def(problem.Text2TextProblem):
58+
"""Problem spec for English word to dictionary definition."""
59+
def is_character_level(self):
60+
return True
61+
62+
@property
63+
def vocab_name(self):
64+
return "vocab.word2def.en"
65+
66+
@property
67+
def input_space_id(self):
68+
return problem.SpaceID.EN_CHR
69+
70+
@property
71+
def target_space_id(self):
72+
return problem.SpaceID.EN_CHR
73+
74+
@property
75+
def num_shards(self):
76+
return 100
77+
78+
@property
79+
def use_subword_tokenizer(self):
80+
return False
81+
```
82+
83+
**generator**:
84+
85+
We're almost done. `generator` generates the training and evaluation data and stores them in files like "word2def_train.lang1" in your DATA_DIR. Thankfully several commonly used methods like `character_generator`, and `token_generator` are already written in the file `wmt.py`. We will import `character_generator` and write:
86+
```python
87+
def generator(self, data_dir, tmp_dir, train):
88+
character_vocab = text_encoder.ByteTextEncoder()
89+
datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
90+
tag = "train" if train else "dev"
91+
return character_generator(datasets[0], datasets[1], character_vocab, EOS)
92+
```
93+
94+
Now our `word2def.py` file looks like the below:
95+
96+
```python
97+
@registry.register_problem()
98+
class Word2def(problem.Text2TextProblem):
99+
"""Problem spec for English word to dictionary definition."""
100+
@property
101+
def is_character_level(self):
102+
return True
103+
104+
@property
105+
def vocab_name(self):
106+
return "vocab.word2def.en"
107+
108+
def generator(self, data_dir, tmp_dir, train):
109+
character_vocab = text_encoder.ByteTextEncoder()
110+
datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
111+
tag = "train" if train else "dev"
112+
return character_generator(datasets[0], datasets[1], character_vocab, EOS)
113+
114+
@property
115+
def input_space_id(self):
116+
return problem.SpaceID.EN_CHR
117+
118+
@property
119+
def target_space_id(self):
120+
return problem.SpaceID.EN_CHR
121+
122+
@property
123+
def num_shards(self):
124+
return 100
125+
126+
@property
127+
def use_subword_tokenizer(self):
128+
return False
129+
```
130+
131+
## Data:
132+
Now we need to tell Tensor2Tensor where our data is located.
133+
134+
I've gone ahead and split all words into a train and test set and saved them in files called `words.train.txt`, `words.test.txt`,
135+
`definitions.train.txt`, and `definitions.test.txt` in a directory called `LOCATION_OF_DATA/`. Let's tell T2T where these files are:
136+
137+
```python
138+
# English Word2def datasets
139+
_WORD2DEF_TRAIN_DATASETS = [
140+
[
141+
"LOCATION_OF_DATA/", ("words_train.txt", "definitions_train.txt")
142+
]
143+
]
144+
_WORD2DEF_TEST_DATASETS = [
145+
[
146+
"LOCATION_OF_DATA", ("words_test.txt", "definitions_test.txt")
147+
]
148+
]
149+
```
150+
151+
## Putting it all together
152+
153+
Now our `word2def.py` file looks like: (with the correct imports)
154+
```python
155+
""" Problem definition for word to dictionary definition.
156+
"""
157+
158+
from __future__ import absolute_import
159+
from __future__ import division
160+
from __future__ import print_function
161+
162+
import os
163+
import tarfile # do we need this import
164+
165+
import google3
166+
167+
from tensor2tensor.data_generators import generator_utils
168+
from tensor2tensor.data_generators import problem
169+
from tensor2tensor.data_generators import text_encoder
170+
from tensor2tensor.data_generators.wmt import character_generator
171+
172+
from tensor2tensor.utils import registry
173+
174+
import tensorflow as tf
175+
176+
FLAGS = tf.flags.FLAGS
177+
178+
# English Word2def datasets
179+
_WORD2DEF_TRAIN_DATASETS = [
180+
LOCATION_OF_DATA+'words_train.txt',
181+
LOCATION_OF_DATA+'definitions_train.txt'
182+
]
183+
184+
_WORD2DEF_TEST_DATASETS = [
185+
LOCATION_OF_DATA+'words_test.txt',
186+
LOCATION_OF_DATA+'definitions_test.txt'
187+
]
188+
189+
@registry.register_problem()
190+
class Word2def(problem.Text2TextProblem):
191+
"""Problem spec for English word to dictionary definition."""
192+
@property
193+
def is_character_level(self):
194+
return True
195+
196+
@property
197+
def vocab_name(self):
198+
return "vocab.word2def.en"
199+
200+
def generator(self, data_dir, tmp_dir, train):
201+
character_vocab = text_encoder.ByteTextEncoder()
202+
datasets = _WORD2DEF_TRAIN_DATASETS if train else _WORD2DEF_TEST_DATASETS
203+
tag = "train" if train else "dev"
204+
return character_generator(datasets[0], datasets[1], character_vocab, EOS)
205+
206+
@property
207+
def input_space_id(self):
208+
return problem.SpaceID.EN_CHR
209+
210+
@property
211+
def target_space_id(self):
212+
return problem.SpaceID.EN_CHR
213+
214+
@property
215+
def num_shards(self):
216+
return 100
217+
218+
@property
219+
def use_subword_tokenizer(self):
220+
return False
221+
222+
```
223+
224+
# Hyperparameters
225+
All hyperparamters inherit from `_default_hparams()` in `problem.py.` If you would like to customize your hyperparameters, add another method to the file `problem_hparams.py`.
226+
227+
# Run the problem
228+
Now that we've gotten our problem set up, let's train a model and generate definitions.
229+
230+
We specify our problem name, the model, and hparams.
231+
```bash
232+
PROBLEM=word2def
233+
MODEL=transformer
234+
HPARAMS=transofmer_base_single_gpu
235+
```
236+
237+
The rest of the steps are as given in the [walkthrough](walkthrough.md).
238+
239+
240+
What if we wanted to train a model to generate words given definitions? In T2T, we can change the problem name to be `PROBLEM=word2def_rev`.
241+
242+
All done. Let us know what definitions your model generated.

0 commit comments

Comments
 (0)