Skip to content

Commit 3829dbb

Browse files
committed
Add variant_length to output for plink and tskit, make missing arrays for index an error
1 parent 587a29e commit 3829dbb

File tree

9 files changed

+121
-40
lines changed

9 files changed

+121
-40
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
- Add contigs to plink output (#344)
44

5+
- Add variant_length and indexing to plink output (#382)
6+
57
Breaking changes
68

79
- Remove explicit sample, contig and filter lists from the schema.

bio2zarr/plink.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
6161
gt[bed_chunk[i] == 2] = 1
6262
gt[bed_chunk[i] == 1, 0] = 1
6363

64-
yield alleles, (gt, phased)
64+
yield vcz.VariantData(max(len(a) for a in alleles), alleles, gt, phased)
6565

6666
def generate_schema(
6767
self,
@@ -110,6 +110,13 @@ def generate_schema(
110110
dimensions=["variants", "alleles"],
111111
description=None,
112112
),
113+
vcz.ZarrArraySpec(
114+
source=None,
115+
name="variant_length",
116+
dtype="i4",
117+
dimensions=["variants"],
118+
description="Length of each variant",
119+
),
113120
vcz.ZarrArraySpec(
114121
name="variant_contig",
115122
dtype=core.min_int_dtype(0, len(np.unique(self.bed.chromosome))),

bio2zarr/tskit.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,23 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
9595
left=self.positions[start],
9696
right=self.positions[stop] if stop < self.num_records else None,
9797
samples=self.tskit_samples,
98+
copy=False,
9899
):
99100
gt = np.full(shape, constants.INT_FILL, dtype=np.int8)
100101
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
102+
variant_length = 0
101103
for i, allele in enumerate(variant.alleles):
102104
# None is returned by tskit in the case of a missing allele
103105
if allele is None:
104106
continue
105107
assert i < num_alleles
106108
alleles[i] = allele
107-
109+
variant_length = max(variant_length, len(allele))
108110
gt[self.sample_indices, self.ploidy_indices] = variant.genotypes[
109111
self.genotype_indices
110112
]
111113

112-
yield alleles, (gt, phased)
114+
yield vcz.VariantData(variant_length, alleles, gt, phased)
113115

114116
def generate_schema(
115117
self,
@@ -162,6 +164,16 @@ def generate_schema(
162164
min_position = np.min(self.ts.sites_position)
163165
max_position = np.max(self.ts.sites_position)
164166

167+
tables = self.ts.tables
168+
ancestral_state_offsets = tables.sites.ancestral_state_offset
169+
derived_state_offsets = tables.mutations.derived_state_offset
170+
ancestral_lengths = ancestral_state_offsets[1:] - ancestral_state_offsets[:-1]
171+
derived_lengths = derived_state_offsets[1:] - derived_state_offsets[:-1]
172+
max_variant_length = max(
173+
np.max(ancestral_lengths) if len(ancestral_lengths) > 0 else 0,
174+
np.max(derived_lengths) if len(derived_lengths) > 0 else 0,
175+
)
176+
165177
array_specs = [
166178
vcz.ZarrArraySpec(
167179
source="position",
@@ -177,6 +189,13 @@ def generate_schema(
177189
dimensions=["variants", "alleles"],
178190
description="Alleles for each variant",
179191
),
192+
vcz.ZarrArraySpec(
193+
source=None,
194+
name="variant_length",
195+
dtype=core.min_int_dtype(0, max_variant_length),
196+
dimensions=["variants"],
197+
description="Length of each variant",
198+
),
180199
vcz.ZarrArraySpec(
181200
source=None,
182201
name="variant_contig",

bio2zarr/vcf.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,14 +1040,19 @@ def iter_genotypes(self, shape, start, stop):
10401040
yield sanitised_genotypes, sanitised_phased
10411041

10421042
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
1043+
variant_lengths = self.fields["rlen"].iter_values(start, stop)
10431044
if self.gt_field is None or shape is None:
1044-
for alleles in self.iter_alleles(start, stop, num_alleles):
1045-
yield alleles, (None, None)
1045+
for variant_length, alleles in zip(
1046+
variant_lengths, self.iter_alleles(start, stop, num_alleles)
1047+
):
1048+
yield vcz.VariantData(variant_length, alleles, None, None)
10461049
else:
1047-
yield from zip(
1050+
for variant_length, alleles, (gt, phased) in zip(
1051+
variant_lengths,
10481052
self.iter_alleles(start, stop, num_alleles),
10491053
self.iter_genotypes(shape, start, stop),
1050-
)
1054+
):
1055+
yield vcz.VariantData(variant_length, alleles, gt, phased)
10511056

10521057
def generate_schema(
10531058
self, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None
@@ -1121,6 +1126,7 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
11211126
compressor=compressor,
11221127
)
11231128

1129+
name_map = {field.full_name: field for field in self.metadata.fields}
11241130
array_specs = [
11251131
fixed_field_spec(
11261132
name="variant_contig",
@@ -1136,6 +1142,11 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
11361142
dtype="O",
11371143
dimensions=["variants", "alleles"],
11381144
),
1145+
fixed_field_spec(
1146+
name="variant_length",
1147+
dtype=name_map["rlen"].smallest_dtype(),
1148+
dimensions=["variants"],
1149+
),
11391150
fixed_field_spec(
11401151
name="variant_id",
11411152
dtype="O",
@@ -1145,14 +1156,12 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
11451156
dtype="bool",
11461157
),
11471158
]
1148-
name_map = {field.full_name: field for field in self.metadata.fields}
11491159

1150-
# Only three of the fixed fields have a direct one-to-one mapping.
1160+
# Only two of the fixed fields have a direct one-to-one mapping.
11511161
array_specs.extend(
11521162
[
11531163
spec_from_field(name_map["QUAL"], array_name="variant_quality"),
11541164
spec_from_field(name_map["POS"], array_name="variant_position"),
1155-
spec_from_field(name_map["rlen"], array_name="variant_length"),
11561165
]
11571166
)
11581167
array_specs.extend(

bio2zarr/vcz.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@
3737
}
3838

3939

40+
@dataclasses.dataclass
41+
class VariantData:
42+
"""Represents variant data returned by iter_alleles_and_genotypes."""
43+
44+
variant_length: int
45+
alleles: np.ndarray
46+
genotypes: np.ndarray
47+
phased: np.ndarray
48+
49+
4050
class Source(abc.ABC):
4151
@property
4252
@abc.abstractmethod
@@ -794,6 +804,7 @@ def encode_array_partition(self, array_spec, partition_index):
794804
def encode_alleles_and_genotypes_partition(self, partition_index):
795805
partition = self.metadata.partitions[partition_index]
796806
alleles = self.init_partition_array(partition_index, "variant_allele")
807+
variant_lengths = self.init_partition_array(partition_index, "variant_length")
797808
has_gt = self.has_genotypes()
798809
shape = None
799810
if has_gt:
@@ -802,18 +813,21 @@ def encode_alleles_and_genotypes_partition(self, partition_index):
802813
partition_index, "call_genotype_phased"
803814
)
804815
shape = gt.buff.shape[1:]
805-
for alleles_value, (genotype, phased) in self.source.iter_alleles_and_genotypes(
816+
for variant_data in self.source.iter_alleles_and_genotypes(
806817
partition.start, partition.stop, shape, alleles.array.shape[1]
807818
):
808819
j_alleles = alleles.next_buffer_row()
809-
alleles.buff[j_alleles] = alleles_value
820+
alleles.buff[j_alleles] = variant_data.alleles
821+
j_variant_length = variant_lengths.next_buffer_row()
822+
variant_lengths.buff[j_variant_length] = variant_data.variant_length
810823
if has_gt:
811824
j = gt.next_buffer_row()
812-
gt.buff[j] = genotype
825+
gt.buff[j] = variant_data.genotypes
813826
j_phased = gt_phased.next_buffer_row()
814-
gt_phased.buff[j_phased] = phased
827+
gt_phased.buff[j_phased] = variant_data.phased
815828

816829
self.finalise_partition_array(partition_index, alleles)
830+
self.finalise_partition_array(partition_index, variant_lengths)
817831
if has_gt:
818832
self.finalise_partition_array(partition_index, gt)
819833
self.finalise_partition_array(partition_index, gt_phased)

tests/data/plink/example.bim

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
1 1_10 0 10 A G
2-
1 1_20 0 20 T C
1+
1 1_10 0 10 A GG
2+
1 1_20 0 20 TTT C

tests/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_examples(self, chunk_size, size, start, stop):
237237
# It works in CI on Linux, but it'll probably break at some point.
238238
# It's also necessary to update these numbers each time a new data
239239
# file gets added
240-
("tests/data", 5045029),
240+
("tests/data", 5045032),
241241
("tests/data/vcf", 5018640),
242242
("tests/data/vcf/sample.vcf.gz", 1089),
243243
],

tests/test_plink.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ class TestExample:
7777
"""
7878
.bim file looks like this:
7979
80-
1 1_10 0 10 A G
81-
1 1_20 0 20 T C
80+
1 1_10 0 10 A GG
81+
1 1_20 0 20 TTT C
8282
8383
Definition: https://www.cog-genomics.org/plink/1.9/formats#bim
8484
Chromosome code (either an integer, or 'X'/'Y'/'XY'/'MT'; '0'
@@ -104,7 +104,10 @@ def test_variant_position(self, ds):
104104
nt.assert_array_equal(ds.variant_position, [10, 20])
105105

106106
def test_variant_allele(self, ds):
107-
nt.assert_array_equal(ds.variant_allele, [["A", "G"], ["T", "C"]])
107+
nt.assert_array_equal(ds.variant_allele, [["A", "GG"], ["TTT", "C"]])
108+
109+
def test_variant_length(self, ds):
110+
nt.assert_array_equal(ds.variant_length, [2, 3])
108111

109112
def test_contig_id(self, ds):
110113
"""Test that contig identifiers are correctly extracted and stored."""
@@ -249,6 +252,9 @@ def test_chunk_size(
249252
worker_processes=worker_processes,
250253
)
251254
ds2 = sg.load_dataset(out)
255+
# Drop the region_index as it is chunk dependent
256+
ds = ds.drop_vars("region_index")
257+
ds2 = ds2.drop_vars("region_index")
252258
xt.assert_equal(ds, ds2)
253259
# TODO check array chunks
254260

@@ -355,3 +361,9 @@ def test_genotypes(self, ds):
355361

356362
def test_variant_position(self, ds):
357363
nt.assert_array_equal(ds.variant_position, [10, 20, 10, 10, 20, 10])
364+
365+
def test_variant_length(self, ds):
366+
nt.assert_array_equal(
367+
ds.variant_length,
368+
[1, 1, 1, 1, 1, 1],
369+
)

tests/test_ts.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ def test_simple_tree_sequence(self, tmp_path):
2323
tables.edges.add_row(left=0, right=100, parent=5, child=2)
2424
tables.edges.add_row(left=0, right=100, parent=5, child=3)
2525
site_id = tables.sites.add_row(position=10, ancestral_state="A")
26-
tables.mutations.add_row(site=site_id, node=4, derived_state="T")
27-
site_id = tables.sites.add_row(position=20, ancestral_state="C")
26+
tables.mutations.add_row(site=site_id, node=4, derived_state="TTTT")
27+
site_id = tables.sites.add_row(position=20, ancestral_state="CCC")
2828
tables.mutations.add_row(site=site_id, node=5, derived_state="G")
2929
site_id = tables.sites.add_row(position=30, ancestral_state="G")
30-
tables.mutations.add_row(site=site_id, node=0, derived_state="A")
30+
tables.mutations.add_row(site=site_id, node=0, derived_state="AA")
3131
tables.sort()
3232
tree_sequence = tables.tree_sequence()
3333
tree_sequence.dump(tmp_path / "test.trees")
@@ -53,7 +53,12 @@ def test_simple_tree_sequence(self, tmp_path):
5353
alleles = zroot["variant_allele"][:]
5454
assert alleles.shape == (3, 2)
5555
assert alleles.dtype == "O"
56-
assert np.array_equal(alleles, [["A", "T"], ["C", "G"], ["G", "A"]])
56+
assert np.array_equal(alleles, [["A", "TTTT"], ["CCC", "G"], ["G", "AA"]])
57+
58+
lengths = zroot["variant_length"][:]
59+
assert lengths.shape == (3,)
60+
assert lengths.dtype == np.int8
61+
assert np.array_equal(lengths, [4, 3, 2])
5762

5863
genotypes = zroot["call_genotype"][:]
5964
assert genotypes.shape == (3, 2, 2)
@@ -64,7 +69,7 @@ def test_simple_tree_sequence(self, tmp_path):
6469

6570
phased = zroot["call_genotype_phased"][:]
6671
assert phased.shape == (3, 2)
67-
assert phased.dtype == np.bool
72+
assert phased.dtype == "bool"
6873
assert np.all(phased)
6974

7075
contigs = zroot["contig_id"][:]
@@ -82,15 +87,22 @@ def test_simple_tree_sequence(self, tmp_path):
8287
assert samples.dtype == "O"
8388
assert np.array_equal(samples, ["tsk_0", "tsk_1"])
8489

90+
region_index = zroot["region_index"][:]
91+
assert region_index.shape == (1,6)
92+
assert region_index.dtype == np.int8
93+
assert np.array_equal(region_index, [[ 0, 0, 10, 30, 31, 3]])
94+
8595
assert set(zroot.array_keys()) == {
8696
"variant_position",
8797
"variant_allele",
98+
"variant_length",
8899
"call_genotype",
89100
"call_genotype_phased",
90101
"call_genotype_mask",
91102
"contig_id",
92103
"variant_contig",
93104
"sample_id",
105+
"region_index",
94106
}
95107

96108

@@ -113,8 +125,8 @@ def simple_ts(self, tmp_path):
113125
tables.edges.add_row(left=0, right=100, parent=5, child=2)
114126
tables.edges.add_row(left=0, right=100, parent=5, child=3)
115127
site_id = tables.sites.add_row(position=10, ancestral_state="A")
116-
tables.mutations.add_row(site=site_id, node=4, derived_state="T")
117-
site_id = tables.sites.add_row(position=20, ancestral_state="C")
128+
tables.mutations.add_row(site=site_id, node=4, derived_state="TT")
129+
site_id = tables.sites.add_row(position=20, ancestral_state="CCC")
118130
tables.mutations.add_row(site=site_id, node=5, derived_state="G")
119131
site_id = tables.sites.add_row(position=30, ancestral_state="G")
120132
tables.mutations.add_row(site=site_id, node=0, derived_state="A")
@@ -248,6 +260,7 @@ def test_schema_generation(self, simple_ts):
248260
field_names = [field.name for field in schema.fields]
249261
assert "variant_position" in field_names
250262
assert "variant_allele" in field_names
263+
assert "variant_length" in field_names
251264
assert "variant_contig" in field_names
252265
assert "call_genotype" in field_names
253266
assert "call_genotype_phased" in field_names
@@ -319,18 +332,22 @@ def test_iter_alleles_and_genotypes(self, simple_ts, ind_nodes, expected_gts):
319332

320333
assert len(results) == 3
321334

322-
for i, (alleles, (gt, phased)) in enumerate(results):
335+
for i, variant_data in enumerate(results):
323336
if i == 0:
324-
assert tuple(alleles) == ("A", "T")
337+
assert variant_data.variant_length == 2
338+
assert np.array_equal(variant_data.alleles, ("A", "TT"))
325339
elif i == 1:
326-
assert tuple(alleles) == ("C", "G")
340+
assert variant_data.variant_length == 3
341+
assert np.array_equal(variant_data.alleles, ("CCC", "G"))
327342
elif i == 2:
328-
assert tuple(alleles) == ("G", "A")
343+
assert variant_data.variant_length == 1
344+
assert np.array_equal(variant_data.alleles, ("G", "A"))
329345

330346
assert np.array_equal(
331-
gt, expected_gts[i]
332-
), f"Mismatch at variant {i}, expected {expected_gts[i]}, got {gt}"
333-
assert np.all(phased)
347+
variant_data.genotypes, expected_gts[i]
348+
), f"Mismatch at variant {i}, expected {expected_gts[i]}, "
349+
f"got {variant_data.genotypes}"
350+
assert np.all(variant_data.phased)
334351

335352
def test_iter_alleles_and_genotypes_errors(self, simple_ts):
336353
"""Test error cases for iter_alleles_and_genotypes with invalid inputs."""
@@ -398,12 +415,12 @@ def insert_branch_sites(ts, m=1):
398415
)
399416

400417
assert len(results_default) == 1
401-
alleles, (gt_default, phased) = results_default[0]
402-
assert tuple(alleles) == ("0", "1")
418+
variant_data_default = results_default[0]
419+
assert np.array_equal(variant_data_default.alleles, ("0", "1"))
403420

404421
# Sample 2 should have the ancestral state (0) when isolated_as_missing=False
405422
expected_gt_default = np.array([[1], [0], [0]])
406-
assert np.array_equal(gt_default, expected_gt_default)
423+
assert np.array_equal(variant_data_default.genotypes, expected_gt_default)
407424

408425
format_obj_missing = ts.TskitFormat(
409426
ts_path, individuals_nodes=ind_nodes, isolated_as_missing=True
@@ -413,12 +430,13 @@ def insert_branch_sites(ts, m=1):
413430
)
414431

415432
assert len(results_missing) == 1
416-
alleles, (gt_missing, phased) = results_missing[0]
417-
assert tuple(alleles) == ("0", "1")
433+
variant_data_missing = results_missing[0]
434+
assert variant_data_missing.variant_length == 1
435+
assert np.array_equal(variant_data_missing.alleles, ("0", "1"))
418436

419437
# Individual 2 should have missing values (-1) when isolated_as_missing=True
420438
expected_gt_missing = np.array([[1], [0], [-1]])
421-
assert np.array_equal(gt_missing, expected_gt_missing)
439+
assert np.array_equal(variant_data_missing.genotypes, expected_gt_missing)
422440

423441
def test_genotype_dtype_selection(self, tmp_path):
424442
tables = tskit.TableCollection(sequence_length=100)

0 commit comments

Comments
 (0)