Skip to content

Commit 8d0f9f3

Browse files
committed
Add tests of tskit source
1 parent 0c6fd5e commit 8d0f9f3

File tree

2 files changed

+339
-12
lines changed

2 files changed

+339
-12
lines changed

bio2zarr/ts.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _make_sample_mapping(self, ploidy):
9595
ts.num_samples // ploidy, ploidy, dtype=np.int32
9696
)
9797
self.max_ploidy = ploidy
98-
self.sample_ids = np.arange(ts.num_samples, dtype=np.int32)
98+
self.sample_ids = (ts.nodes_flags & tskit.NODE_IS_SAMPLE).nonzero()[0]
9999

100100
self._num_samples = len(self.individual_ploidies)
101101

@@ -125,33 +125,30 @@ def iter_alleles(self, start, stop, num_alleles):
125125
yield alleles
126126

127127
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
128-
gt = np.zeros(shape, dtype=np.int8)
129-
phased = np.zeros(shape[:-1], dtype=bool)
128+
# In tskit, all genotypes are considered phased
129+
phased = np.ones(shape[:-1], dtype=bool)
130130

131131
for variant in self.ts.variants(
132132
samples=self.sample_ids,
133133
isolated_as_missing=self.isolated_as_missing,
134134
left=self.positions[start],
135135
right=self.positions[stop] if stop < self.num_records else None,
136136
):
137+
gt = np.full(shape, constants.INT_FILL, dtype=np.int8)
137138
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
138139
for i, allele in enumerate(variant.alleles):
140+
# None is returned by tskit in the case of a missing allele
141+
if allele is None:
142+
continue
139143
assert i < num_alleles
140144
alleles[i] = allele
141145

142146
genotypes = variant.genotypes
143147
sample_index = 0
144148
for i, ploidy in enumerate(self.individual_ploidies):
145149
for j in range(ploidy):
146-
if j < self.max_ploidy: # Only fill up to max_ploidy
147-
try:
148-
gt[i, j] = genotypes[sample_index + j]
149-
except IndexError:
150-
# This can happen if the ploidy varies between individuals
151-
gt[i, j] = -2 # Fill value
152-
153-
# In tskit, all genotypes are considered phased
154-
phased[i] = True
150+
if j < self.max_ploidy:
151+
gt[i, j] = genotypes[sample_index + j]
155152
sample_index += ploidy
156153

157154
yield alleles, (gt, phased)

tests/test_ts.py

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import tempfile
33

44
import numpy as np
5+
import pytest
56
import tskit
67
import zarr
78

@@ -58,3 +59,332 @@ def test_simple_tree_sequence(self, tmp_path):
5859

