Skip to content

Commit ff5dbbf

Browse files
WIP. Add dilate and erode functions to Tensor.
1 parent 3502ddb commit ff5dbbf

File tree

1 file changed

+258
-0
lines changed

1 file changed

+258
-0
lines changed

src/utils/tensor.js

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,264 @@ export class Tensor {
791791
return new Tensor('int64', [BigInt(index)], []);
792792
}
793793

794+
/**
795+
* Normalizes the anchor point for a structuring element.
796+
*
797+
* @param {Object} anchor - The anchor point {x, y}.
798+
* @param {Object} size - The size of the kernel {width, height}.
799+
* @returns {Object} The normalized anchor point.
800+
*/
801+
normalizeAnchor(anchor, size) {
802+
if (anchor.x === -1) {
803+
anchor.x = Math.floor(size.width / 2);
804+
}
805+
if (anchor.y === -1) {
806+
anchor.y = Math.floor(size.height / 2);
807+
}
808+
if (
809+
anchor.x < 0 || anchor.x >= size.width ||
810+
anchor.y < 0 || anchor.y >= size.height
811+
) {
812+
throw new Error("Anchor is out of bounds for the given kernel size.");
813+
}
814+
return anchor;
815+
}
816+
817+
/**
818+
* Creates a structuring element for morphological operations.
819+
*
820+
* @typedef {'MORPH_RECT' | 'MORPH_CROSS' | 'MORPH_ELLIPSE'} Shape
821+
* @typedef {{width: number, height: number}} Size
822+
*
823+
* @param {Shape} shape - The shape of the kernel.
824+
* @param {number | Array<number> | Size} kernelSize - The size of the kernel {width, height}.
825+
* @param {Object} [anchor={x: -1, y: -1}] - The anchor point {x, y}.
826+
* @returns {number[][]} The structuring element as a 2D array.
827+
* @throws {Error} If the shape is invalid or the size is invalid.
828+
*/
829+
getStructuringElement(shape, kernelSize, anchor = { x: -1, y: -1 }) {
830+
if (!['MORPH_RECT', 'MORPH_CROSS', 'MORPH_ELLIPSE'].includes(shape)) {
831+
throw new Error("Invalid shape. Must be 'MORPH_RECT', 'MORPH_CROSS', or 'MORPH_ELLIPSE'.");
832+
}
833+
834+
if (typeof kernelSize === 'number' && Number.isInteger(kernelSize)) {
835+
kernelSize = { width: kernelSize, height: kernelSize };
836+
} else if (Array.isArray(kernelSize) && kernelSize.length === 2) {
837+
kernelSize = { width: kernelSize[0], height: kernelSize[1] };
838+
} else if (
839+
typeof kernelSize !== 'object' ||
840+
!('width' in kernelSize) ||
841+
!('height' in kernelSize) ||
842+
typeof kernelSize.width !== 'number' ||
843+
typeof kernelSize.height !== 'number'
844+
) {
845+
throw new Error("Invalid kernel size. Must be a number, numeric array of length 2, or an object with 'width' and 'height' properties.");
846+
}
847+
848+
if (!Number.isInteger(kernelSize.width) || !Number.isInteger(kernelSize.height)) {
849+
throw new Error('Invalid kernel size. Must be an integer.');
850+
} else if (kernelSize.width % 2 === 0 || kernelSize.height % 2 === 0) {
851+
throw new Error('Invalid kernel size. Must be an odd number.');
852+
}
853+
854+
// Normalize anchor to default to the center if not specified
855+
anchor = this.normalizeAnchor(anchor, kernelSize);
856+
857+
// If the kernel size is 1x1, treat as a rectangle
858+
if (kernelSize.width === 1 && kernelSize.height === 1) {
859+
shape = 'MORPH_RECT';
860+
}
861+
862+
let rowRadius = 0, // Radius along the height
863+
colRadius = 0, // Radius along the width
864+
inverseRowRadiusSquared = 0; // Inverse squared radius for ellipses
865+
866+
if (shape === 'MORPH_ELLIPSE') {
867+
// Calculate radii and inverse squared radius for the ellipse equation
868+
rowRadius = Math.floor(kernelSize.height / 2);
869+
colRadius = Math.floor(kernelSize.width / 2);
870+
inverseRowRadiusSquared = rowRadius > 0 ? 1 / (rowRadius * rowRadius) : 0;
871+
}
872+
873+
// Create a 2D array to represent the kernel
874+
const kernel = Array.from({ length: kernelSize.height }, () => Array(kernelSize.width).fill(0));
875+
876+
for (let row = 0; row < kernelSize.height; row++) {
877+
let startColumn = 0, // Start column for the current row
878+
endColumn = 0; // End column for the current row
879+
880+
if (shape === 'MORPH_RECT' || (shape === 'MORPH_CROSS' && row === anchor.y)) {
881+
// Full width for rectangle or horizontal line for cross shape
882+
endColumn = kernelSize.width;
883+
} else if (shape === 'MORPH_CROSS') {
884+
// Single column for cross shape
885+
startColumn = anchor.x;
886+
endColumn = startColumn + 1;
887+
} else if (shape === 'MORPH_ELLIPSE') {
888+
// Calculate elliptical bounds for this row
889+
const verticalOffset = row - rowRadius; // Distance from center row
890+
if (Math.abs(verticalOffset) <= rowRadius) {
891+
// Solve for horizontal bounds using the ellipse equation: x^2/a^2 + y^2/b^2 = 1
892+
const horizontalRadius = Math.floor(
893+
colRadius * Math.sqrt(Math.max(0, rowRadius * rowRadius - verticalOffset * verticalOffset) * inverseRowRadiusSquared)
894+
);
895+
startColumn = Math.max(colRadius - horizontalRadius, 0); // Left bound of the ellipse
896+
endColumn = Math.min(colRadius + horizontalRadius + 1, kernelSize.width); // Right bound of the ellipse
897+
}
898+
}
899+
900+
// Fill the kernel row with 1s within the range [startColumn, endColumn)
901+
for (let col = startColumn; col < endColumn; col++) {
902+
kernel[row][col] = 1;
903+
}
904+
}
905+
906+
return kernel;
907+
}
908+
909+
// https://github.com/egonSchiele/OpenCV/blob/master/modules/imgproc/src/morph.cpp#L1087
910+
dilate(kernelSize = 3) {
911+
return this.morphologicalOperation(kernelSize, 'dilate');
912+
}
913+
914+
// https://github.com/egonSchiele/OpenCV/blob/master/modules/imgproc/src/morph.cpp#L1079
915+
erode(kernelSize = 3) {
916+
return this.morphologicalOperation(kernelSize, 'erode');
917+
}
918+
919+
/**
920+
* Performs a morphological operation on the input image.
921+
*
922+
* @param {'ERODE' | 'DILATE' | 'OPEN' | 'CLOSE' | 'GRADIENT' | 'TOPHAT' | 'BLACKHAT'} op
923+
* @param {number} kernelSize
924+
*/
925+
async morphologyEx(op, kernelSize, anchor = { x: -1, y: -1 }, iterations = 1, borderType = 0, borderValue = 0) {
926+
switch (op) {
927+
case 'ERODE':
928+
// Perform erosion
929+
await this.erode(kernelSize, anchor, iterations, borderType, borderValue);
930+
break;
931+
932+
case 'DILATE':
933+
// Perform dilation
934+
await this.dilate(kernelSize, anchor, iterations, borderType, borderValue);
935+
break;
936+
937+
case 'OPEN':
938+
// Opening: erosion followed by dilation
939+
this.erode(kernelSize, anchor, iterations, borderType, borderValue);
940+
this.dilate(kernelSize, anchor, iterations, borderType, borderValue);
941+
break;
942+
943+
case 'CLOSE':
944+
// Closing: dilation followed by erosion
945+
this.dilate(kernelSize, anchor, iterations, borderType, borderValue);
946+
this.erode(dst, dst, kernelSize, anchor, iterations, borderType, borderValue);
947+
break;
948+
949+
case 'GRADIENT':
950+
// Gradient: difference between dilation and erosion
951+
temp = this.erode(src, temp, kernelSize, anchor, iterations, borderType, borderValue);
952+
this.dilate(kernelSize, anchor, iterations, borderType, borderValue);
953+
subtractMatrices(dst, temp, dst); // Element-wise subtraction
954+
break;
955+
956+
case 'TOPHAT':
957+
// Tophat: original image minus opening
958+
if (src !== dst) temp = dst;
959+
this.erode(src, temp, kernelSize, anchor, iterations, borderType, borderValue);
960+
this.dilate(temp, temp, kernelSize, anchor, iterations, borderType, borderValue);
961+
subtractMatrices(src, temp, dst);
962+
break;
963+
964+
case 'BLACKHAT':
965+
// Blackhat: closing minus original image
966+
if (src !== dst) temp = dst;
967+
dilate(src, temp, kernelSize, anchor, iterations, borderType, borderValue);
968+
erode(temp, temp, kernelSize, anchor, iterations, borderType, borderValue);
969+
subtractMatrices(temp, src, dst);
970+
break;
971+
972+
default:
973+
throw new Error("Unknown morphological operation");
974+
}
975+
}
976+
977+
/**
978+
* Applies a morphological operation to this tensor.
979+
*
980+
* @param {number} kernelSize The size of the kernel.
981+
* @param {'dilate' | 'erode'} operation The operation to apply.
982+
* @returns {Promise<Tensor>} The cloned, modified output tensor.
983+
*/
984+
async morphologicalOperation(kernelSize, operation) {
985+
// Kernel must be odd because each pixel must sit evenly in the middle.
986+
if (kernelSize % 2 === 0) {
987+
throw new Error('Kernel size must be odd.');
988+
}
989+
990+
const [batches, rows, cols] = this.dims;
991+
const paddingSize = Math.floor(kernelSize / 2);
992+
const outputData = new Float32Array(this.data.length);
993+
const operationFunction = (operationType => {
994+
switch (operationType) {
995+
case 'dilate':
996+
return Math.max;
997+
case 'erode':
998+
return Math.min;
999+
default:
1000+
throw new Error(`Unknown operation: ${operationType}`);
1001+
}
1002+
})(operation);
1003+
1004+
const processChunk = async (chunk) => {
1005+
for (const { batchIndex, rowIndex, colIndex } of chunk) {
1006+
const kernelValues = [];
1007+
1008+
// Collect values in the kernel window
1009+
for (let kernelRowOffset = -paddingSize; kernelRowOffset <= paddingSize; kernelRowOffset++) {
1010+
for (let kernelColOffset = -paddingSize; kernelColOffset <= paddingSize; kernelColOffset++) {
1011+
const neighborRowIndex = rowIndex + kernelRowOffset;
1012+
const neighborColIndex = colIndex + kernelColOffset;
1013+
if (neighborRowIndex >= 0 && neighborRowIndex < rows && neighborColIndex >= 0 && neighborColIndex < cols) {
1014+
const neighborIndex = batchIndex * rows * cols + neighborRowIndex * cols + neighborColIndex;
1015+
kernelValues.push(this.data[neighborIndex]);
1016+
}
1017+
}
1018+
}
1019+
1020+
// Apply operation (e.g., max for dilation, min for erosion)
1021+
const outputIndex = batchIndex * rows * cols + rowIndex * cols + colIndex;
1022+
outputData[outputIndex] = operationFunction(...kernelValues);
1023+
}
1024+
};
1025+
1026+
// Divide work into chunks for parallel processing
1027+
const chunks = [];
1028+
const chunkSize = Math.ceil((batches * rows * cols) / (navigator.hardwareConcurrency || 4));
1029+
let currentChunk = [];
1030+
1031+
for (let batchIndex = 0; batchIndex < batches; batchIndex++) {
1032+
for (let rowIndex = 0; rowIndex < rows; rowIndex++) {
1033+
for (let colIndex = 0; colIndex < cols; colIndex++) {
1034+
currentChunk.push({ batchIndex, rowIndex, colIndex });
1035+
if (currentChunk.length >= chunkSize) {
1036+
chunks.push([...currentChunk]);
1037+
currentChunk = [];
1038+
}
1039+
}
1040+
}
1041+
}
1042+
if (currentChunk.length > 0) {
1043+
chunks.push(currentChunk);
1044+
}
1045+
1046+
// Process all chunks in parallel
1047+
await Promise.all(chunks.map(chunk => processChunk(chunk)));
1048+
1049+
return new Tensor(this.type, outputData, this.dims);
1050+
}
1051+
7941052
/**
7951053
* Performs Tensor dtype conversion.
7961054
* @param {DataType} type The desired data type.

0 commit comments

Comments
 (0)