1
+
2
+ # coding=utf-8
3
+ # Copyright 2018 The Tensor2Tensor Authors.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Data generators for translation data-sets."""
18
+
19
+ import os
20
+ from tensor2tensor .data_generators import generator_utils
21
+ from tensor2tensor .data_generators import problem
22
+ from tensor2tensor .data_generators import text_encoder
23
+ from tensor2tensor .data_generators import text_problems
24
+ from tensor2tensor .data_generators import translate
25
+ from tensor2tensor .utils import registry
26
+
27
+ import tensorflow as tf
28
+
29
+ EOS = text_encoder .EOS_ID
30
+
31
+
32
+ _ENTN_TRAIN_DATASETS = [
33
+ [
34
+ "https://github.com/LauraMartinus/ukuxhumana/blob/master/data/en_tn/eng_tswane.train.tar.gz?raw=true" ,
35
+ (
36
+ "entn_parallel.train.en" ,
37
+ "entn_parallel.train.tn"
38
+ )
39
+ ]
40
+ ]
41
+
42
+ _ENTN_TEST_DATASETS = [
43
+ [
44
+ "https://github.com/LauraMartinus/ukuxhumana/blob/master/data/en_tn/eng_tswane.dev.tar.gz?raw=true" ,
45
+ (
46
+ "entn_parallel.dev.en" ,
47
+ "entn_parallel.dev.tn"
48
+ )
49
+ ]
50
+ ]
51
+
52
+
53
+ @registry .register_problem
54
+ class TranslateEntnRma (translate .TranslateProblem ):
55
+ """Problem spec for English-Setswana translation using the RMA Autshumato dataset"""
56
+ @property
57
+ def approx_vocab_size (self ):
58
+ return 2 ** 15 # 32768
59
+
60
+ @property
61
+ def vocab_filename (self ):
62
+ return "vocab.entn.%d" % self .approx_vocab_size
63
+
64
+
65
+ def source_data_files (self , dataset_split ):
66
+ train = dataset_split == problem .DatasetSplit .TRAIN
67
+ return _ENTN_TRAIN_DATASETS if train else _ENTN_TEST_DATASETS
0 commit comments