Skip to content

Iterate alleles and genotypes at the same time #363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions bio2zarr/icf.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,12 @@ def __init__(self, path):
]
# Allow us to find which partition a given record is in
self.partition_record_index = np.cumsum([0, *partition_num_records])
self.gt_field = None
for field in self.metadata.fields:
self.fields[field.full_name] = IntermediateColumnarFormatField(self, field)
if field.name == "GT":
self.gt_field = field

logger.info(
f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
f"records={self.num_records}, fields={self.num_fields})"
Expand Down Expand Up @@ -970,19 +974,6 @@ def root_attrs(self):
"vcf_header": self.vcf_header,
}

def iter_alleles(self, start, stop, num_alleles):
ref_field = self.fields["REF"]
alt_field = self.fields["ALT"]

for ref, alt in zip(
ref_field.iter_values(start, stop),
alt_field.iter_values(start, stop),
):
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
alleles[0] = ref[0]
alleles[1 : 1 + len(alt)] = alt
yield alleles

def iter_id(self, start, stop):
for value in self.fields["ID"].iter_values(start, stop):
if value is not None:
Expand Down Expand Up @@ -1025,6 +1016,19 @@ def iter_field(self, field_name, shape, start, stop):
for value in source_field.iter_values(start, stop):
yield sanitiser(value)

def iter_alleles(self, start, stop, num_alleles):
ref_field = self.fields["REF"]
alt_field = self.fields["ALT"]

for ref, alt in zip(
ref_field.iter_values(start, stop),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too much in one zip - let's break it up into simpler bits

What's wrong with

if self.gt_field is None:
     for alleles in self.iter_alleles(start, stop):
            yield alleles, None, None
else:
        yield from zip(self.iter_alleles(start, stop), self.iter_genotypes(start, stop)):

and just keep the old definitions?

alt_field.iter_values(start, stop),
):
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
alleles[0] = ref[0]
alleles[1 : 1 + len(alt)] = alt
yield alleles

def iter_genotypes(self, shape, start, stop):
source_field = self.fields["FORMAT/GT"]
for value in source_field.iter_values(start, stop):
Expand All @@ -1034,6 +1038,16 @@ def iter_genotypes(self, shape, start, stop):
sanitised_phased = sanitise_value_int_1d(shape[:-1], phased)
yield sanitised_genotypes, sanitised_phased

def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
if self.gt_field is None or shape is None:
for alleles in self.iter_alleles(start, stop, num_alleles):
yield alleles, (None, None)
else:
yield from zip(
self.iter_alleles(start, stop, num_alleles),
self.iter_genotypes(shape, start, stop),
)

def generate_schema(
self, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None
):
Expand Down Expand Up @@ -1128,15 +1142,13 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
[spec_from_field(field) for field in self.metadata.info_fields]
)

gt_field = None
for field in self.metadata.format_fields:
if field.name == "GT":
gt_field = field
continue
array_specs.append(spec_from_field(field))

if gt_field is not None and n > 0:
ploidy = max(gt_field.summary.max_number - 1, 1)
if self.gt_field is not None and n > 0:
ploidy = max(self.gt_field.summary.max_number - 1, 1)
# Add ploidy dimension only when needed
schema_instance.dimensions["ploidy"] = vcz.VcfZarrDimension(size=ploidy)

