|
| 1 | +import logging |
| 2 | +import pathlib |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import tskit |
| 6 | + |
| 7 | +from bio2zarr import constants, core, vcz |
| 8 | + |
| 9 | +logger = logging.getLogger(__name__) |
| 10 | + |
| 11 | + |
| 12 | +class TskitFormat(vcz.Source): |
| 13 | + def __init__(self, ts_path, contig_id=None, ploidy=None, isolated_as_missing=False): |
| 14 | + self._path = ts_path |
| 15 | + self.ts = tskit.load(ts_path) |
| 16 | + self.contig_id = contig_id if contig_id is not None else "1" |
| 17 | + self.isolated_as_missing = isolated_as_missing |
| 18 | + |
| 19 | + self._make_sample_mapping(ploidy) |
| 20 | + self.positions = self.ts.sites_position |
| 21 | + |
| 22 | + @property |
| 23 | + def path(self): |
| 24 | + return self._path |
| 25 | + |
| 26 | + @property |
| 27 | + def num_records(self): |
| 28 | + return self.ts.num_sites |
| 29 | + |
| 30 | + @property |
| 31 | + def num_samples(self): |
| 32 | + return self._num_samples |
| 33 | + |
| 34 | + @property |
| 35 | + def samples(self): |
| 36 | + return self._samples |
| 37 | + |
| 38 | + @property |
| 39 | + def root_attrs(self): |
| 40 | + return {} |
| 41 | + |
| 42 | + @property |
| 43 | + def contigs(self): |
| 44 | + return [vcz.Contig(id=self.contig_id)] |
| 45 | + |
| 46 | + def _make_sample_mapping(self, ploidy): |
| 47 | + ts = self.ts |
| 48 | + self.individual_ploidies = [] |
| 49 | + self.max_ploidy = 0 |
| 50 | + |
| 51 | + if ts.num_individuals > 0 and ploidy is not None: |
| 52 | + raise ValueError( |
| 53 | + "Cannot specify ploidy when individuals are present in tables" |
| 54 | + ) |
| 55 | + |
| 56 | + # Find all sample nodes that reference individuals |
| 57 | + individuals = np.unique(ts.nodes_individual[ts.samples()]) |
| 58 | + if len(individuals) == 1 and individuals[0] == tskit.NULL: |
| 59 | + # No samples refer to individuals |
| 60 | + individuals = None |
| 61 | + else: |
| 62 | + # np.unique sorts the argument, so if NULL (-1) is present it |
| 63 | + # will be the first value. |
| 64 | + if individuals[0] == tskit.NULL: |
| 65 | + raise ValueError( |
| 66 | + "Sample nodes must either all be associated with individuals " |
| 67 | + "or not associated with any individuals" |
| 68 | + ) |
| 69 | + |
| 70 | + if individuals is not None: |
| 71 | + self.sample_ids = [] |
| 72 | + for i in individuals: |
| 73 | + if i < 0 or i >= self.ts.num_individuals: |
| 74 | + raise ValueError("Invalid individual IDs provided.") |
| 75 | + ind = self.ts.individual(i) |
| 76 | + if len(ind.nodes) == 0: |
| 77 | + raise ValueError(f"Individual {i} not associated with a node") |
| 78 | + is_sample = {ts.node(u).is_sample() for u in ind.nodes} |
| 79 | + if len(is_sample) != 1: |
| 80 | + raise ValueError( |
| 81 | + f"Individual {ind.id} has nodes that are sample and " |
| 82 | + "non-samples" |
| 83 | + ) |
| 84 | + self.sample_ids.extend(ind.nodes) |
| 85 | + self.individual_ploidies.append(len(ind.nodes)) |
| 86 | + self.max_ploidy = max(self.max_ploidy, len(ind.nodes)) |
| 87 | + else: |
| 88 | + if ploidy is None: |
| 89 | + ploidy = 1 |
| 90 | + if ploidy < 1: |
| 91 | + raise ValueError("Ploidy must be >= 1") |
| 92 | + if ts.num_samples % ploidy != 0: |
| 93 | + raise ValueError("Sample size must be divisible by ploidy") |
| 94 | + self.individual_ploidies = np.full( |
| 95 | + ts.num_samples // ploidy, ploidy, dtype=np.int32 |
| 96 | + ) |
| 97 | + self.max_ploidy = ploidy |
| 98 | + self.sample_ids = np.arange(ts.num_samples, dtype=np.int32) |
| 99 | + |
| 100 | + self._num_samples = len(self.individual_ploidies) |
| 101 | + |
| 102 | + self._samples = [vcz.Sample(id=f"tsk_{j}") for j in range(self.num_samples)] |
| 103 | + |
| 104 | + def iter_contig(self, start, stop): |
| 105 | + yield from (0 for _ in range(start, stop)) |
| 106 | + |
| 107 | + def iter_field(self, field_name, shape, start, stop): |
| 108 | + if field_name == "position": |
| 109 | + for pos in self.ts.sites_position[start:stop]: |
| 110 | + yield int(pos) |
| 111 | + else: |
| 112 | + raise ValueError(f"Unknown field {field_name}") |
| 113 | + |
| 114 | + def iter_alleles(self, start, stop, num_alleles): |
| 115 | + for variant in self.ts.variants( |
| 116 | + samples=self.sample_ids, |
| 117 | + isolated_as_missing=self.isolated_as_missing, |
| 118 | + left=self.positions[start], |
| 119 | + right=self.positions[stop] if stop < self.num_records else None, |
| 120 | + ): |
| 121 | + alleles = np.full(num_alleles, constants.STR_FILL, dtype="O") |
| 122 | + for i, allele in enumerate(variant.alleles): |
| 123 | + assert i < num_alleles |
| 124 | + alleles[i] = allele |
| 125 | + yield alleles |
| 126 | + |
| 127 | + def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): |
| 128 | + gt = np.zeros(shape, dtype=np.int8) |
| 129 | + phased = np.zeros(shape[:-1], dtype=bool) |
| 130 | + |
| 131 | + for variant in self.ts.variants( |
| 132 | + samples=self.sample_ids, |
| 133 | + isolated_as_missing=self.isolated_as_missing, |
| 134 | + left=self.positions[start], |
| 135 | + right=self.positions[stop] if stop < self.num_records else None, |
| 136 | + ): |
| 137 | + alleles = np.full(num_alleles, constants.STR_FILL, dtype="O") |
| 138 | + for i, allele in enumerate(variant.alleles): |
| 139 | + assert i < num_alleles |
| 140 | + alleles[i] = allele |
| 141 | + |
| 142 | + genotypes = variant.genotypes |
| 143 | + sample_index = 0 |
| 144 | + for i, ploidy in enumerate(self.individual_ploidies): |
| 145 | + for j in range(ploidy): |
| 146 | + if j < self.max_ploidy: # Only fill up to max_ploidy |
| 147 | + try: |
| 148 | + gt[i, j] = genotypes[sample_index + j] |
| 149 | + except IndexError: |
| 150 | + # This can happen if the ploidy varies between individuals |
| 151 | + gt[i, j] = -2 # Fill value |
| 152 | + |
| 153 | + # In tskit, all genotypes are considered phased |
| 154 | + phased[i] = True |
| 155 | + sample_index += ploidy |
| 156 | + |
| 157 | + yield alleles, (gt, phased) |
| 158 | + |
| 159 | + def generate_schema( |
| 160 | + self, |
| 161 | + variants_chunk_size=None, |
| 162 | + samples_chunk_size=None, |
| 163 | + ): |
| 164 | + n = self.num_samples |
| 165 | + m = self.ts.num_sites |
| 166 | + |
| 167 | + # Determine max number of alleles |
| 168 | + max_alleles = 0 |
| 169 | + for variant in self.ts.variants(): |
| 170 | + max_alleles = max(max_alleles, len(variant.alleles)) |
| 171 | + |
| 172 | + logging.info(f"Scanned tskit with {n} samples and {m} variants") |
| 173 | + logging.info( |
| 174 | + f"Maximum ploidy: {self.max_ploidy}, maximum alleles: {max_alleles}" |
| 175 | + ) |
| 176 | + |
| 177 | + dimensions = { |
| 178 | + "variants": vcz.VcfZarrDimension( |
| 179 | + size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE |
| 180 | + ), |
| 181 | + "samples": vcz.VcfZarrDimension( |
| 182 | + size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE |
| 183 | + ), |
| 184 | + "ploidy": vcz.VcfZarrDimension(size=self.max_ploidy), |
| 185 | + "alleles": vcz.VcfZarrDimension(size=max_alleles), |
| 186 | + } |
| 187 | + |
| 188 | + schema_instance = vcz.VcfZarrSchema( |
| 189 | + format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION, |
| 190 | + dimensions=dimensions, |
| 191 | + fields=[], |
| 192 | + ) |
| 193 | + |
| 194 | + logger.info( |
| 195 | + "Generating schema with chunks=" |
| 196 | + f"{schema_instance.dimensions['variants'].chunk_size}, " |
| 197 | + f"{schema_instance.dimensions['samples'].chunk_size}" |
| 198 | + ) |
| 199 | + |
| 200 | + array_specs = [ |
| 201 | + vcz.ZarrArraySpec( |
| 202 | + source="position", |
| 203 | + name="variant_position", |
| 204 | + dtype="i4", |
| 205 | + dimensions=["variants"], |
| 206 | + description="Position of each variant", |
| 207 | + ), |
| 208 | + vcz.ZarrArraySpec( |
| 209 | + source=None, |
| 210 | + name="variant_allele", |
| 211 | + dtype="O", |
| 212 | + dimensions=["variants", "alleles"], |
| 213 | + description="Alleles for each variant", |
| 214 | + ), |
| 215 | + vcz.ZarrArraySpec( |
| 216 | + source=None, |
| 217 | + name="variant_contig", |
| 218 | + dtype=core.min_int_dtype(0, len(self.contigs)), |
| 219 | + dimensions=["variants"], |
| 220 | + description="Contig/chromosome index for each variant", |
| 221 | + ), |
| 222 | + vcz.ZarrArraySpec( |
| 223 | + source=None, |
| 224 | + name="call_genotype_phased", |
| 225 | + dtype="bool", |
| 226 | + dimensions=["variants", "samples"], |
| 227 | + description="Whether the genotype is phased", |
| 228 | + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), |
| 229 | + ), |
| 230 | + vcz.ZarrArraySpec( |
| 231 | + source=None, |
| 232 | + name="call_genotype", |
| 233 | + dtype="i1", |
| 234 | + dimensions=["variants", "samples", "ploidy"], |
| 235 | + description="Genotype for each variant and sample", |
| 236 | + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(), |
| 237 | + ), |
| 238 | + vcz.ZarrArraySpec( |
| 239 | + source=None, |
| 240 | + name="call_genotype_mask", |
| 241 | + dtype="bool", |
| 242 | + dimensions=["variants", "samples", "ploidy"], |
| 243 | + description="Mask for each genotype call", |
| 244 | + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), |
| 245 | + ), |
| 246 | + ] |
| 247 | + schema_instance.fields = array_specs |
| 248 | + return schema_instance |
| 249 | + |
| 250 | + |
| 251 | +def convert( |
| 252 | + ts_path, |
| 253 | + zarr_path, |
| 254 | + *, |
| 255 | + contig_id=None, |
| 256 | + ploidy=None, |
| 257 | + isolated_as_missing=False, |
| 258 | + variants_chunk_size=None, |
| 259 | + samples_chunk_size=None, |
| 260 | + worker_processes=1, |
| 261 | + show_progress=False, |
| 262 | +): |
| 263 | + tskit_format = TskitFormat( |
| 264 | + ts_path, |
| 265 | + contig_id=contig_id, |
| 266 | + ploidy=ploidy, |
| 267 | + isolated_as_missing=isolated_as_missing, |
| 268 | + ) |
| 269 | + schema_instance = tskit_format.generate_schema( |
| 270 | + variants_chunk_size=variants_chunk_size, |
| 271 | + samples_chunk_size=samples_chunk_size, |
| 272 | + ) |
| 273 | + zarr_path = pathlib.Path(zarr_path) |
| 274 | + vzw = vcz.VcfZarrWriter(TskitFormat, zarr_path) |
| 275 | + # Rough heuristic to split work up enough to keep utilisation high |
| 276 | + target_num_partitions = max(1, worker_processes * 4) |
| 277 | + vzw.init( |
| 278 | + tskit_format, |
| 279 | + target_num_partitions=target_num_partitions, |
| 280 | + schema=schema_instance, |
| 281 | + ) |
| 282 | + vzw.encode_all_partitions( |
| 283 | + worker_processes=worker_processes, |
| 284 | + show_progress=show_progress, |
| 285 | + ) |
| 286 | + vzw.finalise(show_progress) |
| 287 | + vzw.create_index() |
0 commit comments