Skip to content

Commit 35567d2

Browse files
committed
Add variant_length to output for plink and tskit, make missing arrays for index an error
1 parent 6680178 commit 35567d2

File tree

8 files changed

+110
-39
lines changed

8 files changed

+110
-39
lines changed

bio2zarr/plink.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
6161
gt[bed_chunk[i] == 2] = 1
6262
gt[bed_chunk[i] == 1, 0] = 1
6363

64-
yield alleles, (gt, phased)
64+
yield vcz.VariantData(max(len(a) for a in alleles), alleles, gt, phased)
6565

6666
def generate_schema(
6767
self,
@@ -110,6 +110,13 @@ def generate_schema(
110110
dimensions=["variants", "alleles"],
111111
description=None,
112112
),
113+
vcz.ZarrArraySpec(
114+
source=None,
115+
name="variant_length",
116+
dtype="i4",
117+
dimensions=["variants"],
118+
description="Length of each variant",
119+
),
113120
vcz.ZarrArraySpec(
114121
name="variant_contig",
115122
dtype=core.min_int_dtype(0, len(np.unique(self.bed.chromosome))),

bio2zarr/tskit.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,23 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
9292
left=self.positions[start],
9393
right=self.positions[stop] if stop < self.num_records else None,
9494
samples=self.tskit_samples,
95+
copy=False,
9596
):
9697
gt = np.full(shape, constants.INT_FILL, dtype=np.int8)
9798
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
99+
variant_length = 0
98100
for i, allele in enumerate(variant.alleles):
99101
# None is returned by tskit in the case of a missing allele
100102
if allele is None:
101103
continue
102104
assert i < num_alleles
103105
alleles[i] = allele
104-
106+
variant_length = max(variant_length, len(allele))
105107
gt[self.sample_indices, self.ploidy_indices] = variant.genotypes[
106108
self.genotype_indices
107109
]
108110

109-
yield alleles, (gt, phased)
111+
yield vcz.VariantData(variant_length, alleles, gt, phased)
110112

111113
def generate_schema(
112114
self,
@@ -159,6 +161,16 @@ def generate_schema(
159161
min_position = np.min(self.ts.sites_position)
160162
max_position = np.max(self.ts.sites_position)
161163

164+
tables = self.ts.tables
165+
ancestral_state_offsets = tables.sites.ancestral_state_offset
166+
derived_state_offsets = tables.mutations.derived_state_offset
167+
ancestral_lengths = ancestral_state_offsets[1:] - ancestral_state_offsets[:-1]
168+
derived_lengths = derived_state_offsets[1:] - derived_state_offsets[:-1]
169+
max_variant_length = max(
170+
np.max(ancestral_lengths) if len(ancestral_lengths) > 0 else 0,
171+
np.max(derived_lengths) if len(derived_lengths) > 0 else 0,
172+
)
173+
162174
array_specs = [
163175
vcz.ZarrArraySpec(
164176
source="position",
@@ -174,6 +186,13 @@ def generate_schema(
174186
dimensions=["variants", "alleles"],
175187
description="Alleles for each variant",
176188
),
189+
vcz.ZarrArraySpec(
190+
source=None,
191+
name="variant_length",
192+
dtype=core.min_int_dtype(0, max_variant_length),
193+
dimensions=["variants"],
194+
description="Length of each variant",
195+
),
177196
vcz.ZarrArraySpec(
178197
source=None,
179198
name="variant_contig",

bio2zarr/vcf.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,14 +1040,19 @@ def iter_genotypes(self, shape, start, stop):
10401040
yield sanitised_genotypes, sanitised_phased
10411041

10421042
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
1043+
variant_lengths = self.fields["rlen"].iter_values(start, stop)
10431044
if self.gt_field is None or shape is None:
1044-
for alleles in self.iter_alleles(start, stop, num_alleles):
1045-
yield alleles, (None, None)
1045+
for variant_length, alleles in zip(
1046+
variant_lengths, self.iter_alleles(start, stop, num_alleles)
1047+
):
1048+
yield vcz.VariantData(variant_length, alleles, None, None)
10461049
else:
1047-
yield from zip(
1050+
for variant_length, alleles, (gt, phased) in zip(
1051+
variant_lengths,
10481052
self.iter_alleles(start, stop, num_alleles),
10491053
self.iter_genotypes(shape, start, stop),
1050-
)
1054+
):
1055+
yield vcz.VariantData(variant_length, alleles, gt, phased)
10511056

10521057
def generate_schema(
10531058
self, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None
@@ -1121,6 +1126,7 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
11211126
compressor=compressor,
11221127
)
11231128

1129+
name_map = {field.full_name: field for field in self.metadata.fields}
11241130
array_specs = [
11251131
fixed_field_spec(
11261132
name="variant_contig",
@@ -1136,6 +1142,11 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
11361142
dtype="O",
11371143
dimensions=["variants", "alleles"],
11381144
),
1145+
fixed_field_spec(
1146+
name="variant_length",
1147+
dtype=name_map["rlen"].smallest_dtype(),
1148+
dimensions=["variants"],
1149+
),
11391150
fixed_field_spec(
11401151
name="variant_id",
11411152
dtype="O",
@@ -1145,14 +1156,12 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
11451156
dtype="bool",
11461157
),
11471158
]
1148-
name_map = {field.full_name: field for field in self.metadata.fields}
11491159

1150-
# Only three of the fixed fields have a direct one-to-one mapping.
1160+
# Only two of the fixed fields have a direct one-to-one mapping.
11511161
array_specs.extend(
11521162
[
11531163
spec_from_field(name_map["QUAL"], array_name="variant_quality"),
11541164
spec_from_field(name_map["POS"], array_name="variant_position"),
1155-
spec_from_field(name_map["rlen"], array_name="variant_length"),
11561165
]
11571166
)
11581167
array_specs.extend(

bio2zarr/vcz.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@
3737
}
3838

3939

40+
@dataclasses.dataclass
41+
class VariantData:
42+
"""Represents variant data returned by iter_alleles_and_genotypes."""
43+
44+
variant_length: int
45+
alleles: np.ndarray
46+
genotypes: np.ndarray
47+
phased: np.ndarray
48+
49+
4050
class Source(abc.ABC):
4151
@property
4252
@abc.abstractmethod
@@ -794,6 +804,7 @@ def encode_array_partition(self, array_spec, partition_index):
794804
def encode_alleles_and_genotypes_partition(self, partition_index):
795805
partition = self.metadata.partitions[partition_index]
796806
alleles = self.init_partition_array(partition_index, "variant_allele")
807+
variant_lengths = self.init_partition_array(partition_index, "variant_length")
797808
has_gt = self.has_genotypes()
798809
shape = None
799810
if has_gt:
@@ -802,18 +813,21 @@ def encode_alleles_and_genotypes_partition(self, partition_index):
802813
partition_index, "call_genotype_phased"
803814
)
804815
shape = gt.buff.shape[1:]
805-
for alleles_value, (genotype, phased) in self.source.iter_alleles_and_genotypes(
816+
for variant_data in self.source.iter_alleles_and_genotypes(
806817
partition.start, partition.stop, shape, alleles.array.shape[1]
807818
):
808819
j_alleles = alleles.next_buffer_row()
809-
alleles.buff[j_alleles] = alleles_value
820+
alleles.buff[j_alleles] = variant_data.alleles
821+
j_variant_length = variant_lengths.next_buffer_row()
822+
variant_lengths.buff[j_variant_length] = variant_data.variant_length
810823
if has_gt:
811824
j = gt.next_buffer_row()
812-
gt.buff[j] = genotype
825+
gt.buff[j] = variant_data.genotypes
813826
j_phased = gt_phased.next_buffer_row()
814-
gt_phased.buff[j_phased] = phased
827+
gt_phased.buff[j_phased] = variant_data.phased
815828

816829
self.finalise_partition_array(partition_index, alleles)
830+
self.finalise_partition_array(partition_index, variant_lengths)
817831
if has_gt:
818832
self.finalise_partition_array(partition_index, gt)
819833
self.finalise_partition_array(partition_index, gt_phased)

tests/data/plink/example.bim

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
1 1_10 0 10 A G
2-
1 1_20 0 20 T C
1+
1 1_10 0 10 A GG
2+
1 1_20 0 20 TTT C

tests/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_examples(self, chunk_size, size, start, stop):
237237
# It works in CI on Linux, but it'll probably break at some point.
238238
# It's also necessary to update these numbers each time a new data
239239
# file gets added
240-
("tests/data", 5030777),
240+
("tests/data", 5030780),
241241
("tests/data/vcf", 5018640),
242242
("tests/data/vcf/sample.vcf.gz", 1089),
243243
],

tests/test_plink.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ class TestExample:
7777
"""
7878
.bim file looks like this:
7979
80-
1 1_10 0 10 A G
81-
1 1_20 0 20 T C
80+
1 1_10 0 10 A GG
81+
1 1_20 0 20 TTT C
8282
8383
Definition: https://www.cog-genomics.org/plink/1.9/formats#bim
8484
Chromosome code (either an integer, or 'X'/'Y'/'XY'/'MT'; '0'
@@ -104,7 +104,10 @@ def test_variant_position(self, ds):
104104
nt.assert_array_equal(ds.variant_position, [10, 20])
105105

106106
def test_variant_allele(self, ds):
107-
nt.assert_array_equal(ds.variant_allele, [["A", "G"], ["T", "C"]])
107+
nt.assert_array_equal(ds.variant_allele, [["A", "GG"], ["TTT", "C"]])
108+
109+
def test_variant_length(self, ds):
110+
nt.assert_array_equal(ds.variant_length, [2, 3])
108111

109112
def test_contig_id(self, ds):
110113
"""Test that contig identifiers are correctly extracted and stored."""
@@ -249,6 +252,9 @@ def test_chunk_size(
249252
worker_processes=worker_processes,
250253
)
251254
ds2 = sg.load_dataset(out)
255+
# Drop the region_index as it is chunk dependent
256+
ds = ds.drop_vars("region_index")
257+
ds2 = ds2.drop_vars("region_index")
252258
xt.assert_equal(ds, ds2)
253259
# TODO check array chunks
254260

@@ -355,3 +361,9 @@ def test_genotypes(self, ds):
355361

356362
def test_variant_position(self, ds):
357363
nt.assert_array_equal(ds.variant_position, [10, 20, 10, 10, 20, 10])
364+
365+
def test_variant_length(self, ds):
366+
nt.assert_array_equal(
367+
ds.variant_length,
368+
[1, 1, 1, 1, 1, 1],
369+
)

tests/test_ts.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ def test_simple_tree_sequence(self, tmp_path):
2323
tables.edges.add_row(left=0, right=100, parent=5, child=2)
2424
tables.edges.add_row(left=0, right=100, parent=5, child=3)
2525
site_id = tables.sites.add_row(position=10, ancestral_state="A")
26-
tables.mutations.add_row(site=site_id, node=4, derived_state="T")
27-
site_id = tables.sites.add_row(position=20, ancestral_state="C")
26+
tables.mutations.add_row(site=site_id, node=4, derived_state="TTTT")
27+
site_id = tables.sites.add_row(position=20, ancestral_state="CCC")
2828
tables.mutations.add_row(site=site_id, node=5, derived_state="G")
2929
site_id = tables.sites.add_row(position=30, ancestral_state="G")
30-
tables.mutations.add_row(site=site_id, node=0, derived_state="A")
30+
tables.mutations.add_row(site=site_id, node=0, derived_state="AA")
3131
tables.sort()
3232
tree_sequence = tables.tree_sequence()
3333
tree_sequence.dump(tmp_path / "test.trees")
@@ -44,7 +44,11 @@ def test_simple_tree_sequence(self, tmp_path):
4444
assert list(zroot["variant_position"][:]) == [10, 20, 30]
4545

4646
alleles = zroot["variant_allele"][:]
47-
assert np.array_equal(alleles, [["A", "T"], ["C", "G"], ["G", "A"]])
47+
assert np.array_equal(alleles, [["A", "TTTT"], ["CCC", "G"], ["G", "AA"]])
48+
49+
lengths = zroot["variant_length"][:]
50+
assert lengths.shape == (3,)
51+
assert np.array_equal(lengths, [4, 3, 2])
4852

4953
genotypes = zroot["call_genotype"][:]
5054
assert np.array_equal(
@@ -81,8 +85,8 @@ def simple_ts(self, tmp_path):
8185
tables.edges.add_row(left=0, right=100, parent=5, child=2)
8286
tables.edges.add_row(left=0, right=100, parent=5, child=3)
8387
site_id = tables.sites.add_row(position=10, ancestral_state="A")
84-
tables.mutations.add_row(site=site_id, node=4, derived_state="T")
85-
site_id = tables.sites.add_row(position=20, ancestral_state="C")
88+
tables.mutations.add_row(site=site_id, node=4, derived_state="TT")
89+
site_id = tables.sites.add_row(position=20, ancestral_state="CCC")
8690
tables.mutations.add_row(site=site_id, node=5, derived_state="G")
8791
site_id = tables.sites.add_row(position=30, ancestral_state="G")
8892
tables.mutations.add_row(site=site_id, node=0, derived_state="A")
@@ -222,6 +226,7 @@ def test_schema_generation(self, simple_ts):
222226
field_names = [field.name for field in schema.fields]
223227
assert "variant_position" in field_names
224228
assert "variant_allele" in field_names
229+
assert "variant_length" in field_names
225230
assert "variant_contig" in field_names
226231
assert "call_genotype" in field_names
227232
assert "call_genotype_phased" in field_names
@@ -295,18 +300,22 @@ def test_iter_alleles_and_genotypes(self, simple_ts, ind_nodes, expected_gts):
295300

296301
assert len(results) == 3
297302

298-
for i, (alleles, (gt, phased)) in enumerate(results):
303+
for i, variant_data in enumerate(results):
299304
if i == 0:
300-
assert tuple(alleles) == ("A", "T")
305+
assert variant_data.variant_length == 2
306+
assert np.array_equal(variant_data.alleles, ("A", "TT"))
301307
elif i == 1:
302-
assert tuple(alleles) == ("C", "G")
308+
assert variant_data.variant_length == 3
309+
assert np.array_equal(variant_data.alleles, ("CCC", "G"))
303310
elif i == 2:
304-
assert tuple(alleles) == ("G", "A")
311+
assert variant_data.variant_length == 1
312+
assert np.array_equal(variant_data.alleles, ("G", "A"))
305313

306314
assert np.array_equal(
307-
gt, expected_gts[i]
308-
), f"Mismatch at variant {i}, expected {expected_gts[i]}, got {gt}"
309-
assert np.all(phased)
315+
variant_data.genotypes, expected_gts[i]
316+
), f"Mismatch at variant {i}, expected {expected_gts[i]}, "
317+
f"got {variant_data.genotypes}"
318+
assert np.all(variant_data.phased)
310319

311320
def test_iter_alleles_and_genotypes_errors(self, simple_ts):
312321
"""Test error cases for iter_alleles_and_genotypes with invalid inputs."""
@@ -374,12 +383,12 @@ def insert_branch_sites(ts, m=1):
374383
)
375384

376385
assert len(results_default) == 1
377-
alleles, (gt_default, phased) = results_default[0]
378-
assert tuple(alleles) == ("0", "1")
386+
variant_data_default = results_default[0]
387+
assert np.array_equal(variant_data_default.alleles, ("0", "1"))
379388

380389
# Sample 2 should have the ancestral state (0) when isolated_as_missing=False
381390
expected_gt_default = np.array([[1], [0], [0]])
382-
assert np.array_equal(gt_default, expected_gt_default)
391+
assert np.array_equal(variant_data_default.genotypes, expected_gt_default)
383392

384393
format_obj_missing = ts.TskitFormat(
385394
ts_path, ind_nodes, isolated_as_missing=True
@@ -389,12 +398,13 @@ def insert_branch_sites(ts, m=1):
389398
)
390399

391400
assert len(results_missing) == 1
392-
alleles, (gt_missing, phased) = results_missing[0]
393-
assert tuple(alleles) == ("0", "1")
401+
variant_data_missing = results_missing[0]
402+
assert variant_data_missing.variant_length == 1
403+
assert np.array_equal(variant_data_missing.alleles, ("0", "1"))
394404

395405
# Individual 2 should have missing values (-1) when isolated_as_missing=True
396406
expected_gt_missing = np.array([[1], [0], [-1]])
397-
assert np.array_equal(gt_missing, expected_gt_missing)
407+
assert np.array_equal(variant_data_missing.genotypes, expected_gt_missing)
398408

399409
def test_genotype_dtype_selection(self, tmp_path):
400410
tables = tskit.TableCollection(sequence_length=100)

0 commit comments

Comments
 (0)