Skip to content

Commit 868bc9d

Browse files
committed
Inital ts convert
1 parent 3d8ad3e commit 868bc9d

File tree

3 files changed

+348
-1
lines changed

3 files changed

+348
-1
lines changed

bio2zarr/plink.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def generate_schema(
113113
dtype="i1",
114114
dimensions=["variants", "samples", "ploidy"],
115115
description=None,
116-
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
116+
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),
117117
),
118118
vcz.ZarrArraySpec(
119119
name="call_genotype_mask",

bio2zarr/ts.py

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
import logging
2+
import pathlib
3+
4+
import numpy as np
5+
import tskit
6+
7+
from bio2zarr import constants, core, vcz
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class TskitFormat(vcz.Source):
13+
def __init__(self, ts_path, contig_id=None, ploidy=None, isolated_as_missing=False):
14+
self._path = ts_path
15+
self.ts = tskit.load(ts_path)
16+
self.contig_id = contig_id if contig_id is not None else "1"
17+
self.isolated_as_missing = isolated_as_missing
18+
19+
self._make_sample_mapping(ploidy)
20+
self.positions = self.ts.sites_position
21+
22+
@property
23+
def path(self):
24+
return self._path
25+
26+
@property
27+
def num_records(self):
28+
return self.ts.num_sites
29+
30+
@property
31+
def num_samples(self):
32+
return self._num_samples
33+
34+
@property
35+
def samples(self):
36+
return self._samples
37+
38+
@property
39+
def root_attrs(self):
40+
return {}
41+
42+
@property
43+
def contigs(self):
44+
return [vcz.Contig(id=self.contig_id)]
45+
46+
def _make_sample_mapping(self, ploidy):
47+
ts = self.ts
48+
self.individual_ploidies = []
49+
self.max_ploidy = 0
50+
51+
if ts.num_individuals > 0 and ploidy is not None:
52+
raise ValueError(
53+
"Cannot specify ploidy when individuals are present in tables"
54+
)
55+
56+
# Find all sample nodes that reference individuals
57+
individuals = np.unique(ts.tables.nodes.individual[ts.samples()])
58+
if len(individuals) == 1 and individuals[0] == tskit.NULL:
59+
# No samples refer to individuals
60+
individuals = None
61+
else:
62+
# np.unique sorts the argument, so if NULL (-1) is present it
63+
# will be the first value.
64+
if individuals[0] == tskit.NULL:
65+
raise ValueError(
66+
"Sample nodes must either all be associated with individuals "
67+
"or not associated with any individuals"
68+
)
69+
70+
if individuals is not None:
71+
self.sample_ids = []
72+
for i in individuals:
73+
if i < 0 or i >= self.ts.num_individuals:
74+
raise ValueError("Invalid individual IDs provided.")
75+
ind = self.ts.individual(i)
76+
if len(ind.nodes) == 0:
77+
raise ValueError(f"Individual {i} not associated with a node")
78+
is_sample = {ts.node(u).is_sample() for u in ind.nodes}
79+
if len(is_sample) != 1:
80+
raise ValueError(
81+
f"Individual {ind.id} has nodes that are sample and "
82+
"non-samples"
83+
)
84+
self.sample_ids.extend(ind.nodes)
85+
self.individual_ploidies.append(len(ind.nodes))
86+
self.max_ploidy = max(self.max_ploidy, len(ind.nodes))
87+
else:
88+
if ploidy is None:
89+
ploidy = 1
90+
if ploidy < 1:
91+
raise ValueError("Ploidy must be >= 1")
92+
if ts.num_samples % ploidy != 0:
93+
raise ValueError("Sample size must be divisible by ploidy")
94+
self.individual_ploidies = np.full(
95+
ts.num_samples // ploidy, ploidy, dtype=np.int32
96+
)
97+
self.max_ploidy = ploidy
98+
self.sample_ids = np.arange(ts.num_samples, dtype=np.int32)
99+
100+
self._num_samples = len(self.individual_ploidies)
101+
102+
self._samples = [vcz.Sample(id=f"tsk_{j}") for j in range(self.num_samples)]
103+
104+
def iter_contig(self, start, stop):
105+
yield from (0 for _ in range(start, stop))
106+
107+
def iter_field(self, field_name, shape, start, stop):
108+
if field_name == "position":
109+
for pos in self.ts.tables.sites.position[start:stop]:
110+
yield int(pos)
111+
else:
112+
raise ValueError(f"Unknown field {field_name}")
113+
114+
def iter_alleles(self, start, stop, num_alleles):
115+
for variant in self.ts.variants(
116+
samples=self.sample_ids,
117+
isolated_as_missing=self.isolated_as_missing,
118+
left=self.positions[start],
119+
right=self.positions[stop] if stop < self.num_records else None,
120+
):
121+
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
122+
for i, allele in enumerate(variant.alleles):
123+
assert i < num_alleles
124+
alleles[i] = allele
125+
yield alleles
126+
127+
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
128+
gt = np.zeros(shape, dtype=np.int8)
129+
phased = np.zeros(shape[:-1], dtype=bool)
130+
131+
for variant in self.ts.variants(
132+
samples=self.sample_ids,
133+
isolated_as_missing=self.isolated_as_missing,
134+
left=self.positions[start],
135+
right=self.positions[stop] if stop < self.num_records else None,
136+
):
137+
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
138+
for i, allele in enumerate(variant.alleles):
139+
assert i < num_alleles
140+
alleles[i] = allele
141+
142+
genotypes = variant.genotypes
143+
sample_index = 0
144+
for i, ploidy in enumerate(self.individual_ploidies):
145+
for j in range(ploidy):
146+
if j < self.max_ploidy: # Only fill up to max_ploidy
147+
try:
148+
gt[i, j] = genotypes[sample_index + j]
149+
except IndexError:
150+
# This can happen if the ploidy varies between individuals
151+
gt[i, j] = -2 # Fill value
152+
153+
# In tskit, all genotypes are considered phased
154+
phased[i] = True
155+
sample_index += ploidy
156+
157+
yield alleles, (gt, phased)
158+
159+
def generate_schema(
160+
self,
161+
variants_chunk_size=None,
162+
samples_chunk_size=None,
163+
):
164+
n = self.num_samples
165+
m = self.ts.num_sites
166+
167+
# Determine max number of alleles
168+
max_alleles = 0
169+
for variant in self.ts.variants():
170+
max_alleles = max(max_alleles, len(variant.alleles))
171+
172+
logging.info(f"Scanned tskit with {n} samples and {m} variants")
173+
logging.info(
174+
f"Maximum ploidy: {self.max_ploidy}, maximum alleles: {max_alleles}"
175+
)
176+
177+
dimensions = {
178+
"variants": vcz.VcfZarrDimension(
179+
size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE
180+
),
181+
"samples": vcz.VcfZarrDimension(
182+
size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE
183+
),
184+
"ploidy": vcz.VcfZarrDimension(size=self.max_ploidy),
185+
"alleles": vcz.VcfZarrDimension(size=max_alleles),
186+
}
187+
188+
schema_instance = vcz.VcfZarrSchema(
189+
format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
190+
dimensions=dimensions,
191+
fields=[],
192+
)
193+
194+
logger.info(
195+
"Generating schema with chunks="
196+
f"{schema_instance.dimensions['variants'].chunk_size}, "
197+
f"{schema_instance.dimensions['samples'].chunk_size}"
198+
)
199+
200+
array_specs = [
201+
vcz.ZarrArraySpec(
202+
source="position",
203+
name="variant_position",
204+
dtype="i4",
205+
dimensions=["variants"],
206+
description="Position of each variant",
207+
),
208+
vcz.ZarrArraySpec(
209+
source=None,
210+
name="variant_allele",
211+
dtype="O",
212+
dimensions=["variants", "alleles"],
213+
description="Alleles for each variant",
214+
),
215+
vcz.ZarrArraySpec(
216+
source=None,
217+
name="variant_contig",
218+
dtype=core.min_int_dtype(0, len(self.contigs)),
219+
dimensions=["variants"],
220+
description="Contig/chromosome index for each variant",
221+
),
222+
vcz.ZarrArraySpec(
223+
source=None,
224+
name="call_genotype_phased",
225+
dtype="bool",
226+
dimensions=["variants", "samples"],
227+
description="Whether the genotype is phased",
228+
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
229+
),
230+
vcz.ZarrArraySpec(
231+
source=None,
232+
name="call_genotype",
233+
dtype="i1",
234+
dimensions=["variants", "samples", "ploidy"],
235+
description="Genotype for each variant and sample",
236+
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),
237+
),
238+
vcz.ZarrArraySpec(
239+
source=None,
240+
name="call_genotype_mask",
241+
dtype="bool",
242+
dimensions=["variants", "samples", "ploidy"],
243+
description="Mask for each genotype call",
244+
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
245+
),
246+
]
247+
schema_instance.fields = array_specs
248+
return schema_instance
249+
250+
251+
def convert(
252+
ts_path,
253+
zarr_path,
254+
*,
255+
contig_id=None,
256+
ploidy=None,
257+
isolated_as_missing=False,
258+
variants_chunk_size=None,
259+
samples_chunk_size=None,
260+
worker_processes=1,
261+
show_progress=False,
262+
):
263+
tskit_format = TskitFormat(
264+
ts_path,
265+
contig_id=contig_id,
266+
ploidy=ploidy,
267+
isolated_as_missing=isolated_as_missing,
268+
)
269+
schema_instance = tskit_format.generate_schema(
270+
variants_chunk_size=variants_chunk_size,
271+
samples_chunk_size=samples_chunk_size,
272+
)
273+
zarr_path = pathlib.Path(zarr_path)
274+
vzw = vcz.VcfZarrWriter(TskitFormat, zarr_path)
275+
# Rough heuristic to split work up enough to keep utilisation high
276+
target_num_partitions = max(1, worker_processes * 4)
277+
vzw.init(
278+
tskit_format,
279+
target_num_partitions=target_num_partitions,
280+
schema=schema_instance,
281+
)
282+
vzw.encode_all_partitions(
283+
worker_processes=worker_processes,
284+
show_progress=show_progress,
285+
)
286+
vzw.finalise(show_progress)
287+
vzw.create_index()

