Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 554973f

Browse files
alexykuRyan Sepassi
authored andcommitted
Adding a minimum viable DNA data encoder.
PiperOrigin-RevId: 164201984
1 parent 95ee9e5 commit 554973f

File tree

4 files changed

+183
-66
lines changed

4 files changed

+183
-66
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Encoders for DNA data.
17+
18+
* DNAEncoder: ACTG strings to ints and back
19+
* DelimitedDNAEncoder: for delimited subsequences
20+
"""
21+
22+
from __future__ import absolute_import
23+
from __future__ import division
24+
from __future__ import print_function
25+
26+
import itertools
27+
# Dependency imports
28+
29+
from six.moves import xrange # pylint: disable=redefined-builtin
30+
from tensor2tensor.data_generators import text_encoder
31+
32+
33+
class DNAEncoder(text_encoder.TextEncoder):
34+
"""ACTG strings to ints and back. Optionally chunks bases into single ids.
35+
36+
To use a different character set, subclass and set BASES to the char set. UNK
37+
and PAD must not appear in the char set, but can also be reset.
38+
39+
Uses 'N' as an unknown base.
40+
"""
41+
BASES = list("ACTG")
42+
UNK = "N"
43+
PAD = "0"
44+
45+
def __init__(self,
46+
chunk_size=1,
47+
num_reserved_ids=text_encoder.NUM_RESERVED_TOKENS):
48+
super(DNAEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
49+
# Build a vocabulary of chunks of size chunk_size
50+
self._chunk_size = chunk_size
51+
tokens = self._tokens()
52+
tokens.sort()
53+
ids = range(self._num_reserved_ids, len(tokens) + self._num_reserved_ids)
54+
self._ids_to_tokens = dict(zip(ids, tokens))
55+
self._tokens_to_ids = dict(zip(tokens, ids))
56+
57+
def _tokens(self):
58+
chunks = []
59+
for size in range(1, self._chunk_size + 1):
60+
c = itertools.product(self.BASES + [self.UNK], repeat=size)
61+
num_pad = self._chunk_size - size
62+
padding = (self.PAD,) * num_pad
63+
c = [el + padding for el in c]
64+
chunks.extend(c)
65+
return chunks
66+
67+
@property
68+
def vocab_size(self):
69+
return len(self._ids_to_tokens) + self._num_reserved_ids
70+
71+
def encode(self, s):
72+
bases = list(s)
73+
extra = len(bases) % self._chunk_size
74+
if extra > 0:
75+
pad = [self.PAD] * (self._chunk_size - extra)
76+
bases.extend(pad)
77+
assert (len(bases) % self._chunk_size) == 0
78+
num_chunks = len(bases) // self._chunk_size
79+
ids = []
80+
for chunk_idx in xrange(num_chunks):
81+
start_idx = chunk_idx * self._chunk_size
82+
end_idx = start_idx + self._chunk_size
83+
chunk = tuple(bases[start_idx:end_idx])
84+
if chunk not in self._tokens_to_ids:
85+
raise ValueError("Unrecognized token %s" % chunk)
86+
ids.append(self._tokens_to_ids[chunk])
87+
return ids
88+
89+
def decode(self, ids):
90+
bases = []
91+
for idx in ids:
92+
if idx >= self._num_reserved_ids:
93+
chunk = self._ids_to_tokens[idx]
94+
if self.PAD in chunk:
95+
chunk = chunk[:chunk.index(self.PAD)]
96+
else:
97+
chunk = [text_encoder.RESERVED_TOKENS[idx]]
98+
bases.extend(chunk)
99+
return "".join(bases)
100+
101+
102+
class DelimitedDNAEncoder(DNAEncoder):
103+
"""DNAEncoder for delimiter separated subsequences.
104+
105+
Uses ',' as default delimiter.
106+
"""
107+
108+
def __init__(self, delimiter=",", **kwargs):
109+
self._delimiter = delimiter
110+
super(DelimitedDNAEncoder, self).__init__(**kwargs)
111+
112+
@property
113+
def delimiter(self):
114+
return self._delimiter
115+
116+
def _tokens(self):
117+
return super(DelimitedDNAEncoder, self)._tokens() + [self.delimiter]
118+
119+
def encode(self, delimited_string):
120+
ids = []
121+
for s in delimited_string.split(self.delimiter):
122+
ids.extend(super(DelimitedDNAEncoder, self).encode(s))
123+
ids.append(self._tokens_to_ids[self.delimiter])
124+
return ids[:-1]
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for tensor2tensor.data_generators.dna_encoder."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
# Dependency imports
23+
24+
from tensor2tensor.data_generators import dna_encoder
25+
import tensorflow as tf
26+
27+
28+
class DnaEncoderTest(tf.test.TestCase):
29+
30+
def test_encode_decode(self):
31+
original = 'TTCGCGGNNNAACCCAACGCCATCTATGTANNTTGAGTTGTTGAGTTAAA'
32+
33+
# Encoding should be reversible for any reasonable chunk size.
34+
for chunk_size in [1, 2, 4, 6, 8]:
35+
encoder = dna_encoder.DNAEncoder(chunk_size=chunk_size)
36+
encoded = encoder.encode(original)
37+
decoded = encoder.decode(encoded)
38+
self.assertEqual(original, decoded)
39+
40+
def test_delimited_dna_encoder(self):
41+
original = 'TTCGCGGNNN,AACCCAACGC,CATCTATGTA,NNTTGAGTTG,TTGAGTTAAA'
42+
43+
# Encoding should be reversible for any reasonable chunk size.
44+
for chunk_size in [1, 2, 4, 6, 8]:
45+
encoder = dna_encoder.DelimitedDNAEncoder(chunk_size=chunk_size)
46+
encoded = encoder.encode(original)
47+
decoded = encoder.decode(encoded)
48+
self.assertEqual(original, decoded)
49+
50+
51+
if __name__ == '__main__':
52+
tf.test.main()

tensor2tensor/data_generators/gene_expression.py

Lines changed: 4 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from __future__ import division
3636
from __future__ import print_function
3737

38-
import itertools
3938
import math
4039
import multiprocessing as mp
4140
import os
@@ -47,6 +46,7 @@
4746

4847
from six.moves import xrange # pylint: disable=redefined-builtin
4948

49+
from tensor2tensor.data_generators import dna_encoder
5050
from tensor2tensor.data_generators import generator_utils
5151
from tensor2tensor.data_generators import problem
5252
from tensor2tensor.data_generators import text_encoder
@@ -56,7 +56,6 @@
5656
import tensorflow as tf
5757

5858
MAX_CONCURRENT_PROCESSES = 10
59-
_bases = list("ACTG")
6059

6160

6261
class GeneExpressionProblem(problem.Problem):
@@ -82,7 +81,7 @@ def chunk_size(self):
8281
def feature_encoders(self, data_dir):
8382
del data_dir
8483
return {
85-
"inputs": DNAEncoder(chunk_size=self.chunk_size),
84+
"inputs": dna_encoder.DNAEncoder(chunk_size=self.chunk_size),
8685
# TODO(rsepassi): RealEncoder?
8786
"targets": text_encoder.TextEncoder()
8887
}
@@ -244,7 +243,7 @@ def dataset_generator(filepath,
244243
chunk_size=1,
245244
start_idx=None,
246245
end_idx=None):
247-
encoder = DNAEncoder(chunk_size=chunk_size)
246+
encoder = dna_encoder.DNAEncoder(chunk_size=chunk_size)
248247
with h5py.File(filepath, "r") as h5_file:
249248
# Get input keys from h5_file
250249
src_keys = [s % dataset for s in ["%s_in", "%s_na", "%s_out"]]
@@ -278,7 +277,7 @@ def to_example_dict(encoder, inputs, mask, outputs):
278277
while idx != last_idx + 1:
279278
bases.append(encoder.UNK)
280279
last_idx += 1
281-
bases.append(_bases[base_id])
280+
bases.append(encoder.BASES[base_id])
282281
last_idx = idx
283282
assert len(inputs) == len(bases)
284283

@@ -297,62 +296,3 @@ def to_example_dict(encoder, inputs, mask, outputs):
297296
ex_dict = dict(
298297
zip(example_keys, [input_ids, targets_mask, targets, targets_shape]))
299298
return ex_dict
300-
301-
302-
class DNAEncoder(text_encoder.TextEncoder):
303-
"""ACTG strings to ints and back. Optionally chunks bases into single ids.
304-
305-
Uses 'X' as an unknown base.
306-
"""
307-
UNK = "X"
308-
PAD = "0"
309-
310-
def __init__(self,
311-
chunk_size=1,
312-
num_reserved_ids=text_encoder.NUM_RESERVED_TOKENS):
313-
super(DNAEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
314-
# Build a vocabulary of chunks of size chunk_size
315-
self._chunk_size = chunk_size
316-
chunks = []
317-
for size in range(1, chunk_size + 1):
318-
c = itertools.product(_bases + [DNAEncoder.UNK], repeat=size)
319-
num_pad = chunk_size - size
320-
padding = (DNAEncoder.PAD,) * num_pad
321-
c = [el + padding for el in c]
322-
chunks.extend(c)
323-
chunks.sort()
324-
ids = range(self._num_reserved_ids, len(chunks) + self._num_reserved_ids)
325-
self._ids_to_chunk = dict(zip(ids, chunks))
326-
self._chunks_to_ids = dict(zip(chunks, ids))
327-
328-
@property
329-
def vocab_size(self):
330-
return len(self._ids_to_chunk) + self._num_reserved_ids
331-
332-
def encode(self, s):
333-
bases = list(s)
334-
pad = [DNAEncoder.PAD] * (len(bases) % self._chunk_size)
335-
bases.extend(pad)
336-
assert (len(bases) % self._chunk_size) == 0
337-
num_chunks = len(bases) // self._chunk_size
338-
ids = []
339-
for chunk_idx in xrange(num_chunks):
340-
start_idx = chunk_idx * self._chunk_size
341-
end_idx = start_idx + self._chunk_size
342-
chunk = tuple(bases[start_idx:end_idx])
343-
if chunk not in self._chunks_to_ids:
344-
raise ValueError("Unrecognized chunk %s" % chunk)
345-
ids.append(self._chunks_to_ids[chunk])
346-
return ids
347-
348-
def decode(self, ids):
349-
bases = []
350-
for idx in ids:
351-
if idx >= self._num_reserved_ids:
352-
chunk = self._ids_to_chunk[idx]
353-
if DNAEncoder.PAD in chunk:
354-
chunk = chunk[:chunk.index(DNAEncoder.PAD)]
355-
else:
356-
chunk = [text_encoder.RESERVED_TOKENS[idx]]
357-
bases.extend(chunk)
358-
return "".join(bases)

tensor2tensor/data_generators/gene_expression_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import numpy as np
2424

25+
from tensor2tensor.data_generators import dna_encoder
2526
from tensor2tensor.data_generators import gene_expression
2627

2728
import tensorflow as tf
@@ -40,8 +41,8 @@ def _oneHotBases(self, bases):
4041
return np.array(one_hots)
4142

4243
def testRecordToExample(self):
43-
encoder = gene_expression.DNAEncoder(chunk_size=2)
44-
raw_inputs = ["A", "C", "G", "X", "C", "T"]
44+
encoder = dna_encoder.DNAEncoder(chunk_size=2)
45+
raw_inputs = ["A", "C", "G", "N", "C", "T"]
4546

4647
# Put in numpy arrays in the same format as in the h5 file
4748
inputs = self._oneHotBases(raw_inputs)

0 commit comments

Comments
 (0)