Expand All @@ -1152,7 +1164,7 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
array_specs.append(
vcz.ZarrArraySpec(
name="call_genotype",
dtype=gt_field.smallest_dtype(),
dtype=self.gt_field.smallest_dtype(),
dimensions=["variants", "samples", "ploidy"],
description="",
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),
Expand Down
32 changes: 13 additions & 19 deletions bio2zarr/plink.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,28 @@ def samples(self):
def num_samples(self):
return len(self.samples)

def iter_alleles(self, start, stop, num_alleles):
ref_field = self.bed.allele_1
alt_field = self.bed.allele_2

for ref, alt in zip(
ref_field[start:stop],
alt_field[start:stop],
):
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
alleles[0] = ref
alleles[1 : 1 + len(alt)] = alt
yield alleles

def iter_field(self, field_name, shape, start, stop):
assert field_name == "position" # Only position field is supported from plink
yield from self.bed.bp_position[start:stop]

def iter_genotypes(self, shape, start, stop):
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
ref_field = self.bed.allele_1
alt_field = self.bed.allele_2
bed_chunk = self.bed.read(slice(start, stop), dtype=np.int8).T
gt = np.zeros(shape, dtype=np.int8)
phased = np.zeros(shape[:-1], dtype=bool)
for values in bed_chunk:
for i, (ref, alt) in enumerate(
zip(ref_field[start:stop], alt_field[start:stop])
):
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
alleles[0] = ref
alleles[1 : 1 + len(alt)] = alt
gt[:] = 0
gt[values == -127] = -1
gt[values == 2] = 1
gt[values == 1, 0] = 1
gt[bed_chunk[i] == -127] = -1
gt[bed_chunk[i] == 2] = 1
gt[bed_chunk[i] == 1, 0] = 1

yield gt, phased
yield alleles, (gt, phased)

def generate_schema(
self,
Expand Down
57 changes: 25 additions & 32 deletions bio2zarr/vcz.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,7 @@ def root_attrs(self):
return {}

@abc.abstractmethod
def iter_alleles(self, start, stop, num_alleles):
pass

@abc.abstractmethod
def iter_genotypes(self, start, stop, num_alleles):
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
pass

def iter_id(self, start, stop):
Expand Down Expand Up @@ -724,12 +720,11 @@ def encode_partition(self, partition_index):
self.encode_filters_partition(partition_index)
if "variant_contig" in all_field_names:
self.encode_contig_partition(partition_index)
self.encode_alleles_partition(partition_index)
self.encode_alleles_and_genotypes_partition(partition_index)
for array_spec in self.schema.fields:
if array_spec.source is not None:
self.encode_array_partition(array_spec, partition_index)
if self.has_genotypes():
self.encode_genotypes_partition(partition_index)
self.encode_genotype_mask_partition(partition_index)
if self.has_local_alleles():
self.encode_local_alleles_partition(partition_index)
Expand Down Expand Up @@ -780,22 +775,32 @@ def encode_array_partition(self, array_spec, partition_index):

self.finalise_partition_array(partition_index, ba)

def encode_genotypes_partition(self, partition_index):
def encode_alleles_and_genotypes_partition(self, partition_index):
partition = self.metadata.partitions[partition_index]
gt = self.init_partition_array(partition_index, "call_genotype")
gt_phased = self.init_partition_array(partition_index, "call_genotype_phased")

for genotype, phased in self.source.iter_genotypes(
gt.buff.shape[1:], partition.start, partition.stop
alleles = self.init_partition_array(partition_index, "variant_allele")
has_gt = self.has_genotypes()
shape = None
if has_gt:
gt = self.init_partition_array(partition_index, "call_genotype")
gt_phased = self.init_partition_array(
partition_index, "call_genotype_phased"
)
shape = gt.buff.shape[1:]
for alleles_value, (genotype, phased) in self.source.iter_alleles_and_genotypes(
partition.start, partition.stop, shape, alleles.array.shape[1]
):
j = gt.next_buffer_row()
gt.buff[j] = genotype
j_alleles = alleles.next_buffer_row()
alleles.buff[j_alleles] = alleles_value
if has_gt:
j = gt.next_buffer_row()
gt.buff[j] = genotype
j_phased = gt_phased.next_buffer_row()
gt_phased.buff[j_phased] = phased

j_phased = gt_phased.next_buffer_row()
gt_phased.buff[j_phased] = phased

self.finalise_partition_array(partition_index, gt)
self.finalise_partition_array(partition_index, gt_phased)
self.finalise_partition_array(partition_index, alleles)
if has_gt:
self.finalise_partition_array(partition_index, gt)
self.finalise_partition_array(partition_index, gt_phased)

def encode_genotype_mask_partition(self, partition_index):
partition = self.metadata.partitions[partition_index]
Expand Down Expand Up @@ -857,18 +862,6 @@ def encode_local_allele_fields_partition(self, partition_index):
buff.buff[j] = descriptor.convert(value, la)
self.finalise_partition_array(partition_index, buff)

def encode_alleles_partition(self, partition_index):
alleles = self.init_partition_array(partition_index, "variant_allele")
partition = self.metadata.partitions[partition_index]

for value in self.source.iter_alleles(
partition.start, partition.stop, alleles.array.shape[1]
):
j = alleles.next_buffer_row()
alleles.buff[j] = value

self.finalise_partition_array(partition_index, alleles)

def encode_id_partition(self, partition_index):
vid = self.init_partition_array(partition_index, "variant_id")
vid_mask = self.init_partition_array(partition_index, "variant_id_mask")
Expand Down