Skip to content

Commit 0325683

Browse files
committed
mask: use custom outerJoin to try and preserve partitioner
#424
1 parent a1da3fb commit 0325683

File tree

2 files changed

+145
-139
lines changed

2 files changed

+145
-139
lines changed

geotrellis-common/src/main/scala/org/openeo/geotrelliscommon/DatacubeSupport.scala

+143-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package org.openeo.geotrelliscommon
22

3-
import geotrellis.layer.{Boundable, Bounds, FloatingLayoutScheme, KeyBounds, LayoutDefinition, LayoutLevel, LayoutScheme, Metadata, SpaceTimeKey, TileLayerMetadata, ZoomedLayoutScheme}
3+
import geotrellis.layer.{Boundable, Bounds, FloatingLayoutScheme, KeyBounds, LayoutDefinition, LayoutLevel, LayoutScheme, Metadata, SpaceTimeKey, SpatialKey, TileLayerMetadata, ZoomedLayoutScheme}
44
import geotrellis.proj4.CRS
5-
import geotrellis.raster.{CellSize, CellType, MultibandTile, NODATA, doubleNODATA, isData}
5+
import geotrellis.raster.{CellSize, CellType, DoubleCellType, MultibandTile, NODATA, doubleNODATA, isData}
66
import geotrellis.spark.join.SpatialJoin
77
import geotrellis.spark.partition.{PartitionerIndex, SpacePartitioner}
88
import geotrellis.spark.{MultibandTileLayerRDD, _}
9-
import geotrellis.util.GetComponent
9+
import geotrellis.util._
1010
import geotrellis.vector.{Extent, MultiPolygon, ProjectedExtent}
1111
import org.apache.spark.Partitioner
1212
import org.apache.spark.rdd.{CoGroupedRDD, RDD}
@@ -195,8 +195,7 @@ object DatacubeSupport {
195195

196196

197197
}
198-
Some(SpacePartitioner(metadata.bounds)(SpaceTimeKey.Boundable,
199-
ClassTag(classOf[SpaceTimeKey]), partitionerIndex))
198+
Some(SpacePartitioner(metadata.bounds)(SpaceTimeKey.Boundable,ClassTag(classOf[SpaceTimeKey]), partitionerIndex))
200199
}
201200

202201

