@@ -900,8 +900,13 @@ def __init__(self, path):
900
900
]
901
901
# Allow us to find which partition a given record is in
902
902
self .partition_record_index = np .cumsum ([0 , * partition_num_records ])
903
+ self .gt_field = None
903
904
for field in self .metadata .fields :
904
905
self .fields [field .full_name ] = IntermediateColumnarFormatField (self , field )
906
+ if field .name == "GT" :
907
+ self .gt_field = field
908
+ continue
909
+
905
910
logger .info (
906
911
f"Loaded IntermediateColumnarFormat(partitions={ self .num_partitions } , "
907
912
f"records={ self .num_records } , fields={ self .num_fields } )"
@@ -970,19 +975,6 @@ def root_attrs(self):
970
975
"vcf_header" : self .vcf_header ,
971
976
}
972
977
973
- def iter_alleles (self , start , stop , num_alleles ):
974
- ref_field = self .fields ["REF" ]
975
- alt_field = self .fields ["ALT" ]
976
-
977
- for ref , alt in zip (
978
- ref_field .iter_values (start , stop ),
979
- alt_field .iter_values (start , stop ),
980
- ):
981
- alleles = np .full (num_alleles , constants .STR_FILL , dtype = "O" )
982
- alleles [0 ] = ref [0 ]
983
- alleles [1 : 1 + len (alt )] = alt
984
- yield alleles
985
-
986
978
def iter_id (self , start , stop ):
987
979
for value in self .fields ["ID" ].iter_values (start , stop ):
988
980
if value is not None :
@@ -1025,14 +1017,30 @@ def iter_field(self, field_name, shape, start, stop):
1025
1017
for value in source_field .iter_values (start , stop ):
1026
1018
yield sanitiser (value )
1027
1019
1028
- def iter_genotypes (self , shape , start , stop ):
1029
- source_field = self .fields ["FORMAT/GT" ]
1030
- for value in source_field .iter_values (start , stop ):
1031
- genotypes = value [:, :- 1 ] if value is not None else None
1032
- phased = value [:, - 1 ] if value is not None else None
1033
- sanitised_genotypes = sanitise_value_int_2d (shape , genotypes )
1034
- sanitised_phased = sanitise_value_int_1d (shape [:- 1 ], phased )
1035
- yield sanitised_genotypes , sanitised_phased
1020
+ def iter_alleles_and_genotypes (self , start , stop , shape , num_alleles ):
1021
+ ref_field = self .fields ["REF" ]
1022
+ alt_field = self .fields ["ALT" ]
1023
+
1024
+ for ref , alt , gt in zip (
1025
+ ref_field .iter_values (start , stop ),
1026
+ alt_field .iter_values (start , stop ),
1027
+ # Create a dummy gt iterator if genotypes are not included
1028
+ (None for _ in range (stop - start ))
1029
+ if self .gt_field is None or shape is None
1030
+ else self .fields ["FORMAT/GT" ].iter_values (start , stop ),
1031
+ ):
1032
+ alleles = np .full (num_alleles , constants .STR_FILL , dtype = "O" )
1033
+ alleles [0 ] = ref [0 ]
1034
+ alleles [1 : 1 + len (alt )] = alt
1035
+
1036
+ if self .gt_field is not None and shape is not None :
1037
+ genotypes = gt [:, :- 1 ] if gt is not None else None
1038
+ phased = gt [:, - 1 ] if gt is not None else None
1039
+ sanitised_genotypes = sanitise_value_int_2d (shape , genotypes )
1040
+ sanitised_phased = sanitise_value_int_1d (shape [:- 1 ], phased )
1041
+ yield alleles , sanitised_genotypes , sanitised_phased
1042
+ else :
1043
+ yield alleles , None , None
1036
1044
1037
1045
def generate_schema (
1038
1046
self , variants_chunk_size = None , samples_chunk_size = None , local_alleles = None
@@ -1128,15 +1136,13 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
1128
1136
[spec_from_field (field ) for field in self .metadata .info_fields ]
1129
1137
)
1130
1138
1131
- gt_field = None
1132
1139
for field in self .metadata .format_fields :
1133
1140
if field .name == "GT" :
1134
- gt_field = field
1135
1141
continue
1136
1142
array_specs .append (spec_from_field (field ))
1137
1143
1138
- if gt_field is not None and n > 0 :
1139
- ploidy = max (gt_field .summary .max_number - 1 , 1 )
1144
+ if self . gt_field is not None and n > 0 :
1145
+ ploidy = max (self . gt_field .summary .max_number - 1 , 1 )
1140
1146
# Add ploidy dimension only when needed
1141
1147
schema_instance .dimensions ["ploidy" ] = vcz .VcfZarrDimension (size = ploidy )
1142
1148
@@ -1152,7 +1158,7 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
1152
1158
array_specs .append (
1153
1159
vcz .ZarrArraySpec (
1154
1160
name = "call_genotype" ,
1155
- dtype = gt_field .smallest_dtype (),
1161
+ dtype = self . gt_field .smallest_dtype (),
1156
1162
dimensions = ["variants" , "samples" , "ploidy" ],
1157
1163
description = "" ,
1158
1164
compressor = vcz .DEFAULT_ZARR_COMPRESSOR_GENOTYPES .get_config (),
0 commit comments