You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Here we show how to hook-up your own data to train T2T models on it.
12
+
Let's add a new dataset together and train the transformer model. We'll be learning to define English words by training the transformer to "translate" between English words and their definitions on a character level.
13
13
14
-
## The Problem class
14
+
#About the Problem
15
15
16
-
TODO: complete.
16
+
For each problem we want to tackle we create a new problem class and register it. Let's call our problem `Word2def`.
17
+
18
+
Since many text2text problems share similar methods, there's already a class called `Text2TextProblem` that extends the base problem class, `Problem` (both found in `problem.py`).
19
+
20
+
For our problem, we can go ahead and create the file `word2def.py` in the `data_generators` folder and add our new problem, `Word2def`, which extends `TranslateProblem`. Let's also register it while we're at it so we can specify the problem through flags.
21
+
22
+
```python
23
+
@registry.register_problem()
24
+
classWord2def(problem.Text2TextProblem):
25
+
"""Problem spec for English word to dictionary definition."""
26
+
returnNotImplementedError()
27
+
```
28
+
29
+
We need to implement the following methods from `Text2TextProblem` in our new class:
SpaceIDs tell Tensor2Tensor what sort of space the input and target tensors are in. These are things like, EN_CHR (English character), EN_TOK (English token), AUDIO_WAV (audio waveform), IMAGE, DNA (genetic bases). The complete list can be found at `data_generators/problem.py` in the class `SpaceID`.
44
+
45
+
Since we're generating definitions and feeding in words at the character level, we set `is_character_level` to true, and use the same SpaceID, EN_CHR, for both input and target. Additionally, since we aren't using tokens, we don't need to give a `targeted_vocab_size` or define `use_subword_tokenizer`.
46
+
47
+
**vocab_name**:
48
+
49
+
`vocab_name` will be used to name your vocabulary files. We can call ours `'vocab.word2def.en'`
50
+
51
+
**num_shards**:
52
+
53
+
The number of shards to break data files into.
54
+
55
+
```python
56
+
@registry.register_problem()
57
+
classWord2def(problem.Text2TextProblem):
58
+
"""Problem spec for English word to dictionary definition."""
59
+
defis_character_level(self):
60
+
returnTrue
61
+
62
+
@property
63
+
defvocab_name(self):
64
+
return"vocab.word2def.en"
65
+
66
+
@property
67
+
definput_space_id(self):
68
+
return problem.SpaceID.EN_CHR
69
+
70
+
@property
71
+
deftarget_space_id(self):
72
+
return problem.SpaceID.EN_CHR
73
+
74
+
@property
75
+
defnum_shards(self):
76
+
return100
77
+
78
+
@property
79
+
defuse_subword_tokenizer(self):
80
+
returnFalse
81
+
```
82
+
83
+
**generator**:
84
+
85
+
We're almost done. `generator` generates the training and evaluation data and stores them in files like "word2def_train.lang1" in your DATA_DIR. Thankfully several commonly used methods like `character_generator`, and `token_generator` are already written in the file `wmt.py`. We will import `character_generator` and write:
All hyperparamters inherit from `_default_hparams()` in `problem.py.` If you would like to customize your hyperparameters, add another method to the file `problem_hparams.py`.
226
+
227
+
# Run the problem
228
+
Now that we've gotten our problem set up, let's train a model and generate definitions.
229
+
230
+
We specify our problem name, the model, and hparams.
231
+
```bash
232
+
PROBLEM=word2def
233
+
MODEL=transformer
234
+
HPARAMS=transofmer_base_single_gpu
235
+
```
236
+
237
+
The rest of the steps are as given in the [walkthrough](walkthrough.md).
238
+
239
+
240
+
What if we wanted to train a model to generate words given definitions? In T2T, we can change the problem name to be `PROBLEM=word2def_rev`.
241
+
242
+
All done. Let us know what definitions your model generated.
0 commit comments