Skip to content

Commit 3d8ad3e

Browse files
benjefferyjeromekelleher
authored andcommitted
Iterate alleles and genotypes at the same time
1 parent d30940e commit 3d8ad3e

File tree

3 files changed

+68
-69
lines changed

3 files changed

+68
-69
lines changed

bio2zarr/icf.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -900,8 +900,12 @@ def __init__(self, path):
900900
]
901901
# Allow us to find which partition a given record is in
902902
self.partition_record_index = np.cumsum([0, *partition_num_records])
903+
self.gt_field = None
903904
for field in self.metadata.fields:
904905
self.fields[field.full_name] = IntermediateColumnarFormatField(self, field)
906+
if field.name == "GT":
907+
self.gt_field = field
908+
905909
logger.info(
906910
f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
907911
f"records={self.num_records}, fields={self.num_fields})"
@@ -970,19 +974,6 @@ def root_attrs(self):
970974
"vcf_header": self.vcf_header,
971975
}
972976

973-
def iter_alleles(self, start, stop, num_alleles):
974-
ref_field = self.fields["REF"]
975-
alt_field = self.fields["ALT"]
976-
977-
for ref, alt in zip(
978-
ref_field.iter_values(start, stop),
979-
alt_field.iter_values(start, stop),
980-
):
981-
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
982-
alleles[0] = ref[0]
983-
alleles[1 : 1 + len(alt)] = alt
984-
yield alleles
985-
986977
def iter_id(self, start, stop):
987978
for value in self.fields["ID"].iter_values(start, stop):
988979
if value is not None:
@@ -1025,6 +1016,19 @@ def iter_field(self, field_name, shape, start, stop):
10251016
for value in source_field.iter_values(start, stop):
10261017
yield sanitiser(value)
10271018

1019+
def iter_alleles(self, start, stop, num_alleles):
1020+
ref_field = self.fields["REF"]
1021+
alt_field = self.fields["ALT"]
1022+
1023+
for ref, alt in zip(
1024+
ref_field.iter_values(start, stop),
1025+
alt_field.iter_values(start, stop),
1026+
):
1027+
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
1028+
alleles[0] = ref[0]
1029+
alleles[1 : 1 + len(alt)] = alt
1030+
yield alleles
1031+
10281032
def iter_genotypes(self, shape, start, stop):
10291033
source_field = self.fields["FORMAT/GT"]
10301034
for value in source_field.iter_values(start, stop):
@@ -1034,6 +1038,16 @@ def iter_genotypes(self, shape, start, stop):
10341038
sanitised_phased = sanitise_value_int_1d(shape[:-1], phased)
10351039
yield sanitised_genotypes, sanitised_phased
10361040

1041+
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
1042+
if self.gt_field is None or shape is None:
1043+
for alleles in self.iter_alleles(start, stop, num_alleles):
1044+
yield alleles, (None, None)
1045+
else:
1046+
yield from zip(
1047+
self.iter_alleles(start, stop, num_alleles),
1048+
self.iter_genotypes(shape, start, stop),
1049+
)
1050+
10371051
def generate_schema(
10381052
self, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None
10391053
):
@@ -1128,15 +1142,13 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
11281142
[spec_from_field(field) for field in self.metadata.info_fields]
11291143
)
11301144

1131-
gt_field = None
11321145
for field in self.metadata.format_fields:
11331146
if field.name == "GT":
1134-
gt_field = field
11351147
continue
11361148
array_specs.append(spec_from_field(field))
11371149

1138-
if gt_field is not None and n > 0:
1139-
ploidy = max(gt_field.summary.max_number - 1, 1)
1150+
if self.gt_field is not None and n > 0:
1151+
ploidy = max(self.gt_field.summary.max_number - 1, 1)
11401152
# Add ploidy dimension only when needed
11411153
schema_instance.dimensions["ploidy"] = vcz.VcfZarrDimension(size=ploidy)
11421154

