From c6f031be3a5019fd8ec08faa2611c52f7aa01e80 Mon Sep 17 00:00:00 2001 From: Shannon Woods Date: Fri, 2 May 2025 11:32:52 -0400 Subject: [PATCH 1/4] Add a wave-ballot version of the 2D diff splatting experiment --- experiments/balloted-splatting/README.md | 44 ++ .../__pycache__/app.cpython-313.pyc | Bin 0 -> 7599 bytes experiments/balloted-splatting/app.py | 122 ++++ .../balloted-splatting/diffsplatting2d.slang | 665 ++++++++++++++++++ .../balloted-splatting/example-image.png | 3 + experiments/balloted-splatting/jeep.jpg | 3 + experiments/balloted-splatting/main.py | 101 +++ .../balloted-splatting/requirements.txt | 2 + 8 files changed, 940 insertions(+) create mode 100644 experiments/balloted-splatting/README.md create mode 100644 experiments/balloted-splatting/__pycache__/app.cpython-313.pyc create mode 100644 experiments/balloted-splatting/app.py create mode 100644 experiments/balloted-splatting/diffsplatting2d.slang create mode 100644 experiments/balloted-splatting/example-image.png create mode 100644 experiments/balloted-splatting/jeep.jpg create mode 100644 experiments/balloted-splatting/main.py create mode 100644 experiments/balloted-splatting/requirements.txt diff --git a/experiments/balloted-splatting/README.md b/experiments/balloted-splatting/README.md new file mode 100644 index 0000000..edc98cf --- /dev/null +++ b/experiments/balloted-splatting/README.md @@ -0,0 +1,44 @@ +# 2D Differentiable Gaussian Splatting + +## About + +This example demonstrates the use of Slang and SlangPy to implement a 2D Gaussian splatting algorithm. + +This algorithm represents a simplified version of the 3D Gaussian Splatting algorithm detailed in this paper (https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/). This 2D demonstration does not have the 3D->2D projection step & assumes that the Gaussian blobs are presented in order of depth (higher index = farther away). Further, this implementation does not perform adaptive density control to add or remove blobs. + +See the `computeDerivativesMain()` kernel and the `splatBlobs()` function for the bulk of the key pieces of the code. This sample uses SlangPy (see `main.py`) to easily load and dispatch the kernels. SlangPy handles the pipeline setup, buffer allocation, buffer copies, and other boilerplate tasks to make it easy to prototype high-performance differentiable code. + +For a full 3D Gaussian Splatting implementation written in Slang, see this repository: https://github.com/google/slang-gaussian-rasterization + +### Workaround for 'compressing' a 2D group size into a 1D group +This sample uses a workaround for SlangPy's fixed group size of `(32, 1, 1)`. The rasterizer uses a fixed `8x4` 2D tile. We use numpy commands to construct an aray of dispatch indices such that the right threads receive the right 2D thread index. `calcCompressedDispatchIDs()` in `main.py` holds the logic for this workaround. + +When SlangPy is updated with the functionality to specify group sizes, this workaround will be removed. + +## How to Use + +### Installation + +First, install slangpy and the tev viewer: + +- **SlangPy** python package: `pip install slangpy`. See SlangPy's [docs](https://slangpy.shader-slang.org/en/latest/installation.html) for a full list of requirements. +- **Tev** viewer: Download from [releases](https://github.com/Tom94/tev/releases/tag/v1.29). See [tev's github](https://github.com/Tom94/tev) for more information. + +Then install the example's requirements, from within the sample folder: + +`pip install -r requirements.txt` + +### Optional: Setup via Conda +For simpler setup, use an anaconda/miniconda installation (See [Conda's user guide](https://docs.conda.io/projects/conda/en/latest/user-guide/index.html) for more). + +Ensure that your environment is using **Python 3.10**. +If you are using conda, you can create a new environment with **python 3.10** and **slangpy** both installed using the following command: +`conda create -n slangpy-env python=3.10 slangpy`. Then switch to this new environment with `conda activate slangpy-env`. + +### Running the Sample +- Open the **Tev** viewer and keep it running in the background. +- From the sample folder, run `python main.py` from a terminal. + +You should see a stream of images in Tev as the training progresses: + +![](./example-image.png) \ No newline at end of file diff --git a/experiments/balloted-splatting/__pycache__/app.cpython-313.pyc b/experiments/balloted-splatting/__pycache__/app.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..324c61a4273b8a5ef9381d6c2398b2532c1d6fa6 GIT binary patch literal 7599 zcmcIoT}&HS7QW-(2{vGC@DIO^O-xLPA>AbB~tph=gxS>#tcf8?q0$5 zoqNwY_ug~P{qDJEr@FcdLHWm4{Tx$=(0|B>RxHKF&M-7SMhTRluA!?GrmmW>DM-Fe z*UZ?g{93RDe$CgcS8do9M1B-P2}>hNSWlW%DcC-1irRi8A5oJ`53%VqJC){S`i-2B z%4XQ~NR*N-Z?eK%Qe{wUeGM53p?6>vlL%7@giQ$wW;P|v3G<{SuI!31n?pQIC3;C35Oni4jkvYadq;i@97*O1n}i`H4BZ8W5%chS}qXG3DrI539J zQMj=kMLn|lLN52?0T6MwZa6hF!{^eh5WIMYUC5=m6aw2Mt!46QXq(}G=P}ghsgKcw zJ~(O}LvI7Erq3Xn&@7AMHl0jwgwT!f7?q$Brn6wPv`#yrTiVFJAv9*0sMUl_xU|+u z@EuZ&EZt2d%x6t7hc0;x#p_l2@;dG-|5DpPN%mY(wAnzW?t?(-uONfZc+=iKW0~>h zUrlM?NeRnYIIkc%xK~O%ytnWFv4%2xx9^?N$dk)B4jXV7c=k&;y7o#rYDa)k3ENrl z&$MpgGC4hac|y7u15ZF32c4bJ+NED(O{R6b!}ysJwbt5)tEFswKx>zN_u*>OC=%cS zrJ0p4UE(Z?+DAsC7TJ0`H7(4^wmB{}J159iAteAvAHW=+x=ZjHK)O1WNv8AD9FqgE z<{z0bIXTz?m0Xj}3%R_&2;3bZkGV(IvN79C7B8?MGS9QKoNNX7=hD%jY~}Mg4$Cyd zq%tXiVPrEuo0e_ksmyftwrpo^7aM$<&CKQ&V@47NN#hPGMAW@r9n{45^^b96| zRi+Myst!h2%mF#WB-1RfJxzEx3ITxo?D@%qD-;lI26^^D5+TlYCv@`??-ZYz++l)Bm>-Y2dmzaQ#Vt z{3kQzAEgM{HjVtvtM^v!Nxoju*ZcUy*8^V;Z1@JrSb=ulcP=|6x>clGAB=uJ`Prm& z=(Kp~^mmqjIsfjI1{307;t3fYeHQEeV*b&*6uT_OE*IS1)vlE;$=xox+Y27w>cGl? zu+PBgtc{)W;=a#!ca)(5BNOJdz?%uBw(#ebB$&1p-w?z1Jzg2LDwjHSZ z3bl=F-LR)ky7{>qHMOpH{-yJ;U9iB~Z5!&0NzqGU^wN@rEOnEP6e68ctz6?Ie)*VO z8>F|Qw4+B9L~NMcHjgOW3qL$z3}SM{;}BGG6|tx4K{5E=AOQFX4iiIVsls1%(`wPCR~{Dcngou5SLiox$dV?Tq(2}MQGqB-q) zX2AnyDPxC;mq&_4IU0Wu!P~$*^7?4htk_g`@^jgIdK&O1%X4x~4rd8`svyW$S{z>i zak86eemCKH=;w*WK0pOW?frq}0m;!WI@%>iM07+79>0Fz!y+Btqz@HpJ$tpIN3~}_ z1Fxc;3>-b;<+J1{S&CaP`$?d_8wgEbM*EB##m8fmQE%UEkkh`e9I}&&0&)y`i5&Ki ziX1&EE-}WdOKq)XydGoe6T~N#64xVit;aYzterFf)zeBZaCctPB$jT#*kw`(FPoZAKA?&-^-^k z$7CRkVp0nb7Gee5^ng;iM{t-h`bb4~i^)`&ynHd)2tG_Iat+DeVyKtQE-bJah?z3U zEQDXOWh$M5*oo)#z?9(y2qER_crjE3-&PNzS0?3_cj%a2|$;Lh^Koo({=#RP-Eu-0*eNmrWaio)l!KVa0TfWBCyDN7kUrhAHw!8t!8xg$` z$$Lulp8D1!oxLKSy&|2xDZ-!kX2BbQc%=P2gd^=WOT(Ma;I<1jwyrm?HA{g3F)*+Z zIKNa=puG~^B+^ZvT$IA6#qjA(`i#Z1~D7^;{5pE=WCB#GWgg^i>^mJ4L$l z;d$xEpm=0(xuK#mWfY`L=p068 z&k#WYi7!#?J8&Wlk+Dpo5i*jyzM>FpXW_jE*stSWi6l%Ce2)ljOlj97?vTfu3pRm~sKgW`Vl})6QjYGc3QDNy?^~UipBK zWu{WX0-KZV@OHn&VL`UNrgk{qOx@wq#YegsmMd2eWfyYUjJ6JIK9v#9$d21M1-S!7 zC?*;xf;k>1%dn~9cSAVQq})OvyzB)HTPb^&zYQe3TLpE8>we#I-=`ia5EBD2DR4#% zoY{1oDb&{AUtC`NbZGtB+BK;;CN{@5YGe8p5E5OXO;>xts4!2l0PQ;W0JpD^!MsrJ)*xy@}Cm@rwZYYZ98fRZzD^C_XUl7 zp_Tb3q25jJaWb`|R&pGKzXG|pLZUPD;N(+hblZ%aAvK=XaTpXuP+LLnZ=nJ4^#2gK zSwK@`tbv$Le}91Q8{l-s=`f6Ocb^XR6>7&fVIQjh*ZiI&aZ1i_S;*evWHTgnu<~&1 z2Fj>Oi5?V(E6d^b}gd>#4QW z((s3GC^l)89Ic8?+C*pD1M5>~gxI7_u}LY*lO)#A3el2cop+(}DzBX&!4``quJjv+ zOlxJi;Z9H_GYv>%8o?xSAGN5VoaHu~5|k(rlc-obrFvkv`-qYxt=v$dIyJL(2ZneO zCxX_r^iDve1DkY{{;i~Y#^7oGJ>941CTdq!p&CtzNR>OLXCjQdjrS3tod+<4ekL^y2 zwd3K?7bA~Go+D^(J5AR92mCgIpT|Ql$j{HDQjZPUa+H>BaO*+-5|guniMy*06cV(^ zc9IxRr>2zLw4$R&YUWANoJu5unP0X5tjh$rvGZ8JqtFKmd4{D7){KsEnH{uj3Q MC<-PgwKDsE0j)b$ sgl.Device: + return self._device + + @property + def window(self) -> sgl.Window: + return self._window + + @property + def mouse_pos(self) -> sgl.float2: + return self._mouse_pos + + @property + def output(self) -> sgl.Texture: + return self._output_texture + + def process_events(self): + if self._window.should_close(): + return False + self._window.process_events() + return True + + def present(self): + image = self.surface.acquire_next_image() + if image is None: + return + + if (self._output_texture == None + or self._output_texture.width != image.width + or self._output_texture.height != image.height + ): + self._output_texture = self.device.create_texture( + width=image.width, + height=image.height, + format=sgl.Format.rgba32_float, + usage=sgl.TextureUsage.shader_resource | sgl.TextureUsage.unordered_access, + label="output_texture", + ) + + command_buffer = self._device.create_command_encoder() + command_buffer.blit(image, self._output_texture) + command_buffer.set_texture_state(image, sgl.ResourceState.present) + self._device.submit_command_buffer(command_buffer.finish()) + + del image + self.surface.present() + + def _on_window_keyboard_event(self, event: sgl.KeyboardEvent): + if event.type == sgl.KeyboardEventType.key_press: + if event.key == sgl.KeyCode.escape: + self._window.close() + return + elif event.key == sgl.KeyCode.f1: + if self._output_texture: + sgl.tev.show_async(self._output_texture) + return + elif event.key == sgl.KeyCode.f2: + if self._output_texture: + bitmap = self._output_texture.to_bitmap() + bitmap.convert( + sgl.Bitmap.PixelFormat.rgb, + sgl.Bitmap.ComponentType.uint8, + srgb_gamma=True, + ).write_async("screenshot.png") + return + if self.on_keyboard_event: + self.on_keyboard_event(event) + + def _on_window_mouse_event(self, event: sgl.MouseEvent): + if event.type == sgl.MouseEventType.move: + self._mouse_pos = event.pos + if self.on_mouse_event: + self.on_mouse_event(event) + + def _on_window_resize(self, width: int, height: int): + self._device.wait() + self.surface.configure(width=width, height=height) diff --git a/experiments/balloted-splatting/diffsplatting2d.slang b/experiments/balloted-splatting/diffsplatting2d.slang new file mode 100644 index 0000000..a7bc968 --- /dev/null +++ b/experiments/balloted-splatting/diffsplatting2d.slang @@ -0,0 +1,665 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +/* + * 2D Gaussian Splatting Example in Slang + * + * This example demonstrates the use of Slang's differentiable programming capabilities to implement + * a 2D Gaussian splatting algorithm that can be trained within the browser using the Slang Playground. + * + * This algorithm represents a simplified version of the 3D Gaussian Splatting algorithm detailed in + * this paper (https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/). + * This 2D demonstration does not have the 3D->2D projection step & assumes that the Gaussian blobs + * are presented in order of depth (higher index = farther away). Further, this implementation does + * not perform adaptive density control to add or remove blobs. + * + * See the `computeDerivativesMain()` kernel and the `splatBlobs()` function for the bulk of the key + * pieces of the code. + * + * Key Slang features used in this example include the autodiff operator `bwd_diff(fn)`, the + * `[Differentiable]` attribute, and custom derivatives for a few specific components via + * the `[BackwardDerivative(fn)]` attribute. + * + * For a full 3D Gaussian Splatting implementation written in Slang, see this repository: + * https://github.com/google/slang-gaussian-rasterization + * + */ + +import slangpy; + +// ----- Constants and definitions -------- + +static const int GAUSSIANS_PER_BLOCK = 512; +static const int WG_X = 8; +static const int WG_Y = 4; +static const int WG_SIZE = WG_X * WG_Y; + +static const float ADAM_ETA = 0.002; +static const float ADAM_BETA_1 = 0.9; +static const float ADAM_BETA_2 = 0.999; +static const float ADAM_EPSILON = 1e-8; + +// ----- Shared memory declarations -------- + +// Note: In Slang, the 'groupshared' identifier is used to define +// workgroup-level shared memory. This is equivalent to '__shared__' in CUDA + +groupshared uint intersectingBlobCount; + +// The blobs buffer is used to store the indices of the blobs that intersect +// with the current tile. +// +// Reduce to 16bit to conserve group shared memory +groupshared uint16_t intersectingBlobList[GAUSSIANS_PER_BLOCK]; + +// The maxCount and finalVal buffers are used to store the final PixelState objects +// after the forward pass. This data is read-back for the backwards pass. +// +// Reduce int to 16bit to conserve group shared memory +groupshared uint16_t maxCount[WG_X * WG_Y]; +groupshared float4 finalVal[WG_X * WG_Y]; + +// ----------------------------------------- + +struct Blobs +{ + GradInOutTensor blobsBuffer; + + __subscript(uint idx) -> float + { + [Differentiable] get { return blobsBuffer[idx]; } + } +}; + +/* + * Oriented bounding box (OBB) data-structure + * + * Can be used to represent the bounds of an anisotropic Gaussian blob. + * The bounding box can be extracted by taking a canonical box + * formed by (-1,-1), (1,-1), (1,1), (-1,1), then translating, rotating, and scaling it. + */ +struct OBB +{ + float2 center; + float2x2 rotation; + float2 scale; + + /* + * intersects() returns true if the OBB intersects with another OBB. + * + * The implementation is based on the separating axis theorem (see + * https://dyn4j.org/2010/01/sat/#sat-algo for a detailed explanation). + * At a high level, the SAT algorithm checks if the projections of the + * points of the two OBBs are disjoint along the normals of all of the + * faces of each OBB. + */ + bool intersects(OBB other) + { + float2 canonicalPts[4] = float2[4](float2(-1, -1), float2(1, -1), float2(1, 1), float2(-1, 1)); + + float2x2 invRotation = inverse(rotation); + float2x2 otherInvRotation = inverse(other.rotation); + float2 pts[4]; + for (int i = 0; i < 4; i++) + pts[i] = center + float2( + dot(invRotation[0], (canonicalPts[i] * scale)), + dot(invRotation[1], (canonicalPts[i] * scale))); + + float2 otherPts[4]; + for (int i = 0; i < 4; i++) + otherPts[i] = other.center + float2( + dot(otherInvRotation[0], (canonicalPts[i] * other.scale)), + dot(otherInvRotation[1], (canonicalPts[i] * other.scale))); + + return !(arePtsSeparatedAlongAxes(pts, otherPts, rotation) || + arePtsSeparatedAlongAxes(pts, otherPts, other.rotation)); + } + + static bool arePtsSeparatedAlongAxes(float2[4] pts, float2[4] otherPts, float2x2 axes) + { + // If any set of points are entirely on one side of the other, they are separated. + // + for (int i = 0; i < 2; i++) + { + float2 axis = axes[i]; + float2 proj = float2(dot(pts[0], axis), dot(pts[0], axis)); + float2 otherProj = float2(dot(otherPts[0], axis), dot(otherPts[0], axis)); + + for (int j = 1; j < 4; j++) + { + proj.x = min(proj.x, dot(pts[j], axis)); + proj.y = max(proj.y, dot(pts[j], axis)); + + otherProj.x = min(otherProj.x, dot(otherPts[j], axis)); + otherProj.y = max(otherProj.y, dot(otherPts[j], axis)); + } + + if (proj.y < otherProj.x || otherProj.y < proj.x) + return true; + } + + return false; + } + + // In Slang, constructors are defined through special methods named `__init`. + // Several constructors can be defined, and overload resolution will pick the right one. + // + __init(float2 center, float2x2 rotation, float2 scale) + { + this.center = center; + this.rotation = rotation; + this.scale = scale; + } +}; + +/* + * A utility function to premultiply the color by the alpha value. + * This is a key part of the alpha blending routine used in the + * Gaussian splatting algorithm. + */ +[Differentiable] +float4 preMult(float4 pixel) +{ + return float4(pixel.rgb * pixel.a, pixel.a); +} + +/* + * alphaBlend() implements the standard alpha blending algorithm. + * + * Takes the current pixel value 'pixel' & blends it with a + * contribution 'gval' from a new Gaussian. + */ +[Differentiable] +float4 alphaBlend(float4 pixel, float4 gval) +{ + gval = preMult(gval); + + return float4( + pixel.rgb + gval.rgb * pixel.a, + pixel.a * (1 - gval.a)); +} + +/* + * undoAlphaBlend() implements the reverse of the alpha blending algorithm. + * + * Takes a pixel value 'pixel' and the same 'gval' contribution & + * computes the previous pixel value. + * + * This is a critical piece of the backwards pass. + */ +float4 undoAlphaBlend(float4 pixel, float4 gval) +{ + gval = preMult(gval); + + var oldPixelAlpha = pixel.a / (1 - gval.a); + return float4( + pixel.rgb - gval.rgb * oldPixelAlpha, + oldPixelAlpha); +} + +/* + * PixelState encapsulates all the info for a pixel as it is being rasterized + * through the sorted list of blobs. + */ +struct PixelState : IDifferentiable +{ + float4 value; + uint finalCount; +}; + +/* + * transformPixelState() applies the alpha blending operation to the pixel state & + * updates the counter accordingly. + * + * This state transition also stops further blending once the pixel is effectively opaque. + * This is important to avoid the alpha becoming too low (or even 0), at which point + * the blending is not reversible. + * + */ +[Differentiable] +PixelState transformPixelState(PixelState pixel, float4 gval, uint blobIdx) +{ + var newState = alphaBlend(pixel.value, gval); + + if (pixel.value.a < 1.f / 255.f) + return { pixel.value, pixel.finalCount }; + + //return { newState, pixel.finalCount + 1 }; + return { newState, blobIdx }; +} + +/* + * undoPixelState() reverses the alpha blending operation and restores the previous pixel + * state. + */ +PixelState undoPixelState(PixelState nextState, uint index, float4 gval) +{ + if (index > nextState.finalCount) + return { nextState.value, nextState.finalCount }; + + return { undoAlphaBlend(nextState.value, gval), nextState.finalCount - 1 }; +} + +[Differentiable] +float2x2 inverse(float2x2 mat) +{ + float2x2 output; + + float det = determinant(mat); + output[0][0] = mat[1][1] / det; + output[0][1] = -mat[0][1] / det; + output[1][0] = -mat[1][0] / det; + output[1][1] = mat[0][0] / det; + + return output; +} + +struct Gaussian2D : IDifferentiable +{ + float2 center; + float2x2 sigma; + float3 color; + float opacity; + + [Differentiable] + static Gaussian2D load(Blobs blobs, uint idx) + { + uint total = Gaussian2D.count(blobs); + Gaussian2D gaussian; + gaussian.center = smoothstep( + float2(0, 0), + float2(1, 1), + float2(blobs[total * 0 + idx], blobs[total * 1 + idx])); + + // Add a small padding value to avoid singularities or unstable Gaussians. + gaussian.sigma[0][0] = smoothstep(0.f, 1.f, blobs[total * 2 + idx] * 0.8f) + 0.005f; + gaussian.sigma[1][1] = smoothstep(0.f, 1.f, blobs[total * 3 + idx] * 0.8f) + 0.005f; + + float aniso = (smoothstep(0.f, 1.f, blobs[total * 4 + idx] * 0.6f) - 0.5f) * 1.65f; + gaussian.sigma[0][1] = sqrt(gaussian.sigma[0][0] * gaussian.sigma[1][1]) * aniso; + gaussian.sigma[1][0] = sqrt(gaussian.sigma[0][0] * gaussian.sigma[1][1]) * aniso; + + // Scale the sigma so the blobs aren't too large + gaussian.sigma *= 0.0001; + + gaussian.color = smoothstep(0, 1, float3( + blobs[total * 5 + idx], + blobs[total * 6 + idx], + blobs[total * 7 + idx]) * 0.8f); + + gaussian.opacity = smoothstep(0, 1, blobs[total * 8 + idx] * 0.9f + 0.1f); + return gaussian; + } + + // Simple helper method to get the number of elements in the buffer + static uint count(Blobs blobs) + { + return blobs.blobsBuffer.primal.shape[0] / 9; + } + + /* + * eval() calculates the color and weight of the Gaussian at a given UV coordinate. + * + * This method calculates an alpha by applying the standard multi-variate Gaussian formula + * to calculate the power which is then scaled by an opacity value. The color components + * are represented by additional fields. + */ + [Differentiable] + float4 eval(float2 uv) + { + float2x2 invCov = inverse(sigma); + float2 diff = uv - center; + float power = -0.5f * ((diff.x * diff.x * invCov[0][0]) + + (diff.y * diff.y * invCov[1][1]) + + (diff.x * diff.y * invCov[0][1]) + + (diff.y * diff.x * invCov[1][0])); + + float weight = min(.99f, opacity * exp(power)); + return float4(color, weight); + } + + OBB bounds() + { + // Calculate eigenvectors for the 2x2 matrix. + float2x2 cov = sigma; + + float a = cov[0][0]; + float b = cov[0][1]; + float c = cov[1][0]; + float d = cov[1][1]; + + float n_stddev = 4.f; + + if (abs(b) < 1e-6 || abs(c) < 1e-6) + { + // The covariance matrix is diagonal (or close enough..), so the eigenvectors are the x and y axes. + float2x2 eigenvectors = float2x2(float2(1, 0), float2(0, 1)); + float2 scale = float2(sqrt(a), sqrt(d)); + + return OBB(center, eigenvectors, scale * n_stddev); + } + else + { + float trace = a + d; + float det = a * d - b * c; + + float lambda1 = 0.5 * (trace + sqrt(trace * trace - 4 * det)); + float lambda2 = 0.5 * (trace - sqrt(trace * trace - 4 * det)); + + float2x2 eigenvectors; + eigenvectors[0] = float2(lambda1 - d, c) / length(float2(lambda1 - d, c)); + eigenvectors[1] = float2(b, lambda2 - a) / length(float2(b, lambda2 - a)); + + // Calculate the scale of the OBB + float2 scale = float2(sqrt(lambda1), sqrt(lambda2)); + + return OBB(center, eigenvectors, scale * n_stddev); + } + } +}; + +[Differentiable] +float4 eval(Blobs blobs, uint blob_id, no_diff float2 uv) +{ + Gaussian2D gaussian = Gaussian2D.load(blobs, blob_id); + return gaussian.eval(uv); +} + +/* + * cullAndApplyBlobs finds blobs which intersect the current tile and evaluates them in a single pass using + * wave intrinsics. + * + * This uses the multiplicative alpha blending algorithm laid out in the original GS paper (https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) + * This is represented as a 'state transition' (transformPixelState) as we go through the blobs in order, so that we can + * concisely represent the 'state undo' operation in the custom backwards pass (fineRasterize_bwd). + * + * In Slang, custom derivative functions can be defiened using the `[BackwardDerivative(custom_fn)]` attribute. + */ +[BackwardDerivative(fineRasterize_bwd)] +float4 cullAndApplyBlobs(Blobs blobs, OBB tileBounds, uint localIdx, no_diff float2 uv) +{ + PixelState pixelState = PixelState(float4(0, 0, 0, 1), 0); + uint numIntersectingBlobs = 0; + + // Traverse the list in workgroup-sized chunks. Each lane in the workgroup/wave will be responsible for + // determining if one gaussian in the chunk intersects the current tile. + for (uint wgStart = 0, numGaussians = Gaussian2D.count(blobs); wgStart < numGaussians; wgStart += WG_SIZE) + { + // lane 0 will load the blob represented at position wgStart, and other lanes will get the subsequent blobs + Gaussian2D coarseBlob = Gaussian2D.load(blobs, wgStart + localIdx); + bool intersects = coarseBlob.bounds().intersects(tileBounds); + + // All lanes write to the ballot bitmask to indicate whether intersection is true; + // so all lanes will have the same value for intersectionMask + uint intersectionMask = WaveActiveBallot(intersects).x; + + while(intersectionMask != 0) + { + // identify the next lane with intersects == true in this chunk + uint idxInChunk = firstbitlow(intersectionMask); + uint16_t blobIdx = wgStart + idxInChunk; // then get the index for that blob + + intersectionMask &= intersectionMask - 1; // remove the least significant 1 bit from the mask + + float4 blobEval = eval(blobs, blobIdx, uv); + pixelState = transformPixelState(pixelState, blobEval, blobIdx); + + intersectingBlobList[min(numIntersectingBlobs++, GAUSSIANS_PER_BLOCK - 1)] = blobIdx; + } + + // if ALL the blobs processed in this chunk are below the alpha threshold, + // stop processing blobs. + if (WaveActiveAllTrue(pixelState.value.a < 1.f / 255.f)) + { + break; + } + } + + intersectingBlobCount = numIntersectingBlobs; + maxCount[localIdx] = pixelState.finalCount; + finalVal[localIdx] = pixelState.value; + return pixelState.value; +} + +/* + * fineRasterize_bwd() is the user-provided backwards pass for the fine rasterization step. + * + * This is implemented as a custom derivative function because, while applying auto-diff directly to a function + * with a loop can result in excessive state caching (a necessary part of standard automatic differentiation methods) + * + * For Gaussian splatting, there is a 'state undo' (undoPixelState) operation available. fineRasterize_bwd takes advantage of this + * to recreate the states at each step of the forward pass instead of letting auto-diff store them. + * + * While it is important to represent the backwards loop explicitly in this way, the contents of the loop body (loading, evaluation, + * blending, etc..) can still be differentiated automatically (and it would be tedioush to do so manually). + * + * The loop body therefore invokes `bwd_diff` to backprop the derivatives via auto-diff. + */ +void fineRasterize_bwd(Blobs blobset, OBB tileBounds, uint localIdx, float2 uv, float4 dOut) +{ + GroupMemoryBarrierWithGroupSync(); + + PixelState pixelState = { finalVal[localIdx], maxCount[localIdx] }; + + PixelState.Differential dColor = { dOut }; + + // The backwards pass manually performs an 'undo' to reproduce the state at each step. + // The inner loop body still uses auto-diff, so the bulk of the computation is still + // handled by the auto-diff engine. + // + for (uint _i = intersectingBlobCount; _i > 0; _i--) + { + uint i = _i - 1; + var blobID = intersectingBlobList[i]; + var gval = eval(blobset, blobID, uv); + var prevState = undoPixelState(pixelState, blobID, gval); + + var dpState = diffPair(prevState); + var dpGVal = diffPair(gval); + + // Once we have the previous state, we can continue with the backpropagation via auto-diff within + // the loop body. Note that the `bwd_diff` calls writeback the differentials to dpState and dpGVal, + // and can be obtained via `getDifferential()` (or simply '.d') + // + bwd_diff(transformPixelState)(dpState, dpGVal, blobID, dColor); + bwd_diff(eval)(blobset, blobID, uv, dpGVal.getDifferential()); + + pixelState = prevState; + dColor = dpState.getDifferential(); + } +} + +/* + * calcUV() computes a 'stretch-free' mapping from the requested render-target dimensions (renderSize) to the + * image in the texture (imageSize) + */ +float2 calcUV(uint2 dispatchThreadID, int2 renderSize, int2 imageSize) +{ + // Easy case. + if (all(renderSize == imageSize)) + return ((float2)dispatchThreadID) / renderSize; + + float aspectRatioRT = ((float)renderSize.x) / renderSize.y; + float aspectRatioTEX = ((float)imageSize.x) / imageSize.y; + + if (aspectRatioRT > aspectRatioTEX) + { + // Render target is wider than the texture. + // Match the widths. + // + float xCoord = ((float)dispatchThreadID.x) / renderSize.x; + float yCoord = ((float)dispatchThreadID.y * aspectRatioTEX) / renderSize.x; + + // We'll re-center the y-coord around 0.5. + float yCoordMax = aspectRatioTEX / aspectRatioRT; + yCoord = yCoord + (1.0 - yCoordMax) / 2.0f; + return float2(xCoord, yCoord); + } + else + { + // Render target is taller than the texture. + // Match the heights. + // + float yCoord = ((float)dispatchThreadID.y) / renderSize.y; + float xCoord = ((float)dispatchThreadID.x) / (renderSize.y * aspectRatioTEX); + + // We'll recenter the x-coord around 0.5. + float xCoordMax = aspectRatioRT / aspectRatioTEX; + xCoord = xCoord + (1.0 - xCoordMax) / 2.0f; + return float2(xCoord, yCoord); + } +} + +/* + * splatBlobs() is the main rendering routine that computes a final color for the pixel. + * + * This function will use a single pass to cull the full blob list to only the ones instersecting + * with the current workgroup tile, and evaluate them. This avoids having to re-sort the list. + * This is accomplished by using WaveActiveBallot to allow the entire subgroup to see the result + * of the intersection test for each lane in the wave. + */ +[Differentiable] +float4 splatBlobs(GradInOutTensor blobsBuffer, uint2 dispatchThreadID, int2 dispatchSize, int2 texSize) +{ + Blobs blobs = {blobsBuffer}; + + // Calculate effective uv coordinate for the current pixel. This is used for + // evaluating the 2D Gaussians. + float2 uv = no_diff calcUV(dispatchThreadID, dispatchSize, texSize); + + // + // Calculate a bounding box in uv coordinates for the current workgroup. + // + + uint2 tileCoords = uint2(dispatchThreadID.x / WG_X, dispatchThreadID.y / WG_Y); + + float2 tileLow = calcUV(tileCoords * uint2(WG_X, WG_Y), dispatchSize, texSize); + float2 tileHigh = calcUV((tileCoords + 1) * uint2(WG_X, WG_Y), dispatchSize, texSize); + + float2 tileCenter = (tileLow + tileHigh) / 2; + float2x2 tileRotation = float2x2(1, 0, 0, 1); + float2 tileScale = (tileHigh - tileLow) / 2; + + OBB tileBounds = OBB(tileCenter, tileRotation, tileScale); + + uint localIdx = WaveGetLaneIndex(); // identify the specific lane within the workgroup -- + // this will be used to correctly identify intersected + // blobs + + float4 color = cullAndApplyBlobs(blobs, tileBounds, localIdx, uv); + return float4(color.rgb * (1.0 - color.a) + color.a, 1.0); +} + +void renderBlobsToTexture( + RWTexture2D output, + GradInOutTensor blobsBuffer, + uint2 dispatchThreadID) +{ + uint2 imageSize; + output.GetDimensions(imageSize.x, imageSize.y); + output[dispatchThreadID] = splatBlobs(blobsBuffer, dispatchThreadID, imageSize, imageSize); +} + +/* + * loss() implements the standard L2 loss function to quantify the difference between + * the rendered image and the target texture. + */ +[Differentiable] +float loss(uint2 dispatchThreadID, int2 imageSize, Blobs blobs, Texture2D targetTexture) +{ + int texWidth; + int texHeight; + targetTexture.GetDimensions(texWidth, texHeight); + int2 texSize = int2(texWidth, texHeight); + + // Splat the blobs and calculate the color for this pixel. + float4 color = splatBlobs(blobs.blobsBuffer, dispatchThreadID, imageSize, texSize); + + float4 targetColor; + float weight; + if (dispatchThreadID.x >= imageSize.x || dispatchThreadID.y >= imageSize.y) + { + return 0.f; + } + else + { + //uint2 flippedCoords = uint2(dispatchThreadID.x, imageSize.y - dispatchThreadID.y); + targetColor = no_diff targetTexture[dispatchThreadID]; + return dot(color.rgb - targetColor.rgb, color.rgb - targetColor.rgb); + } + + return 0.f; +} + +[Differentiable] +void perPixelLoss( + GradInOutTensor output, + uint2 dispatchThreadID, + GradInOutTensor blobsBuffer, + Texture2D targetTexture) +{ + uint2 imageSize; + targetTexture.GetDimensions(imageSize.x, imageSize.y); + output.set( + {dispatchThreadID.x, dispatchThreadID.y}, + loss(dispatchThreadID, imageSize, {blobsBuffer}, targetTexture)); +} + +/* + * clearDerivativesMain() is a kernel that resets the derivative buffer to all 0s + */ +void clearDerivs(uint3 dispatchThreadID, RWNDBuffer derivBuffer) +{ + derivBuffer[dispatchThreadID.x] = asuint(0.f); +} + +/* + * Output a constant. Useful to quickly clear a buffer to a specific + * value with slangpy + */ +void ones(out float4 val) +{ + val = float4(1.f); +} + +/* + * updateBlobsMain() is a kernel that updates the blob parameters using the Adam optimizer. + * + * Since all the parameters are laid out in a single float buffer, there is no need to re-interpret + * the buffer into a struct. + * + * The Adam optimization method (https://arxiv.org/abs/1412.6980) is used to process the gradients before + * applying the update. It acts as a temporal filter on the gradients, and stores per-parameter state that + * persists across iterations to help stabilize the optimization process. + * + */ +void adamUpdate( + inout float val, + inout float dVal, + inout float firstMoment, + inout float secondMoment) +{ + // Read & reset the derivative + float g_t = dVal; + + float g_t_2 = g_t * g_t; + + // + // Perform a gradient update using Adam optimizer rules for + // a smoother optimization. + // + + float m_t_prev = firstMoment; + float v_t_prev = secondMoment; + float m_t = ADAM_BETA_1 * m_t_prev + (1 - ADAM_BETA_1) * g_t; + float v_t = ADAM_BETA_2 * v_t_prev + (1 - ADAM_BETA_2) * g_t_2; + + firstMoment = m_t; + secondMoment = v_t; + + float m_t_hat = m_t / (1 - ADAM_BETA_1); + float v_t_hat = v_t / (1 - ADAM_BETA_2); + + float update = (ADAM_ETA / (sqrt(v_t_hat) + ADAM_EPSILON)) * m_t_hat; + + val -= update; + dVal = 0.f; +} diff --git a/experiments/balloted-splatting/example-image.png b/experiments/balloted-splatting/example-image.png new file mode 100644 index 0000000..bda2592 --- /dev/null +++ b/experiments/balloted-splatting/example-image.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2eb51c7e5063d2bb2ce8f8b5c62c7d9b21849bf37602dcc5c5a7e21be5f69cfb +size 641975 diff --git a/experiments/balloted-splatting/jeep.jpg b/experiments/balloted-splatting/jeep.jpg new file mode 100644 index 0000000..f87b5f2 --- /dev/null +++ b/experiments/balloted-splatting/jeep.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e22272493724b6b8c94fe67f7ca31cf5b79b068656da8f7a204df1652bc8c00 +size 158253 diff --git a/experiments/balloted-splatting/main.py b/experiments/balloted-splatting/main.py new file mode 100644 index 0000000..1930557 --- /dev/null +++ b/experiments/balloted-splatting/main.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from app import App +import slangpy as spy +import sgl +import pathlib +import numpy as np +import imageio +from tqdm import tqdm + +# Load up an input image. +image = imageio.imread("./jeep.jpg") +W = image.shape[0] +H = image.shape[1] + +# Create SGL windowed app and get the device +app = App("Diffusion Splatting 2D", W, H) +device = app.device + +# 2D -> 1D dispatch-ID mapping utility to help us work around slangpy's 1D dispatch restriction + + +def calcCompressedDispatchIDs(x_max: int, y_max: int, wg_x: int, wg_y: int): + local_x = np.arange(0, wg_x, dtype=np.uint32) + local_y = np.arange(0, wg_y, dtype=np.uint32) + local_xv, local_yv = np.meshgrid(local_x, local_y, indexing="ij") + local_xyv = np.stack([local_xv, local_yv], axis=-1) + local_xyv = np.tile(local_xyv.reshape(wg_x * wg_y, 2).astype(np.uint32), + ((x_max // wg_x) * (y_max // wg_y), 1)) + local_xyv = local_xyv.reshape((x_max * y_max, 2)) + + group_x = np.arange(0, (x_max // wg_x), dtype=np.uint32) + group_y = np.arange(0, (y_max // wg_y), dtype=np.uint32) + group_xv, group_yv = np.meshgrid(group_x, group_y, indexing="ij") + group_xyv = np.stack([group_xv, group_yv], axis=-1) + group_xyv = np.tile(group_xyv[:, :, None, None, :], (1, 1, wg_y, wg_x, 1)) + group_xyv = group_xyv.reshape((x_max * y_max, 2)).astype(np.uint32) + + return ((group_xyv * np.array([wg_x, wg_y])[None, :] + local_xyv).astype(np.uint32)) + + +# Load module +module = spy.Module.load_from_file(device, "diffsplatting2d.slang") + +# Randomize the blobs buffer +NUM_BLOBS = 20480 * 2 +FLOATS_PER_BLOB = 9 +blobs = spy.Tensor.numpy(device, np.random.rand( + NUM_BLOBS * FLOATS_PER_BLOB).astype(np.float32)).with_grads() + +WORKGROUP_X, WORKGROUP_Y = 8, 4 + + +assert (W % WORKGROUP_X == 0) and (H % WORKGROUP_Y == 0) + +# Go from RGB_u8 -> RGBA_f32 +image = (image / 256.0).astype(np.float32) +image = np.concatenate([image, np.ones((W, H, 1), dtype=np.float32)], axis=-1) +input_image = device.create_texture( + data=image, + width=W, + height=H, + format=sgl.Format.rgba32_float, + usage=sgl.TextureUsage.shader_resource) + +dispatch_ids = spy.NDBuffer(device, dtype=module.uint2, shape=(W, H)) +dispatch_ids.copy_from_numpy(calcCompressedDispatchIDs(W, H, WORKGROUP_X, WORKGROUP_Y)) + +per_pixel_loss = spy.Tensor.empty(device, dtype=module.float4, shape=(W, H)) +per_pixel_loss = per_pixel_loss.with_grads() +# Set per-pixel loss' derivative to 1 (using a 1-line function in the slang file) +module.ones(per_pixel_loss.grad_in) + +adam_first_moment = spy.Tensor.zeros_like(blobs) +adam_second_moment = spy.Tensor.zeros_like(blobs) + +# Pre-allocate a texture to send data to tev occasionally. +current_render = device.create_texture( + width=W, + height=H, + format=sgl.Format.rgba32_float, + usage=sgl.TextureUsage.shader_resource | sgl.TextureUsage.unordered_access) + +iterations = 10000 +for iter in tqdm(range(iterations)): + if not app.process_events(): + exit(0) + + # Backprop the unit per-pixel loss with auto-diff. + module.perPixelLoss.bwds(per_pixel_loss, dispatch_ids, blobs, input_image) + + # Update + module.adamUpdate(blobs, blobs.grad_out, adam_first_moment, adam_second_moment) + + if iter % 10 == 0: + module.renderBlobsToTexture(app.output, blobs, dispatch_ids) + app.present() + +# Keep window processing events until user closes it. +while app.process_events(): + app.present() diff --git a/experiments/balloted-splatting/requirements.txt b/experiments/balloted-splatting/requirements.txt new file mode 100644 index 0000000..1f0ae24 --- /dev/null +++ b/experiments/balloted-splatting/requirements.txt @@ -0,0 +1,2 @@ +tqdm +imageio \ No newline at end of file From 9bc72d11476078cf0a525a874c8f3241f069c5ac Mon Sep 17 00:00:00 2001 From: Shannon Woods Date: Tue, 20 May 2025 14:55:04 -0400 Subject: [PATCH 2/4] Updates balloted splatting example to declare its capability requirements. Also removes sgl import/dependency. --- experiments/balloted-splatting/app.py | 61 +++++++++---------- ...atting2d.slang => ballotsplatting2d.slang} | 6 ++ experiments/balloted-splatting/main.py | 14 ++--- 3 files changed, 42 insertions(+), 39 deletions(-) rename experiments/balloted-splatting/{diffsplatting2d.slang => ballotsplatting2d.slang} (99%) diff --git a/experiments/balloted-splatting/app.py b/experiments/balloted-splatting/app.py index a111186..18a9d6d 100644 --- a/experiments/balloted-splatting/app.py +++ b/experiments/balloted-splatting/app.py @@ -1,38 +1,37 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import Callable, Optional -import sgl -import slangpy +import slangpy as spy from pathlib import Path class App: - def __init__(self, title="Diffsplat Example", width=1024, height=1024, device_type=sgl.DeviceType.d3d12): + def __init__(self, title="Balloted Splat Example", width=1024, height=1024, device_type=spy.DeviceType.d3d12): super().__init__() - # Create SGL window - self._window = sgl.Window( + # Create a window + self._window = spy.Window( width=width, height=height, title=title, resizable=True ) - # Create SlangPy device with local include path for shaders - self._device = slangpy.create_device(device_type, + # Create a device with local include path for shaders + self._device = spy.create_device(device_type, include_paths=[Path(__file__).parent]) # Setup swapchain self.surface = self._device.create_surface(self._window) self.surface.configure(width=self._window.width, height=self._window.height) - self._output_texture: 'sgl.Texture' = self.device.create_texture( + self._output_texture: spy.Texture = self.device.create_texture( width=self._window.width, height=self._window.height, - format=sgl.Format.rgba32_float, - usage=sgl.TextureUsage.shader_resource | sgl.TextureUsage.unordered_access, + format=spy.Format.rgba32_float, + usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access, label="output_texture", ) # Store mouse pos - self._mouse_pos = sgl.float2() + self._mouse_pos = spy.float2() # Internal events self._window.on_keyboard_event = self._on_window_keyboard_event @@ -40,23 +39,23 @@ def __init__(self, title="Diffsplat Example", width=1024, height=1024, device_ty self._window.on_resize = self._on_window_resize # Hookable events - self.on_keyboard_event: Optional[Callable[[sgl.KeyboardEvent], None]] = None - self.on_mouse_event: Optional[Callable[[sgl.MouseEvent], None]] = None + self.on_keyboard_event: Optional[Callable[[spy.KeyboardEvent], None]] = None + self.on_mouse_event: Optional[Callable[[spy.MouseEvent], None]] = None @property - def device(self) -> sgl.Device: + def device(self) -> spy.Device: return self._device @property - def window(self) -> sgl.Window: + def window(self) -> spy.Window: return self._window @property - def mouse_pos(self) -> sgl.float2: + def mouse_pos(self) -> spy.float2: return self._mouse_pos @property - def output(self) -> sgl.Texture: + def output(self) -> spy.Texture: return self._output_texture def process_events(self): @@ -77,46 +76,46 @@ def present(self): self._output_texture = self.device.create_texture( width=image.width, height=image.height, - format=sgl.Format.rgba32_float, - usage=sgl.TextureUsage.shader_resource | sgl.TextureUsage.unordered_access, + format=spy.Format.rgba32_float, + usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access, label="output_texture", ) command_buffer = self._device.create_command_encoder() command_buffer.blit(image, self._output_texture) - command_buffer.set_texture_state(image, sgl.ResourceState.present) + command_buffer.set_texture_state(image, spy.ResourceState.present) self._device.submit_command_buffer(command_buffer.finish()) del image self.surface.present() - def _on_window_keyboard_event(self, event: sgl.KeyboardEvent): - if event.type == sgl.KeyboardEventType.key_press: - if event.key == sgl.KeyCode.escape: + def _on_window_keyboard_event(self, event: spy.KeyboardEvent): + if event.type == spy.KeyboardEventType.key_press: + if event.key == spy.KeyCode.escape: self._window.close() return - elif event.key == sgl.KeyCode.f1: + elif event.key == spy.KeyCode.f1: if self._output_texture: - sgl.tev.show_async(self._output_texture) + spy.tev.show_async(self._output_texture) return - elif event.key == sgl.KeyCode.f2: + elif event.key == spy.KeyCode.f2: if self._output_texture: bitmap = self._output_texture.to_bitmap() bitmap.convert( - sgl.Bitmap.PixelFormat.rgb, - sgl.Bitmap.ComponentType.uint8, + spy.Bitmap.PixelFormat.rgb, + spy.Bitmap.ComponentType.uint8, srgb_gamma=True, ).write_async("screenshot.png") return if self.on_keyboard_event: self.on_keyboard_event(event) - def _on_window_mouse_event(self, event: sgl.MouseEvent): - if event.type == sgl.MouseEventType.move: + def _on_window_mouse_event(self, event: spy.MouseEvent): + if event.type == spy.MouseEventType.move: self._mouse_pos = event.pos if self.on_mouse_event: self.on_mouse_event(event) def _on_window_resize(self, width: int, height: int): self._device.wait() - self.surface.configure(width=width, height=height) + self.surface.configure(width=width, height=height) \ No newline at end of file diff --git a/experiments/balloted-splatting/diffsplatting2d.slang b/experiments/balloted-splatting/ballotsplatting2d.slang similarity index 99% rename from experiments/balloted-splatting/diffsplatting2d.slang rename to experiments/balloted-splatting/ballotsplatting2d.slang index a7bc968..880237e 100644 --- a/experiments/balloted-splatting/diffsplatting2d.slang +++ b/experiments/balloted-splatting/ballotsplatting2d.slang @@ -373,6 +373,8 @@ float4 eval(Blobs blobs, uint blob_id, no_diff float2 uv) * * In Slang, custom derivative functions can be defiened using the `[BackwardDerivative(custom_fn)]` attribute. */ + [require (subgroup_ballot)] + [require (subgroup_vote)] [BackwardDerivative(fineRasterize_bwd)] float4 cullAndApplyBlobs(Blobs blobs, OBB tileBounds, uint localIdx, no_diff float2 uv) { @@ -511,12 +513,16 @@ float2 calcUV(uint2 dispatchThreadID, int2 renderSize, int2 imageSize) /* * splatBlobs() is the main rendering routine that computes a final color for the pixel. * + * It will check for support of * This function will use a single pass to cull the full blob list to only the ones instersecting * with the current workgroup tile, and evaluate them. This avoids having to re-sort the list. * This is accomplished by using WaveActiveBallot to allow the entire subgroup to see the result * of the intersection test for each lane in the wave. */ [Differentiable] +[require(subgroup_ballot)] +[require(subgroup_vote)] +[require(subgroup_basic)] float4 splatBlobs(GradInOutTensor blobsBuffer, uint2 dispatchThreadID, int2 dispatchSize, int2 texSize) { Blobs blobs = {blobsBuffer}; diff --git a/experiments/balloted-splatting/main.py b/experiments/balloted-splatting/main.py index 1930557..98e5df3 100644 --- a/experiments/balloted-splatting/main.py +++ b/experiments/balloted-splatting/main.py @@ -2,8 +2,6 @@ from app import App import slangpy as spy -import sgl -import pathlib import numpy as np import imageio from tqdm import tqdm @@ -14,7 +12,7 @@ H = image.shape[1] # Create SGL windowed app and get the device -app = App("Diffusion Splatting 2D", W, H) +app = App("Balloted Splatting 2D", W, H) device = app.device # 2D -> 1D dispatch-ID mapping utility to help us work around slangpy's 1D dispatch restriction @@ -40,7 +38,7 @@ def calcCompressedDispatchIDs(x_max: int, y_max: int, wg_x: int, wg_y: int): # Load module -module = spy.Module.load_from_file(device, "diffsplatting2d.slang") +module = spy.Module.load_from_file(device, "ballotsplatting2d.slang") # Randomize the blobs buffer NUM_BLOBS = 20480 * 2 @@ -60,8 +58,8 @@ def calcCompressedDispatchIDs(x_max: int, y_max: int, wg_x: int, wg_y: int): data=image, width=W, height=H, - format=sgl.Format.rgba32_float, - usage=sgl.TextureUsage.shader_resource) + format=spy.Format.rgba32_float, + usage=spy.TextureUsage.shader_resource) dispatch_ids = spy.NDBuffer(device, dtype=module.uint2, shape=(W, H)) dispatch_ids.copy_from_numpy(calcCompressedDispatchIDs(W, H, WORKGROUP_X, WORKGROUP_Y)) @@ -78,8 +76,8 @@ def calcCompressedDispatchIDs(x_max: int, y_max: int, wg_x: int, wg_y: int): current_render = device.create_texture( width=W, height=H, - format=sgl.Format.rgba32_float, - usage=sgl.TextureUsage.shader_resource | sgl.TextureUsage.unordered_access) + format=spy.Format.rgba32_float, + usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access) iterations = 10000 for iter in tqdm(range(iterations)): From caba4786ae754a73391ee4ec86406036f466be0b Mon Sep 17 00:00:00 2001 From: Shannon Woods <158105547+swoods-nv@users.noreply.github.com> Date: Tue, 20 May 2025 19:28:04 -0400 Subject: [PATCH 3/4] Update whitespace for capability tokens Co-authored-by: ArielG-NV <159081215+ArielG-NV@users.noreply.github.com> --- experiments/balloted-splatting/ballotsplatting2d.slang | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experiments/balloted-splatting/ballotsplatting2d.slang b/experiments/balloted-splatting/ballotsplatting2d.slang index 880237e..d502cb4 100644 --- a/experiments/balloted-splatting/ballotsplatting2d.slang +++ b/experiments/balloted-splatting/ballotsplatting2d.slang @@ -373,8 +373,8 @@ float4 eval(Blobs blobs, uint blob_id, no_diff float2 uv) * * In Slang, custom derivative functions can be defiened using the `[BackwardDerivative(custom_fn)]` attribute. */ - [require (subgroup_ballot)] - [require (subgroup_vote)] +[require (subgroup_ballot)] +[require (subgroup_vote)] [BackwardDerivative(fineRasterize_bwd)] float4 cullAndApplyBlobs(Blobs blobs, OBB tileBounds, uint localIdx, no_diff float2 uv) { From 39d8e9fa70bc6f332c8c10a86b00a683b7b6e986 Mon Sep 17 00:00:00 2001 From: Shannon Woods Date: Thu, 22 May 2025 13:27:18 -0400 Subject: [PATCH 4/4] Remove capability requirements The Slang compiler will automatically enforce requirements, no need to annotate. --- experiments/balloted-splatting/ballotsplatting2d.slang | 5 ----- 1 file changed, 5 deletions(-) diff --git a/experiments/balloted-splatting/ballotsplatting2d.slang b/experiments/balloted-splatting/ballotsplatting2d.slang index 880237e..208c415 100644 --- a/experiments/balloted-splatting/ballotsplatting2d.slang +++ b/experiments/balloted-splatting/ballotsplatting2d.slang @@ -373,8 +373,6 @@ float4 eval(Blobs blobs, uint blob_id, no_diff float2 uv) * * In Slang, custom derivative functions can be defiened using the `[BackwardDerivative(custom_fn)]` attribute. */ - [require (subgroup_ballot)] - [require (subgroup_vote)] [BackwardDerivative(fineRasterize_bwd)] float4 cullAndApplyBlobs(Blobs blobs, OBB tileBounds, uint localIdx, no_diff float2 uv) { @@ -520,9 +518,6 @@ float2 calcUV(uint2 dispatchThreadID, int2 renderSize, int2 imageSize) * of the intersection test for each lane in the wave. */ [Differentiable] -[require(subgroup_ballot)] -[require(subgroup_vote)] -[require(subgroup_basic)] float4 splatBlobs(GradInOutTensor blobsBuffer, uint2 dispatchThreadID, int2 dispatchSize, int2 texSize) { Blobs blobs = {blobsBuffer};