From f19420c9dc8bdcd450c212c4fe2ed49bd534c89e Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 25 Apr 2025 19:23:33 +0100 Subject: [PATCH] Iterate alleles and genotypes at the same time --- bio2zarr/icf.py | 48 ++++++++++++++++++++++++--------------- bio2zarr/plink.py | 32 +++++++++++--------------- bio2zarr/vcz.py | 57 +++++++++++++++++++++-------------------------- 3 files changed, 68 insertions(+), 69 deletions(-) diff --git a/bio2zarr/icf.py b/bio2zarr/icf.py index 3ac7425a..f828989d 100644 --- a/bio2zarr/icf.py +++ b/bio2zarr/icf.py @@ -900,8 +900,12 @@ def __init__(self, path): ] # Allow us to find which partition a given record is in self.partition_record_index = np.cumsum([0, *partition_num_records]) + self.gt_field = None for field in self.metadata.fields: self.fields[field.full_name] = IntermediateColumnarFormatField(self, field) + if field.name == "GT": + self.gt_field = field + logger.info( f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, " f"records={self.num_records}, fields={self.num_fields})" @@ -970,19 +974,6 @@ def root_attrs(self): "vcf_header": self.vcf_header, } - def iter_alleles(self, start, stop, num_alleles): - ref_field = self.fields["REF"] - alt_field = self.fields["ALT"] - - for ref, alt in zip( - ref_field.iter_values(start, stop), - alt_field.iter_values(start, stop), - ): - alleles = np.full(num_alleles, constants.STR_FILL, dtype="O") - alleles[0] = ref[0] - alleles[1 : 1 + len(alt)] = alt - yield alleles - def iter_id(self, start, stop): for value in self.fields["ID"].iter_values(start, stop): if value is not None: @@ -1025,6 +1016,19 @@ def iter_field(self, field_name, shape, start, stop): for value in source_field.iter_values(start, stop): yield sanitiser(value) + def iter_alleles(self, start, stop, num_alleles): + ref_field = self.fields["REF"] + alt_field = self.fields["ALT"] + + for ref, alt in zip( + ref_field.iter_values(start, stop), + alt_field.iter_values(start, stop), + ): + alleles = np.full(num_alleles, constants.STR_FILL, dtype="O") + alleles[0] = ref[0] + alleles[1 : 1 + len(alt)] = alt + yield alleles + def iter_genotypes(self, shape, start, stop): source_field = self.fields["FORMAT/GT"] for value in source_field.iter_values(start, stop): @@ -1034,6 +1038,16 @@ def iter_genotypes(self, shape, start, stop): sanitised_phased = sanitise_value_int_1d(shape[:-1], phased) yield sanitised_genotypes, sanitised_phased + def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): + if self.gt_field is None or shape is None: + for alleles in self.iter_alleles(start, stop, num_alleles): + yield alleles, (None, None) + else: + yield from zip( + self.iter_alleles(start, stop, num_alleles), + self.iter_genotypes(shape, start, stop), + ) + def generate_schema( self, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None ): @@ -1128,15 +1142,13 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)): [spec_from_field(field) for field in self.metadata.info_fields] ) - gt_field = None for field in self.metadata.format_fields: if field.name == "GT": - gt_field = field continue array_specs.append(spec_from_field(field)) - if gt_field is not None and n > 0: - ploidy = max(gt_field.summary.max_number - 1, 1) + if self.gt_field is not None and n > 0: + ploidy = max(self.gt_field.summary.max_number - 1, 1) # Add ploidy dimension only when needed schema_instance.dimensions["ploidy"] = vcz.VcfZarrDimension(size=ploidy) @@ -1152,7 +1164,7 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)): array_specs.append( vcz.ZarrArraySpec( name="call_genotype", - dtype=gt_field.smallest_dtype(), + dtype=self.gt_field.smallest_dtype(), dimensions=["variants", "samples", "ploidy"], description="", compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(), diff --git a/bio2zarr/plink.py b/bio2zarr/plink.py index 7c672696..2eadfb38 100644 --- a/bio2zarr/plink.py +++ b/bio2zarr/plink.py @@ -31,34 +31,28 @@ def samples(self): def num_samples(self): return len(self.samples) - def iter_alleles(self, start, stop, num_alleles): - ref_field = self.bed.allele_1 - alt_field = self.bed.allele_2 - - for ref, alt in zip( - ref_field[start:stop], - alt_field[start:stop], - ): - alleles = np.full(num_alleles, constants.STR_FILL, dtype="O") - alleles[0] = ref - alleles[1 : 1 + len(alt)] = alt - yield alleles - def iter_field(self, field_name, shape, start, stop): assert field_name == "position" # Only position field is supported from plink yield from self.bed.bp_position[start:stop] - def iter_genotypes(self, shape, start, stop): + def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): + ref_field = self.bed.allele_1 + alt_field = self.bed.allele_2 bed_chunk = self.bed.read(slice(start, stop), dtype=np.int8).T gt = np.zeros(shape, dtype=np.int8) phased = np.zeros(shape[:-1], dtype=bool) - for values in bed_chunk: + for i, (ref, alt) in enumerate( + zip(ref_field[start:stop], alt_field[start:stop]) + ): + alleles = np.full(num_alleles, constants.STR_FILL, dtype="O") + alleles[0] = ref + alleles[1 : 1 + len(alt)] = alt gt[:] = 0 - gt[values == -127] = -1 - gt[values == 2] = 1 - gt[values == 1, 0] = 1 + gt[bed_chunk[i] == -127] = -1 + gt[bed_chunk[i] == 2] = 1 + gt[bed_chunk[i] == 1, 0] = 1 - yield gt, phased + yield alleles, (gt, phased) def generate_schema( self, diff --git a/bio2zarr/vcz.py b/bio2zarr/vcz.py index 488f0a05..85cf550c 100644 --- a/bio2zarr/vcz.py +++ b/bio2zarr/vcz.py @@ -71,11 +71,7 @@ def root_attrs(self): return {} @abc.abstractmethod - def iter_alleles(self, start, stop, num_alleles): - pass - - @abc.abstractmethod - def iter_genotypes(self, start, stop, num_alleles): + def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): pass def iter_id(self, start, stop): @@ -724,12 +720,11 @@ def encode_partition(self, partition_index): self.encode_filters_partition(partition_index) if "variant_contig" in all_field_names: self.encode_contig_partition(partition_index) - self.encode_alleles_partition(partition_index) + self.encode_alleles_and_genotypes_partition(partition_index) for array_spec in self.schema.fields: if array_spec.source is not None: self.encode_array_partition(array_spec, partition_index) if self.has_genotypes(): - self.encode_genotypes_partition(partition_index) self.encode_genotype_mask_partition(partition_index) if self.has_local_alleles(): self.encode_local_alleles_partition(partition_index) @@ -780,22 +775,32 @@ def encode_array_partition(self, array_spec, partition_index): self.finalise_partition_array(partition_index, ba) - def encode_genotypes_partition(self, partition_index): + def encode_alleles_and_genotypes_partition(self, partition_index): partition = self.metadata.partitions[partition_index] - gt = self.init_partition_array(partition_index, "call_genotype") - gt_phased = self.init_partition_array(partition_index, "call_genotype_phased") - - for genotype, phased in self.source.iter_genotypes( - gt.buff.shape[1:], partition.start, partition.stop + alleles = self.init_partition_array(partition_index, "variant_allele") + has_gt = self.has_genotypes() + shape = None + if has_gt: + gt = self.init_partition_array(partition_index, "call_genotype") + gt_phased = self.init_partition_array( + partition_index, "call_genotype_phased" + ) + shape = gt.buff.shape[1:] + for alleles_value, (genotype, phased) in self.source.iter_alleles_and_genotypes( + partition.start, partition.stop, shape, alleles.array.shape[1] ): - j = gt.next_buffer_row() - gt.buff[j] = genotype + j_alleles = alleles.next_buffer_row() + alleles.buff[j_alleles] = alleles_value + if has_gt: + j = gt.next_buffer_row() + gt.buff[j] = genotype + j_phased = gt_phased.next_buffer_row() + gt_phased.buff[j_phased] = phased - j_phased = gt_phased.next_buffer_row() - gt_phased.buff[j_phased] = phased - - self.finalise_partition_array(partition_index, gt) - self.finalise_partition_array(partition_index, gt_phased) + self.finalise_partition_array(partition_index, alleles) + if has_gt: + self.finalise_partition_array(partition_index, gt) + self.finalise_partition_array(partition_index, gt_phased) def encode_genotype_mask_partition(self, partition_index): partition = self.metadata.partitions[partition_index] @@ -857,18 +862,6 @@ def encode_local_allele_fields_partition(self, partition_index): buff.buff[j] = descriptor.convert(value, la) self.finalise_partition_array(partition_index, buff) - def encode_alleles_partition(self, partition_index): - alleles = self.init_partition_array(partition_index, "variant_allele") - partition = self.metadata.partitions[partition_index] - - for value in self.source.iter_alleles( - partition.start, partition.stop, alleles.array.shape[1] - ): - j = alleles.next_buffer_row() - alleles.buff[j] = value - - self.finalise_partition_array(partition_index, alleles) - def encode_id_partition(self, partition_index): vid = self.init_partition_array(partition_index, "variant_id") vid_mask = self.init_partition_array(partition_index, "variant_id_mask")