Skip to content

Commit 6181d44

Browse files
WIP. Add dilate and erode functions to Tensor.
1 parent 307a490 commit 6181d44

File tree

1 file changed

+258
-0
lines changed

1 file changed

+258
-0
lines changed

src/utils/tensor.js

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,264 @@ export class Tensor {
789789
return new Tensor('int64', [BigInt(index)], []);
790790
}
791791

792+
/**
793+
* Normalizes the anchor point for a structuring element.
794+
*
795+
* @param {Object} anchor - The anchor point {x, y}.
796+
* @param {Object} size - The size of the kernel {width, height}.
797+
* @returns {Object} The normalized anchor point.
798+
*/
799+
normalizeAnchor(anchor, size) {
800+
if (anchor.x === -1) {
801+
anchor.x = Math.floor(size.width / 2);
802+
}
803+
if (anchor.y === -1) {
804+
anchor.y = Math.floor(size.height / 2);
805+
}
806+
if (
807+
anchor.x < 0 || anchor.x >= size.width ||
808+
anchor.y < 0 || anchor.y >= size.height
809+
) {
810+
throw new Error("Anchor is out of bounds for the given kernel size.");
811+
}
812+
return anchor;
813+
}
814+
815+
/**
816+
* Creates a structuring element for morphological operations.
817+
*
818+
* @typedef {'MORPH_RECT' | 'MORPH_CROSS' | 'MORPH_ELLIPSE'} Shape
819+
* @typedef {{width: number, height: number}} Size
820+
*
821+
* @param {Shape} shape - The shape of the kernel.
822+
* @param {number | Array<number> | Size} kernelSize - The size of the kernel {width, height}.
823+
* @param {Object} [anchor={x: -1, y: -1}] - The anchor point {x, y}.
824+
* @returns {number[][]} The structuring element as a 2D array.
825+
* @throws {Error} If the shape is invalid or the size is invalid.
826+
*/
827+
getStructuringElement(shape, kernelSize, anchor = { x: -1, y: -1 }) {
828+
if (!['MORPH_RECT', 'MORPH_CROSS', 'MORPH_ELLIPSE'].includes(shape)) {
829+
throw new Error("Invalid shape. Must be 'MORPH_RECT', 'MORPH_CROSS', or 'MORPH_ELLIPSE'.");
830+
}
831+
832+
if (typeof kernelSize === 'number' && Number.isInteger(kernelSize)) {
833+
kernelSize = { width: kernelSize, height: kernelSize };
834+
} else if (Array.isArray(kernelSize) && kernelSize.length === 2) {
835+
kernelSize = { width: kernelSize[0], height: kernelSize[1] };
836+
} else if (
837+
typeof kernelSize !== 'object' ||
838+
!('width' in kernelSize) ||
839+
!('height' in kernelSize) ||
840+
typeof kernelSize.width !== 'number' ||
841+
typeof kernelSize.height !== 'number'
842+
) {
843+
throw new Error("Invalid kernel size. Must be a number, numeric array of length 2, or an object with 'width' and 'height' properties.");
844+
}
845+
846+
if (!Number.isInteger(kernelSize.width) || !Number.isInteger(kernelSize.height)) {
847+
throw new Error('Invalid kernel size. Must be an integer.');
848+
} else if (kernelSize.width % 2 === 0 || kernelSize.height % 2 === 0) {
849+
throw new Error('Invalid kernel size. Must be an odd number.');
850+
}
851+
852+
// Normalize anchor to default to the center if not specified
853+
anchor = this.normalizeAnchor(anchor, kernelSize);
854+
855+
// If the kernel size is 1x1, treat as a rectangle
856+
if (kernelSize.width === 1 && kernelSize.height === 1) {
857+
shape = 'MORPH_RECT';
858+
}
859+
860+
let rowRadius = 0, // Radius along the height
861+
colRadius = 0, // Radius along the width
862+
inverseRowRadiusSquared = 0; // Inverse squared radius for ellipses
863+
864+
if (shape === 'MORPH_ELLIPSE') {
865+
// Calculate radii and inverse squared radius for the ellipse equation
866+
rowRadius = Math.floor(kernelSize.height / 2);
867+
colRadius = Math.floor(kernelSize.width / 2);
868+
inverseRowRadiusSquared = rowRadius > 0 ? 1 / (rowRadius * rowRadius) : 0;
869+
}
870+
871+
// Create a 2D array to represent the kernel
872+
const kernel = Array.from({ length: kernelSize.height }, () => Array(kernelSize.width).fill(0));
873+
874+
for (let row = 0; row < kernelSize.height; row++) {
875+
let startColumn = 0, // Start column for the current row
876+
endColumn = 0; // End column for the current row
877+
878+
if (shape === 'MORPH_RECT' || (shape === 'MORPH_CROSS' && row === anchor.y)) {
879+
// Full width for rectangle or horizontal line for cross shape
880+
endColumn = kernelSize.width;
881+
} else if (shape === 'MORPH_CROSS') {
882+
// Single column for cross shape
883+
startColumn = anchor.x;
884+
endColumn = startColumn + 1;
885+
} else if (shape === 'MORPH_ELLIPSE') {
886+
// Calculate elliptical bounds for this row
887+
const verticalOffset = row - rowRadius; // Distance from center row
888+
if (Math.abs(verticalOffset) <= rowRadius) {
889+
// Solve for horizontal bounds using the ellipse equation: x^2/a^2 + y^2/b^2 = 1
890+
const horizontalRadius = Math.floor(
891+
colRadius * Math.sqrt(Math.max(0, rowRadius * rowRadius - verticalOffset * verticalOffset) * inverseRowRadiusSquared)
892+
);
893+
startColumn = Math.max(colRadius - horizontalRadius, 0); // Left bound of the ellipse
894+
endColumn = Math.min(colRadius + horizontalRadius + 1, kernelSize.width); // Right bound of the ellipse
895+
}
896+
}
897+
898+
// Fill the kernel row with 1s within the range [startColumn, endColumn)
899+
for (let col = startColumn; col < endColumn; col++) {
900+
kernel[row][col] = 1;
901+
}
902+
}
903+
904+
return kernel;
905+
}
906+
907+
// https://github.com/egonSchiele/OpenCV/blob/master/modules/imgproc/src/morph.cpp#L1087
908+
dilate(kernelSize = 3) {
909+
return this.morphologicalOperation(kernelSize, 'dilate');
910+
}
911+
912+
// https://github.com/egonSchiele/OpenCV/blob/master/modules/imgproc/src/morph.cpp#L1079
913+
erode(kernelSize = 3) {
914+
return this.morphologicalOperation(kernelSize, 'erode');
915+
}
916+
917+
/**
918+
* Performs a morphological operation on the input image.
919+
*
920+
* @param {'ERODE' | 'DILATE' | 'OPEN' | 'CLOSE' | 'GRADIENT' | 'TOPHAT' | 'BLACKHAT'} op
921+
* @param {number} kernelSize
922+
*/
923+
async morphologyEx(op, kernelSize, anchor = { x: -1, y: -1 }, iterations = 1, borderType = 0, borderValue = 0) {
924+
switch (op) {
925+
case 'ERODE':
926+
// Perform erosion
927+
await this.erode(kernelSize, anchor, iterations, borderType, borderValue);
928+
break;
929+
930+
case 'DILATE':
931+
// Perform dilation
932+
await this.dilate(kernelSize, anchor, iterations, borderType, borderValue);
933+
break;
934+
935+
case 'OPEN':
936+
// Opening: erosion followed by dilation
937+
this.erode(kernelSize, anchor, iterations, borderType, borderValue);
938+
this.dilate(kernelSize, anchor, iterations, borderType, borderValue);
939+
break;
940+
941+
case 'CLOSE':
942+
// Closing: dilation followed by erosion
943+
this.dilate(kernelSize, anchor, iterations, borderType, borderValue);
944+
this.erode(dst, dst, kernelSize, anchor, iterations, borderType, borderValue);
945+
break;
946+
947+
case 'GRADIENT':
948+
// Gradient: difference between dilation and erosion
949+
temp = this.erode(src, temp, kernelSize, anchor, iterations, borderType, borderValue);
950+
this.dilate(kernelSize, anchor, iterations, borderType, borderValue);
951+
subtractMatrices(dst, temp, dst); // Element-wise subtraction
952+
break;
953+
954+
case 'TOPHAT':
955+
// Tophat: original image minus opening
956+
if (src !== dst) temp = dst;
957+
this.erode(src, temp, kernelSize, anchor, iterations, borderType, borderValue);
958+
this.dilate(temp, temp, kernelSize, anchor, iterations, borderType, borderValue);
959+
subtractMatrices(src, temp, dst);
960+
break;
961+
962+
case 'BLACKHAT':
963+
// Blackhat: closing minus original image
964+
if (src !== dst) temp = dst;
965+
dilate(src, temp, kernelSize, anchor, iterations, borderType, borderValue);
966+
erode(temp, temp, kernelSize, anchor, iterations, borderType, borderValue);
967+
subtractMatrices(temp, src, dst);
968+
break;
969+
970+
default:
971+
throw new Error("Unknown morphological operation");
972+
}
973+
}
974+
975+
/**
976+
* Applies a morphological operation to this tensor.
977+
*
978+
* @param {number} kernelSize The size of the kernel.
979+
* @param {'dilate' | 'erode'} operation The operation to apply.
980+
* @returns {Promise<Tensor>} The cloned, modified output tensor.
981+
*/
982+
async morphologicalOperation(kernelSize, operation) {
983+
// Kernel must be odd because each pixel must sit evenly in the middle.
984+
if (kernelSize % 2 === 0) {
985+
throw new Error('Kernel size must be odd.');
986+
}
987+
988+
const [batches, rows, cols] = this.dims;
989+
const paddingSize = Math.floor(kernelSize / 2);
990+
const outputData = new Float32Array(this.data.length);
991+
const operationFunction = (operationType => {
992+
switch (operationType) {
993+
case 'dilate':
994+
return Math.max;
995+
case 'erode':
996+
return Math.min;
997+
default:
998+
throw new Error(`Unknown operation: ${operationType}`);
999+
}
1000+
})(operation);
1001+
1002+
const processChunk = async (chunk) => {
1003+
for (const { batchIndex, rowIndex, colIndex } of chunk) {
1004+
const kernelValues = [];
1005+
1006+
// Collect values in the kernel window
1007+
for (let kernelRowOffset = -paddingSize; kernelRowOffset <= paddingSize; kernelRowOffset++) {
1008+
for (let kernelColOffset = -paddingSize; kernelColOffset <= paddingSize; kernelColOffset++) {
1009+
const neighborRowIndex = rowIndex + kernelRowOffset;
1010+
const neighborColIndex = colIndex + kernelColOffset;
1011+
if (neighborRowIndex >= 0 && neighborRowIndex < rows && neighborColIndex >= 0 && neighborColIndex < cols) {
1012+
const neighborIndex = batchIndex * rows * cols + neighborRowIndex * cols + neighborColIndex;
1013+
kernelValues.push(this.data[neighborIndex]);
1014+
}
1015+
}
1016+
}
1017+
1018+
// Apply operation (e.g., max for dilation, min for erosion)
1019+
const outputIndex = batchIndex * rows * cols + rowIndex * cols + colIndex;
1020+
outputData[outputIndex] = operationFunction(...kernelValues);
1021+
}
1022+
};
1023+
1024+
// Divide work into chunks for parallel processing
1025+
const chunks = [];
1026+
const chunkSize = Math.ceil((batches * rows * cols) / (navigator.hardwareConcurrency || 4));
1027+
let currentChunk = [];
1028+
1029+
for (let batchIndex = 0; batchIndex < batches; batchIndex++) {
1030+
for (let rowIndex = 0; rowIndex < rows; rowIndex++) {
1031+
for (let colIndex = 0; colIndex < cols; colIndex++) {
1032+
currentChunk.push({ batchIndex, rowIndex, colIndex });
1033+
if (currentChunk.length >= chunkSize) {
1034+
chunks.push([...currentChunk]);
1035+
currentChunk = [];
1036+
}
1037+
}
1038+
}
1039+
}
1040+
if (currentChunk.length > 0) {
1041+
chunks.push(currentChunk);
1042+
}
1043+
1044+
// Process all chunks in parallel
1045+
await Promise.all(chunks.map(chunk => processChunk(chunk)));
1046+
1047+
return new Tensor(this.type, outputData, this.dims);
1048+
}
1049+
7921050
/**
7931051
* Performs Tensor dtype conversion.
7941052
* @param {DataType} type The desired data type.

0 commit comments

Comments
 (0)