Skip to content

Add wave ballot example #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions experiments/balloted-splatting/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# 2D Differentiable Gaussian Splatting

## About

This example demonstrates the use of Slang and SlangPy to implement a 2D Gaussian splatting algorithm.

This algorithm represents a simplified version of the 3D Gaussian Splatting algorithm detailed in this paper (https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/). This 2D demonstration does not have the 3D->2D projection step & assumes that the Gaussian blobs are presented in order of depth (higher index = farther away). Further, this implementation does not perform adaptive density control to add or remove blobs.

See the `computeDerivativesMain()` kernel and the `splatBlobs()` function for the bulk of the key pieces of the code. This sample uses SlangPy (see `main.py`) to easily load and dispatch the kernels. SlangPy handles the pipeline setup, buffer allocation, buffer copies, and other boilerplate tasks to make it easy to prototype high-performance differentiable code.

For a full 3D Gaussian Splatting implementation written in Slang, see this repository: https://github.com/google/slang-gaussian-rasterization

### Workaround for 'compressing' a 2D group size into a 1D group
This sample uses a workaround for SlangPy's fixed group size of `(32, 1, 1)`. The rasterizer uses a fixed `8x4` 2D tile. We use numpy commands to construct an aray of dispatch indices such that the right threads receive the right 2D thread index. `calcCompressedDispatchIDs()` in `main.py` holds the logic for this workaround.

When SlangPy is updated with the functionality to specify group sizes, this workaround will be removed.

## How to Use

### Installation

First, install slangpy and the tev viewer:

- **SlangPy** python package: `pip install slangpy`. See SlangPy's [docs](https://slangpy.shader-slang.org/en/latest/installation.html) for a full list of requirements.
- **Tev** viewer: Download from [releases](https://github.com/Tom94/tev/releases/tag/v1.29). See [tev's github](https://github.com/Tom94/tev) for more information.

Then install the example's requirements, from within the sample folder:

`pip install -r requirements.txt`

### Optional: Setup via Conda
For simpler setup, use an anaconda/miniconda installation (See [Conda's user guide](https://docs.conda.io/projects/conda/en/latest/user-guide/index.html) for more).

Ensure that your environment is using **Python 3.10**.
If you are using conda, you can create a new environment with **python 3.10** and **slangpy** both installed using the following command:
`conda create -n slangpy-env python=3.10 slangpy`. Then switch to this new environment with `conda activate slangpy-env`.

### Running the Sample
- Open the **Tev** viewer and keep it running in the background.
- From the sample folder, run `python main.py` from a terminal.

You should see a stream of images in Tev as the training progresses:

![](./example-image.png)
Binary file not shown.
121 changes: 121 additions & 0 deletions experiments/balloted-splatting/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Callable, Optional
import slangpy as spy
from pathlib import Path


class App:
def __init__(self, title="Balloted Splat Example", width=1024, height=1024, device_type=spy.DeviceType.d3d12):
super().__init__()

# Create a window
self._window = spy.Window(
width=width, height=height, title=title, resizable=True
)

# Create a device with local include path for shaders
self._device = spy.create_device(device_type,
include_paths=[Path(__file__).parent])

# Setup swapchain
self.surface = self._device.create_surface(self._window)
self.surface.configure(width=self._window.width, height=self._window.height)

self._output_texture: spy.Texture = self.device.create_texture(
width=self._window.width,
height=self._window.height,
format=spy.Format.rgba32_float,
usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access,
label="output_texture",
)

# Store mouse pos
self._mouse_pos = spy.float2()

# Internal events
self._window.on_keyboard_event = self._on_window_keyboard_event
self._window.on_mouse_event = self._on_window_mouse_event
self._window.on_resize = self._on_window_resize

# Hookable events
self.on_keyboard_event: Optional[Callable[[spy.KeyboardEvent], None]] = None
self.on_mouse_event: Optional[Callable[[spy.MouseEvent], None]] = None

@property
def device(self) -> spy.Device:
return self._device

@property
def window(self) -> spy.Window:
return self._window

@property
def mouse_pos(self) -> spy.float2:
return self._mouse_pos

@property
def output(self) -> spy.Texture:
return self._output_texture

def process_events(self):
if self._window.should_close():
return False
self._window.process_events()
return True

def present(self):
image = self.surface.acquire_next_image()
if image is None:
return

if (self._output_texture == None
or self._output_texture.width != image.width
or self._output_texture.height != image.height
):
self._output_texture = self.device.create_texture(
width=image.width,
height=image.height,
format=spy.Format.rgba32_float,
usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access,
label="output_texture",
)

command_buffer = self._device.create_command_encoder()
command_buffer.blit(image, self._output_texture)
command_buffer.set_texture_state(image, spy.ResourceState.present)
self._device.submit_command_buffer(command_buffer.finish())

del image
self.surface.present()

def _on_window_keyboard_event(self, event: spy.KeyboardEvent):
if event.type == spy.KeyboardEventType.key_press:
if event.key == spy.KeyCode.escape:
self._window.close()
return
elif event.key == spy.KeyCode.f1:
if self._output_texture:
spy.tev.show_async(self._output_texture)
return
elif event.key == spy.KeyCode.f2:
if self._output_texture:
bitmap = self._output_texture.to_bitmap()
bitmap.convert(
spy.Bitmap.PixelFormat.rgb,
spy.Bitmap.ComponentType.uint8,
srgb_gamma=True,
).write_async("screenshot.png")
return
if self.on_keyboard_event:
self.on_keyboard_event(event)

def _on_window_mouse_event(self, event: spy.MouseEvent):
if event.type == spy.MouseEventType.move:
self._mouse_pos = event.pos
if self.on_mouse_event:
self.on_mouse_event(event)

def _on_window_resize(self, width: int, height: int):
self._device.wait()
self.surface.configure(width=width, height=height)
Loading