@@ -206,7 +205,7 @@ object DatacubeSupport {
206205
replacement: java.lang.Double,
207206
ignoreKeysWithoutMask: Boolean = false,
208207
): RDD[(K, MultibandTile)] with Metadata[M] = {
209-
val joined = if (ignoreKeysWithoutMask) {
208+
val joined: RDD[(K, (MultibandTile, Option[MultibandTile]))] with Metadata[_ >: M with Bounds[K]] = if (ignoreKeysWithoutMask) {
210209
//inner join, try to preserve partitioner
211210
val tmpRdd: RDD[(K, (MultibandTile, Option[MultibandTile]))] =
212211
if(datacube.partitioner.isDefined && datacube.partitioner.get.isInstanceOf[SpacePartitioner[K]]){
@@ -228,7 +227,7 @@ object DatacubeSupport {
228227

229228
ContextRDD(tmpRdd, datacube.metadata)
230229
} else {
231-
SpatialJoin.leftOuterJoin(datacube, mask)
230+
outerJoin(datacube,mask).withContext(_.filter(_._2._1.isDefined).map(t => (t._1,(t._2._1.get,t._2._2))))
232231
}
233232
val replacementInt: Int = if (replacement == null) NODATA else replacement.intValue()
234233
val replacementDouble: Double = if (replacement == null) doubleNODATA else replacement
@@ -254,6 +253,143 @@ object DatacubeSupport {
254253
new ContextRDD(masked, datacube.metadata)
255254
}
256255

256+
def maybeBandCount[K](cube: RDD[(K, MultibandTile)]): Option[Int] = {
257+
if (cube.isInstanceOf[OpenEORasterCube[K]] && cube.asInstanceOf[OpenEORasterCube[K]].openEOMetadata.bandCount > 0) {
258+
val count = cube.asInstanceOf[OpenEORasterCube[K]].openEOMetadata.bandCount
259+
logger.info(s"Computed band count ${count} from metadata of ${cube}")
260+
return Some(count)
261+
}else{
262+
return None
263+
}
264+
}
265+
266+
def getManyBandsIndexGeneric[K]()(implicit t:ClassTag[K]):PartitionerIndex[K] = {
267+
import reflect.ClassTag
268+
val spacetimeKeyTag = classOf[SpaceTimeKey]
269+
val index: PartitionerIndex[K] = t match {
270+
case strtag if strtag == ClassTag(spacetimeKeyTag) => SpaceTimeByMonthPartitioner.asInstanceOf[PartitionerIndex[K]]
271+
case _ => ByTileSpatialPartitioner.asInstanceOf[PartitionerIndex[K]]
272+
}
273+
index
274+
}
275+
276+
277+
def maybeCellType[K](cube: RDD[(K, MultibandTile)]): Option[CellType] = {
278+
if (cube.isInstanceOf[MultibandTileLayerRDD[K]]) {
279+
return Some(cube.asInstanceOf[MultibandTileLayerRDD[K]].metadata.cellType)
280+
}
281+
return None
282+
}
283+
284+
def maybeTileSize[K](cube: RDD[(K, MultibandTile)]): Option[Int] = {
285+
if (cube.isInstanceOf[MultibandTileLayerRDD[K]]) {
286+
return Some(cube.asInstanceOf[MultibandTileLayerRDD[K]].metadata.tileLayout.tileSize)
287+
}
288+
return None
289+
}
290+
291+
/**
292+
* Determines the appropriate partitioner index for the maximum partition size based on the input parameters.
293+
*
294+
* @param nrBands The number of bands in the tile.
295+
* @param tileSize The size of the tile in number of pixels (cols*rows).
296+
* @param cellTypeBits The number of bits used for the cell type (e.g., 8 for Byte, 16 for Short, etc.).
297+
* @param maxPartitionSizeInMb The maximum size of each partition in megabytes. Default is 500.0 MB.
298+
* @return A partitioner index of type `PartitionerIndex[K]` tailored to ensure the partitions adhere to the specified maximum size.
299+
*/
300+
def getPartitionerIndexForMaxPartitionSize[K](nrBands: Int, tileSize: Int, cellTypeBits: Int, maxPartitionSizeInMb: Double = 500.0)(implicit t:ClassTag[K]): PartitionerIndex[K] = {
301+
// Estimate the maximum amount of records required to hit maxPartitionSizeInMb,
302+
// then calculate the max indexReduction that remains under this amount of records.
303+
val tileSizeInMb: Double = (nrBands * tileSize * cellTypeBits).toDouble / (8 * 1024 * 1024)
304+
val maxRecordsPerPartition: Double = math.min(maxPartitionSizeInMb / tileSizeInMb, 1024)
305+
val indexReduction = math.max(math.ceil(math.log(maxRecordsPerPartition) / math.log(2)).toInt - 1, 1)
306+
t match {
307+
case spaceTag if spaceTag == ClassTag(classOf[SpatialKey]) => {
308+
logger.info(s"Creating ConfigurableSpatialPartitionerReduceZ($indexReduction) based on tile size: $tileSize, band count: $nrBands, cell type bits: $cellTypeBits, tileSizeInMb: $tileSizeInMb")
309+
new ConfigurableSpatialPartitionerReduceZ(indexReduction).asInstanceOf[PartitionerIndex[K]]
310+
}
311+
case _ => {
312+
logger.info(s"Creating ConfigurableSpaceTimePartitioner($indexReduction) based on tile size: $tileSize, band count: $nrBands, cell type bits: $cellTypeBits, tileSizeInMb: $tileSizeInMb")
313+
new ConfigurableSpaceTimePartitioner(indexReduction).asInstanceOf[PartitionerIndex[K]]
314+
}
315+
}
316+
}
317+
318+
def outerJoin[K: Boundable: PartitionerIndex: ClassTag,
319+
M: GetComponent[*, Bounds[K]],
320+
M1: GetComponent[*, Bounds[K]]
321+
](leftCube: RDD[(K, MultibandTile)] with Metadata[M], rightCube: RDD[(K, MultibandTile)] with Metadata[M1]): RDD[(K, (Option[MultibandTile], Option[MultibandTile]))] with Metadata[Bounds[K]] = {
322+
323+
val kbLeft: Bounds[K] = leftCube.metadata.getComponent[Bounds[K]]
324+
val kbRight: Bounds[K] = rightCube.metadata.getComponent[Bounds[K]]
325+
val kb: Bounds[K] = kbLeft.combine(kbRight)
326+
327+
val leftCount = maybeBandCount(leftCube)
328+
val rightCount = maybeBandCount(rightCube)
329+
//fairly arbitrary heuristic if we're going to create a cube with a high number of bands
330+
val manyBands = leftCount.getOrElse(1) + rightCount.getOrElse(1) > 25
331+
332+
val part =if( leftCube.partitioner.isDefined && rightCube.partitioner.isDefined && leftCube.partitioner.get.isInstanceOf[SpacePartitioner[K]] && rightCube.partitioner.get.isInstanceOf[SpacePartitioner[K]]) {
333+
val leftPart = leftCube.partitioner.get.asInstanceOf[SpacePartitioner[K]]
334+
val rightPart = rightCube.partitioner.get.asInstanceOf[SpacePartitioner[K]]
335+
logger.info(s"Merging cubes with spatial indices: ${leftPart.index} - ${rightPart.index}")
336+
if(leftPart.index == rightPart.index && leftPart.index.isInstanceOf[SparseSpaceTimePartitioner]) {
337+
val newIndices: Array[BigInt] = (leftPart.index.asInstanceOf[SparseSpaceTimePartitioner].indices ++ rightPart.index.asInstanceOf[SparseSpaceTimePartitioner].indices).distinct.sorted
338+
implicit val newIndex: PartitionerIndex[K] = new SparseSpaceTimePartitioner(newIndices,leftPart.index.asInstanceOf[SparseSpaceTimePartitioner].indexReduction).asInstanceOf[PartitionerIndex[K]]
339+
SpacePartitioner[K](kb)(implicitly,implicitly,newIndex)
340+
}else if(leftPart.index == rightPart.index && leftPart.index.isInstanceOf[SparseSpaceOnlyPartitioner]) {
341+
val newIndices: Array[BigInt] = (leftPart.index.asInstanceOf[SparseSpaceOnlyPartitioner].indices ++ rightPart.index.asInstanceOf[SparseSpaceOnlyPartitioner].indices).distinct.sorted
342+
implicit val newIndex: PartitionerIndex[K] = new SparseSpaceOnlyPartitioner(newIndices,leftPart.index.asInstanceOf[SparseSpaceOnlyPartitioner].indexReduction).asInstanceOf[PartitionerIndex[K]]
343+
SpacePartitioner[K](kb)(implicitly,implicitly,newIndex)
344+
}
345+
else if(leftPart.index == rightPart.index && (leftPart.index == ByTileSpatialPartitioner || leftPart.index.isInstanceOf[ByTileSpacetimePartitioner])) {
346+
leftPart
347+
}
348+
else if(leftPart.index == rightPart.index && leftPart.index.isInstanceOf[ConfigurableSpaceTimePartitioner] ) {
349+
leftPart
350+
}
351+
else if(leftPart.index == rightPart.index && leftPart.index.isInstanceOf[ConfigurableSpatialPartitionerReduceZ] ) {
352+
val indexReduction: Int = leftPart.index.asInstanceOf[ConfigurableSpatialPartitionerReduceZ].indexReduction
353+
logger.info(s"Using ConfigurableSpatialPartitionerReduceZ with indexReduction: ${indexReduction}")
354+
leftPart
355+
}
356+
else if(leftPart.index == rightPart.index && leftPart.index.isInstanceOf[SparseSpatialPartitioner] ) {
357+
val newIndices: Array[BigInt] = (leftPart.index.asInstanceOf[SparseSpatialPartitioner].indices ++ rightPart.index.asInstanceOf[SparseSpatialPartitioner].indices).distinct.sorted
358+
implicit val newIndex: PartitionerIndex[K] = new SparseSpatialPartitioner(newIndices,leftPart.index.asInstanceOf[SparseSpatialPartitioner].indexReduction).asInstanceOf[PartitionerIndex[K]]
359+
SpacePartitioner[K](kb)(implicitly,implicitly,newIndex)
360+
}
361+
else{
362+
SpacePartitioner[K](kb)
363+
}
364+
} else {
365+
// At least one partitioner is undefined.
366+
logger.info(s"Merging cubes with partitioners: ${leftCube.partitioner} - ${rightCube.partitioner} - many band case detected: $manyBands")
367+
if(manyBands) {
368+
val index: PartitionerIndex[K] = getManyBandsIndexGeneric[K]()
369+
SpacePartitioner[K](kb)(implicitly,implicitly,index)
370+
} else {
371+
val nrBands = leftCount.getOrElse(10) + rightCount.getOrElse(10)
372+
val outputCellType = maybeCellType(leftCube).getOrElse(DoubleCellType).union(maybeCellType(rightCube).getOrElse(DoubleCellType))
373+
val tileSize = maybeTileSize(leftCube).getOrElse(128 * 128)
374+
val newIndex = getPartitionerIndexForMaxPartitionSize[K](nrBands, tileSize, outputCellType.bits)
375+
SpacePartitioner[K](kb)(implicitly, implicitly, newIndex)
376+
}
377+
}
378+
379+
val joinRdd =
380+
new CoGroupedRDD[K](List(part(leftCube), part(rightCube)), part)
381+
.flatMapValues { case Array(l, r) =>
382+
if (l.isEmpty)
383+
for (v <- r.iterator) yield (None, Some(v))
384+
else if (r.isEmpty)
385+
for (v <- l.iterator) yield (Some(v), None)
386+
else
387+
for (v <- l.iterator; w <- r.iterator) yield (Some(v), Some(w))
388+
}.asInstanceOf[RDD[(K, (Option[MultibandTile], Option[MultibandTile]))]]
389+
390+
ContextRDD(joinRdd, part.bounds)
391+
}
392+
257393
def applyDataMask(datacubeParams: Option[DataCubeParameters],
258394
rdd: RDD[(SpaceTimeKey, MultibandTile)],
259395
metadata: TileLayerMetadata[SpaceTimeKey],

0 commit comments

Comments
 (0)