5960
samples = zroot["sample_id"][:]
6061
assert np.array_equal(samples, ["tsk_0", "tsk_1"])
62+
63+
64+
class TestTskitFormat:
65+
"""Unit tests for TskitFormat without using full conversion."""
66+
67+
@pytest.fixture()
68+
def simple_ts(self, tmp_path):
69+
tables = tskit.TableCollection(sequence_length=100)
70+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
71+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
72+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
73+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
74+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
75+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
76+
tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1
77+
tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3
78+
tables.edges.add_row(left=0, right=100, parent=4, child=0)
79+
tables.edges.add_row(left=0, right=100, parent=4, child=1)
80+
tables.edges.add_row(left=0, right=100, parent=5, child=2)
81+
tables.edges.add_row(left=0, right=100, parent=5, child=3)
82+
site_id = tables.sites.add_row(position=10, ancestral_state="A")
83+
tables.mutations.add_row(site=site_id, node=4, derived_state="T")
84+
site_id = tables.sites.add_row(position=20, ancestral_state="C")
85+
tables.mutations.add_row(site=site_id, node=5, derived_state="G")
86+
site_id = tables.sites.add_row(position=30, ancestral_state="G")
87+
tables.mutations.add_row(site=site_id, node=0, derived_state="A")
88+
tables.sort()
89+
tree_sequence = tables.tree_sequence()
90+
ts_path = tmp_path / "test.trees"
91+
tree_sequence.dump(ts_path)
92+
return ts_path, tree_sequence
93+
94+
@pytest.fixture()
95+
def no_individuals_ts(self, tmp_path):
96+
tables = tskit.TableCollection(sequence_length=100)
97+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
98+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
99+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
100+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
101+
tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1
102+
tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3
103+
tables.edges.add_row(left=0, right=100, parent=4, child=0)
104+
tables.edges.add_row(left=0, right=100, parent=4, child=1)
105+
tables.edges.add_row(left=0, right=100, parent=5, child=2)
106+
tables.edges.add_row(left=0, right=100, parent=5, child=3)
107+
site_id = tables.sites.add_row(position=10, ancestral_state="A")
108+
tables.mutations.add_row(site=site_id, node=4, derived_state="T")
109+
site_id = tables.sites.add_row(position=20, ancestral_state="C")
110+
tables.mutations.add_row(site=site_id, node=5, derived_state="G")
111+
tables.sort()
112+
tree_sequence = tables.tree_sequence()
113+
ts_path = tmp_path / "no_individuals.trees"
114+
tree_sequence.dump(ts_path)
115+
return ts_path, tree_sequence
116+
117+
def test_initialization(self, simple_ts):
118+
ts_path, tree_sequence = simple_ts
119+
120+
# Test with default parameters
121+
format_obj = ts.TskitFormat(ts_path)
122+
assert format_obj.path == ts_path
123+
assert format_obj.ts.num_sites == tree_sequence.num_sites
124+
assert format_obj.contig_id == "1"
125+
assert not format_obj.isolated_as_missing
126+
127+
# Test with custom parameters
128+
format_obj = ts.TskitFormat(ts_path, contig_id="chr1", isolated_as_missing=True)
129+
assert format_obj.contig_id == "chr1"
130+
assert format_obj.isolated_as_missing
131+
assert format_obj.path == ts_path
132+
133+
def test_basic_properties(self, simple_ts):
134+
ts_path, _ = simple_ts
135+
format_obj = ts.TskitFormat(ts_path)
136+
137+
assert format_obj.num_records == format_obj.ts.num_sites
138+
assert format_obj.num_samples == 2 # Two individuals
139+
assert len(format_obj.samples) == 2
140+
assert format_obj.samples[0].id == "tsk_0"
141+
assert format_obj.samples[1].id == "tsk_1"
142+
143+
assert format_obj.root_attrs == {}
144+
145+
contigs = format_obj.contigs
146+
assert len(contigs) == 1
147+
assert contigs[0].id == "1"
148+
149+
def test_sample_mapping_with_individuals(self, simple_ts):
150+
ts_path, _ = simple_ts
151+
152+
format_obj = ts.TskitFormat(ts_path)
153+
assert format_obj.num_samples == 2
154+
assert format_obj.max_ploidy == 2
155+
assert format_obj.individual_ploidies == [2, 2]
156+
157+
# Should raise error if ploidy specified with individuals
158+
with pytest.raises(
159+
ValueError, match="Cannot specify ploidy when individuals are present"
160+
):
161+
ts.TskitFormat(ts_path, ploidy=2)
162+
163+
def test_sample_mapping_without_individuals(self, no_individuals_ts):
164+
ts_path, tree_sequence = no_individuals_ts
165+
166+
# Default ploidy should be 1
167+
format_obj = ts.TskitFormat(ts_path)
168+
assert format_obj.num_samples == 4
169+
assert format_obj.max_ploidy == 1
170+
assert list(format_obj.individual_ploidies) == [1, 1, 1, 1]
171+
172+
# Explicitly set ploidy to 2
173+
format_obj = ts.TskitFormat(ts_path, ploidy=2)
174+
assert format_obj.num_samples == 2
175+
assert format_obj.max_ploidy == 2
176+
assert list(format_obj.individual_ploidies) == [2, 2]
177+
178+
with pytest.raises(ValueError, match="Ploidy must be >= 1"):
179+
ts.TskitFormat(ts_path, ploidy=0)
180+
181+
with pytest.raises(ValueError, match="Sample size must be divisible by ploidy"):
182+
ts.TskitFormat(ts_path, ploidy=3)
183+
184+
def test_schema_generation(self, simple_ts):
185+
ts_path, _ = simple_ts
186+
format_obj = ts.TskitFormat(ts_path)
187+
188+
schema = format_obj.generate_schema()
189+
assert schema.dimensions["variants"].size == 3
190+
assert schema.dimensions["samples"].size == 2
191+
assert schema.dimensions["ploidy"].size == 2
192+
assert schema.dimensions["alleles"].size == 2 # A/T, C/G, G/A -> max is 2
193+
field_names = [field.name for field in schema.fields]
194+
assert "variant_position" in field_names
195+
assert "variant_allele" in field_names
196+
assert "variant_contig" in field_names
197+
assert "call_genotype" in field_names
198+
assert "call_genotype_phased" in field_names
199+
assert "call_genotype_mask" in field_names
200+
schema = format_obj.generate_schema(
201+
variants_chunk_size=10, samples_chunk_size=5
202+
)
203+
assert schema.dimensions["variants"].chunk_size == 10
204+
assert schema.dimensions["samples"].chunk_size == 5
205+
206+
def test_iter_contig(self, simple_ts):
207+
ts_path, _ = simple_ts
208+
format_obj = ts.TskitFormat(ts_path)
209+
contig_indices = list(format_obj.iter_contig(1, 3))
210+
assert contig_indices == [0, 0]
211+
212+
def test_iter_field(self, simple_ts):
213+
ts_path, _ = simple_ts
214+
format_obj = ts.TskitFormat(ts_path)
215+
positions = list(format_obj.iter_field("position", None, 0, 3))
216+
assert positions == [10, 20, 30]
217+
positions = list(format_obj.iter_field("position", None, 1, 3))
218+
assert positions == [20, 30]
219+
with pytest.raises(ValueError, match="Unknown field"):
220+
list(format_obj.iter_field("unknown_field", None, 0, 3))
221+
222+
def test_iter_alleles(self, simple_ts):
223+
ts_path, _ = simple_ts
224+
format_obj = ts.TskitFormat(ts_path)
225+
alleles_list = list(format_obj.iter_alleles(0, 3, 2))
226+
227+
expected_alleles = np.array([["A", "T"], ["C", "G"], ["G", "A"]])
228+
assert len(alleles_list) == 3
229+
assert np.array_equal(alleles_list, expected_alleles)
230+
231+
# Test with different start/stop
232+
alleles_list = list(format_obj.iter_alleles(1, 3, 2))
233+
assert len(alleles_list) == 2
234+
expected_alleles = np.array([["C", "G"], ["G", "A"]])
235+
assert np.array_equal(alleles_list, expected_alleles)
236+
237+
def test_iter_alleles_and_genotypes(self, simple_ts):
238+
ts_path, _ = simple_ts
239+
format_obj = ts.TskitFormat(ts_path)
240+
241+
shape = (2, 2) # (num_samples, max_ploidy)
242+
results = list(format_obj.iter_alleles_and_genotypes(0, 3, shape, 2))
243+
244+
assert len(results) == 3
245+
246+
for i, (alleles, (gt, phased)) in enumerate(results):
247+
if i == 0:
248+
assert tuple(alleles) == ("A", "T")
249+
assert np.array_equal(gt, [[1, 1], [0, 0]])
250+
elif i == 1:
251+
assert tuple(alleles) == ("C", "G")
252+
assert np.array_equal(gt, [[0, 0], [1, 1]])
253+
elif i == 2:
254+
assert tuple(alleles) == ("G", "A")
255+
assert np.array_equal(gt, [[1, 0], [0, 0]])
256+
assert np.all(phased)
257+
258+
def test_partial_iter_alleles_and_genotypes(self, simple_ts):
259+
ts_path, _ = simple_ts
260+
format_obj = ts.TskitFormat(ts_path)
261+
262+
shape = (2, 2)
263+
results = list(format_obj.iter_alleles_and_genotypes(1, 3, shape, 2))
264+
assert len(results) == 2
265+
alleles, (gt, phased) = results[0]
266+
assert tuple(alleles) == ("C", "G")
267+
assert np.array_equal(gt, [[0, 0], [1, 1]])
268+
assert np.all(phased)
269+
270+
alleles, (gt, phased) = results[1]
271+
assert tuple(alleles) == ("G", "A")
272+
assert np.array_equal(gt, [[1, 0], [0, 0]])
273+
assert np.all(phased)
274+
275+
def test_variable_ploidy(self, tmp_path):
276+
# Create a tree sequence with mixed ploidy
277+
tables = tskit.TableCollection(sequence_length=100)
278+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
279+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
280+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
281+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
282+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
283+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
284+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
285+
tables.nodes.add_row(flags=0, time=1)
286+
tables.nodes.add_row(flags=0, time=1)
287+
tables.edges.add_row(left=0, right=100, parent=5, child=0)
288+
tables.edges.add_row(left=0, right=100, parent=5, child=1)
289+
tables.edges.add_row(left=0, right=100, parent=6, child=2)
290+
tables.edges.add_row(left=0, right=100, parent=6, child=3)
291+
tables.edges.add_row(left=0, right=100, parent=6, child=4)
292+
293+
# Add a site with a mutation at individual 0, node 0
294+
site_id = tables.sites.add_row(position=10, ancestral_state="A")
295+
tables.mutations.add_row(site=site_id, node=0, derived_state="T")
296+
297+
# Add another site with mutation at individual 1, node 3
298+
# (middle chromosome of triploid)
299+
site_id = tables.sites.add_row(position=20, ancestral_state="C")
300+
tables.mutations.add_row(site=site_id, node=3, derived_state="G")
301+
302+
tables.sort()
303+
tree_sequence = tables.tree_sequence()
304+
ts_path = tmp_path / "mixed_ploidy.trees"
305+
tree_sequence.dump(ts_path)
306+
307+
format_obj = ts.TskitFormat(ts_path)
308+
309+
assert format_obj.max_ploidy == 3
310+
assert format_obj.individual_ploidies == [2, 3]
311+
312+
shape = (2, 3) # (num_samples, max_ploidy)
313+
results = list(format_obj.iter_alleles_and_genotypes(0, 2, shape, 2))
314+
315+
assert len(results) == 2
316+
317+
alleles, (gt, phased) = results[0]
318+
assert tuple(alleles) == ("A", "T")
319+
# First site - derived on 1st of diploid
320+
expected_gt = np.array([[1, 0, -2], [0, 0, 0]])
321+
assert np.array_equal(gt, expected_gt)
322+
assert np.all(phased)
323+
324+
alleles, (gt, phased) = results[1]
325+
assert tuple(alleles) == ("C", "G")
326+
# Second site - derived on 2nd of triploid
327+
expected_gt = np.array([[0, 0, -2], [0, 1, 0]])
328+
assert np.array_equal(gt, expected_gt)
329+
assert np.all(phased)
330+
331+
# Check the fill value
332+
assert gt[0, 2] == -2
333+
334+
def test_isolated_as_missing(self, tmp_path):
335+
def insert_branch_sites(ts, m=1):
336+
if m == 0:
337+
return ts
338+
tables = ts.dump_tables()
339+
tables.sites.clear()
340+
tables.mutations.clear()
341+
for tree in ts.trees():
342+
left, right = tree.interval
343+
delta = (right - left) / (m * len(list(tree.nodes())))
344+
x = left
345+
for u in tree.nodes():
346+
if tree.parent(u) != tskit.NULL:
347+
for _ in range(m):
348+
site = tables.sites.add_row(position=x, ancestral_state="0")
349+
tables.mutations.add_row(
350+
site=site, node=u, derived_state="1"
351+
)
352+
x += delta
353+
return tables.tree_sequence()
354+
355+
tables = tskit.Tree.generate_balanced(2, span=10).tree_sequence.dump_tables()
356+
# This also tests sample nodes that are not a single block at
357+
# the start of the nodes table.
358+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
359+
tree_sequence = insert_branch_sites(tables.tree_sequence())
360+
print(tree_sequence.tables)
361+
362+
ts_path = tmp_path / "isolated_sample.trees"
363+
tree_sequence.dump(ts_path)
364+
365+
format_obj_default = ts.TskitFormat(ts_path, isolated_as_missing=False)
366+
shape = (3, 1) # (num_samples, max_ploidy)
367+
results_default = list(
368+
format_obj_default.iter_alleles_and_genotypes(0, 1, shape, 2)
369+
)
370+
371+
assert len(results_default) == 1
372+
alleles, (gt_default, phased) = results_default[0]
373+
assert tuple(alleles) == ("0", "1")
374+
375+
# Sample 2 should have the ancestral state (0) when isolated_as_missing=False
376+
expected_gt_default = np.array([[1], [0], [0]])
377+
assert np.array_equal(gt_default, expected_gt_default)
378+
379+
format_obj_missing = ts.TskitFormat(ts_path, isolated_as_missing=True)
380+
results_missing = list(
381+
format_obj_missing.iter_alleles_and_genotypes(0, 1, shape, 2)
382+
)
383+
384+
assert len(results_missing) == 1
385+
alleles, (gt_missing, phased) = results_missing[0]
386+
assert tuple(alleles) == ("0", "1")
387+
388+
# Individual 2 should have missing values (-1) when isolated_as_missing=True
389+
expected_gt_missing = np.array([[1], [0], [-1]])
390+
assert np.array_equal(gt_missing, expected_gt_missing)

0 commit comments

Comments
 (0)