@@ -1152,7 +1164,7 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
11521164
array_specs.append(
11531165
vcz.ZarrArraySpec(
11541166
name="call_genotype",
1155-
dtype=gt_field.smallest_dtype(),
1167+
dtype=self.gt_field.smallest_dtype(),
11561168
dimensions=["variants", "samples", "ploidy"],
11571169
description="",
11581170
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),

bio2zarr/plink.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,34 +31,28 @@ def samples(self):
3131
def num_samples(self):
3232
return len(self.samples)
3333

34-
def iter_alleles(self, start, stop, num_alleles):
35-
ref_field = self.bed.allele_1
36-
alt_field = self.bed.allele_2
37-
38-
for ref, alt in zip(
39-
ref_field[start:stop],
40-
alt_field[start:stop],
41-
):
42-
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
43-
alleles[0] = ref
44-
alleles[1 : 1 + len(alt)] = alt
45-
yield alleles
46-
4734
def iter_field(self, field_name, shape, start, stop):
4835
assert field_name == "position" # Only position field is supported from plink
4936
yield from self.bed.bp_position[start:stop]
5037

51-
def iter_genotypes(self, shape, start, stop):
38+
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
39+
ref_field = self.bed.allele_1
40+
alt_field = self.bed.allele_2
5241
bed_chunk = self.bed.read(slice(start, stop), dtype=np.int8).T
5342
gt = np.zeros(shape, dtype=np.int8)
5443
phased = np.zeros(shape[:-1], dtype=bool)
55-
for values in bed_chunk:
44+
for i, (ref, alt) in enumerate(
45+
zip(ref_field[start:stop], alt_field[start:stop])
46+
):
47+
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
48+
alleles[0] = ref
49+
alleles[1 : 1 + len(alt)] = alt
5650
gt[:] = 0
57-
gt[values == -127] = -1
58-
gt[values == 2] = 1
59-
gt[values == 1, 0] = 1
51+
gt[bed_chunk[i] == -127] = -1
52+
gt[bed_chunk[i] == 2] = 1
53+
gt[bed_chunk[i] == 1, 0] = 1
6054

61-
yield gt, phased
55+
yield alleles, (gt, phased)
6256

6357
def generate_schema(
6458
self,

bio2zarr/vcz.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,7 @@ def root_attrs(self):
7171
return {}
7272

7373
@abc.abstractmethod
74-
def iter_alleles(self, start, stop, num_alleles):
75-
pass
76-
77-
@abc.abstractmethod
78-
def iter_genotypes(self, start, stop, num_alleles):
74+
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
7975
pass
8076

8177
def iter_id(self, start, stop):
@@ -724,12 +720,11 @@ def encode_partition(self, partition_index):
724720
self.encode_filters_partition(partition_index)
725721
if "variant_contig" in all_field_names:
726722
self.encode_contig_partition(partition_index)
727-
self.encode_alleles_partition(partition_index)
723+
self.encode_alleles_and_genotypes_partition(partition_index)
728724
for array_spec in self.schema.fields:
729725
if array_spec.source is not None:
730726
self.encode_array_partition(array_spec, partition_index)
731727
if self.has_genotypes():
732-
self.encode_genotypes_partition(partition_index)
733728
self.encode_genotype_mask_partition(partition_index)
734729
if self.has_local_alleles():
735730
self.encode_local_alleles_partition(partition_index)
@@ -780,22 +775,32 @@ def encode_array_partition(self, array_spec, partition_index):
780775

781776
self.finalise_partition_array(partition_index, ba)
782777

783-
def encode_genotypes_partition(self, partition_index):
778+
def encode_alleles_and_genotypes_partition(self, partition_index):
784779
partition = self.metadata.partitions[partition_index]
785-
gt = self.init_partition_array(partition_index, "call_genotype")
786-
gt_phased = self.init_partition_array(partition_index, "call_genotype_phased")
787-
788-
for genotype, phased in self.source.iter_genotypes(
789-
gt.buff.shape[1:], partition.start, partition.stop
780+
alleles = self.init_partition_array(partition_index, "variant_allele")
781+
has_gt = self.has_genotypes()
782+
shape = None
783+
if has_gt:
784+
gt = self.init_partition_array(partition_index, "call_genotype")
785+
gt_phased = self.init_partition_array(
786+
partition_index, "call_genotype_phased"
787+
)
788+
shape = gt.buff.shape[1:]
789+
for alleles_value, (genotype, phased) in self.source.iter_alleles_and_genotypes(
790+
partition.start, partition.stop, shape, alleles.array.shape[1]
790791
):
791-
j = gt.next_buffer_row()
792-
gt.buff[j] = genotype
792+
j_alleles = alleles.next_buffer_row()
793+
alleles.buff[j_alleles] = alleles_value
794+
if has_gt:
795+
j = gt.next_buffer_row()
796+
gt.buff[j] = genotype
797+
j_phased = gt_phased.next_buffer_row()
798+
gt_phased.buff[j_phased] = phased
793799

794-
j_phased = gt_phased.next_buffer_row()
795-
gt_phased.buff[j_phased] = phased
796-
797-
self.finalise_partition_array(partition_index, gt)
798-
self.finalise_partition_array(partition_index, gt_phased)
800+
self.finalise_partition_array(partition_index, alleles)
801+
if has_gt:
802+
self.finalise_partition_array(partition_index, gt)
803+
self.finalise_partition_array(partition_index, gt_phased)
799804

800805
def encode_genotype_mask_partition(self, partition_index):
801806
partition = self.metadata.partitions[partition_index]
@@ -857,18 +862,6 @@ def encode_local_allele_fields_partition(self, partition_index):
857862
buff.buff[j] = descriptor.convert(value, la)
858863
self.finalise_partition_array(partition_index, buff)
859864

860-
def encode_alleles_partition(self, partition_index):
861-
alleles = self.init_partition_array(partition_index, "variant_allele")
862-
partition = self.metadata.partitions[partition_index]
863-
864-
for value in self.source.iter_alleles(
865-
partition.start, partition.stop, alleles.array.shape[1]
866-
):
867-
j = alleles.next_buffer_row()
868-
alleles.buff[j] = value
869-
870-
self.finalise_partition_array(partition_index, alleles)
871-
872865
def encode_id_partition(self, partition_index):
873866
vid = self.init_partition_array(partition_index, "variant_id")
874867
vid_mask = self.init_partition_array(partition_index, "variant_id_mask")

0 commit comments

Comments
 (0)