tests/test_ts.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
import tempfile
3+
4+
import numpy as np
5+
import tskit
6+
import zarr
7+
8+
from bio2zarr import ts
9+
10+
11+
class TestTskit:
12+
def test_simple_tree_sequence(self, tmp_path):
13+
tables = tskit.TableCollection(sequence_length=100)
14+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
15+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
16+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
17+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
18+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
19+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
20+
tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1
21+
tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3
22+
tables.edges.add_row(left=0, right=100, parent=4, child=0)
23+
tables.edges.add_row(left=0, right=100, parent=4, child=1)
24+
tables.edges.add_row(left=0, right=100, parent=5, child=2)
25+
tables.edges.add_row(left=0, right=100, parent=5, child=3)
26+
site_id = tables.sites.add_row(position=10, ancestral_state="A")
27+
tables.mutations.add_row(site=site_id, node=4, derived_state="T")
28+
site_id = tables.sites.add_row(position=20, ancestral_state="C")
29+
tables.mutations.add_row(site=site_id, node=5, derived_state="G")
30+
site_id = tables.sites.add_row(position=30, ancestral_state="G")
31+
tables.mutations.add_row(site=site_id, node=0, derived_state="A")
32+
tables.sort()
33+
tree_sequence = tables.tree_sequence()
34+
tree_sequence.dump(tmp_path / "test.trees")
35+
with tempfile.TemporaryDirectory() as tempdir:
36+
zarr_path = os.path.join(tempdir, "test_output.zarr")
37+
ts.convert(tmp_path / "test.trees", zarr_path, show_progress=False)
38+
zroot = zarr.open(zarr_path, mode="r")
39+
assert zroot["variant_position"].shape == (3,)
40+
assert list(zroot["variant_position"][:]) == [10, 20, 30]
41+
42+
alleles = zroot["variant_allele"][:]
43+
assert np.array_equal(alleles, [["A", "T"], ["C", "G"], ["G", "A"]])
44+
45+
genotypes = zroot["call_genotype"][:]
46+
assert np.array_equal(
47+
genotypes, [[[1, 1], [0, 0]], [[0, 0], [1, 1]], [[1, 0], [0, 0]]]
48+
)
49+
50+
phased = zroot["call_genotype_phased"][:]
51+
assert np.all(phased)
52+
53+
contigs = zroot["contig_id"][:]
54+
assert np.array_equal(contigs, ["1"])
55+
56+
contig = zroot["variant_contig"][:]
57+
assert np.array_equal(contig, [0, 0, 0])
58+
59+
samples = zroot["sample_id"][:]
60+
assert np.array_equal(samples, ["tsk_0", "tsk_1"])

0 commit comments

Comments
 (0)