Skip to content

Commit 0066bf5

Browse files
authored
Merge pull request #10 from Roblox/tijmen-add-encode-decode-example
Add shape encode and decode example
2 parents 2f464ae + d3c43cd commit 0066bf5

File tree

5 files changed

+169
-12
lines changed

5 files changed

+169
-12
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,5 @@ cython_debug/
172172
.DS_Store
173173

174174
# Output folder
175-
outputs/
175+
outputs/
176+
model_weights/

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ pip install -e .
2828

2929
## Inference
3030

31+
### Shape Generation
32+
3133
To generate 3D models using the downloaded models simply run:
3234
```bash
3335
python3 generate.py --gpt-ckpt-path model_weights/gpt.safetensors --shape-ckpt-path model_weights/shape.safetensors --prompt "sleek vintage green couch with clean lines and velvet material" --fast-inference
@@ -39,3 +41,12 @@ If you want to render a turntable gif of the mesh, you can use the `--render-gif
3941
and save it as `turntable.gif` in the specified `output` directory.
4042

4143
> **Note**: You must have Blender installed and available in your system's PATH to render the turntable GIF. You can download it from [Blender's official website](https://www.blender.org/). Ensure that the Blender executable is accessible from the command line.
44+
45+
### Shaple tokenization and de-tokenization
46+
To tokenize a 3D shape into token indices and reconstruct it back, you can use the following command:
47+
48+
```bash
49+
python3 vq_vae_encode_decode.py --shape-ckpt-path model_weights/shape.safetensors --mesh-path ./outputs/output.obj
50+
```
51+
52+
This will process the `.obj` file located at `./outputs/output.obj` and prints the tokenized representation as well as exports the mesh reconstructed from the token indices.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
"warp-lang",
2727
"accelerate",
2828
"scikit-image",
29-
"huggingface-hub",
29+
"huggingface_hub[cli]",
3030
]
3131
[project.optional-dependencies]
3232
lint = ["ruff","pyright"] # Development tools

src/cube/model/autoencoder/spherical_vq.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ def quantize(self, z: torch.Tensor):
120120

121121
return z_q, {"z": z.detach(), "q": q}
122122

123+
def straight_through_approximation(self, z, z_q):
124+
"""passed gradient from z_q to z"""
125+
z_q = z + (z_q - z).detach()
126+
return z_q
127+
123128
def forward(self, z: torch.Tensor):
124129
"""
125130
Forward pass of the spherical vector quantization autoencoder.
@@ -150,14 +155,4 @@ def forward(self, z: torch.Tensor):
150155
z_q = self.straight_through_approximation(z_e, z_q)
151156
z_q = self.c_out(z_q)
152157

153-
with torch.no_grad():
154-
e_mean = (
155-
F.one_hot(ret_dict["q"], num_classes=self.num_codes)
156-
.view(-1, self.num_codes)
157-
.float()
158-
.mean(0)
159-
)
160-
ret_dict["perplexity"] = torch.exp(
161-
-torch.sum(e_mean * torch.log(e_mean + 1e-10))
162-
)
163158
return z_q, ret_dict

vq_vae_encode_decode.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import argparse
2+
import logging
3+
4+
import numpy as np
5+
import torch
6+
import trimesh
7+
8+
from cube.inference.engine import load_config, load_model_weights, parse_structured
9+
from cube.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
10+
11+
MESH_SCALE = 0.96
12+
13+
14+
def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray:
15+
"""Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0"""
16+
vertices = vertices
17+
bbmin = vertices.min(0)
18+
bbmax = vertices.max(0)
19+
center = (bbmin + bbmax) * 0.5
20+
scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
21+
vertices = (vertices - center) * scale
22+
return vertices
23+
24+
25+
def load_scaled_mesh(file_path: str) -> trimesh.Trimesh:
26+
"""
27+
Load a mesh and scale it to a unit cube, and clean the mesh.
28+
Parameters:
29+
file_obj: str | IO
30+
file_type: str
31+
Returns:
32+
mesh: trimesh.Trimesh
33+
"""
34+
mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh")
35+
mesh.remove_infinite_values()
36+
mesh.update_faces(mesh.nondegenerate_faces())
37+
mesh.update_faces(mesh.unique_faces())
38+
mesh.remove_unreferenced_vertices()
39+
if len(mesh.vertices) == 0 or len(mesh.faces) == 0:
40+
raise ValueError("Mesh has no vertices or faces after cleaning")
41+
mesh.vertices = rescale(mesh.vertices)
42+
return mesh
43+
44+
45+
def load_and_process_mesh(file_path: str, n_samples: int = 8192):
46+
"""
47+
Loads a 3D mesh from the specified file path, samples points from its surface,
48+
and processes the sampled points into a point cloud with normals.
49+
Args:
50+
file_path (str): The file path to the 3D mesh file.
51+
n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192.
52+
Returns:
53+
torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud.
54+
Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz).
55+
"""
56+
57+
mesh = load_scaled_mesh(file_path)
58+
positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples)
59+
normals = mesh.face_normals[face_indices]
60+
point_cloud = np.concatenate(
61+
[positions, normals], axis=1
62+
) # Shape: (num_samples, 6)
63+
point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float()
64+
return point_cloud
65+
66+
67+
@torch.inference_mode()
68+
def run_shape_decode(
69+
shape_model: OneDAutoEncoder,
70+
output_ids: torch.Tensor,
71+
resolution_base: float = 8.43,
72+
chunk_size: int = 100_000,
73+
):
74+
"""
75+
Decodes the shape from the given output IDs and extracts the geometry.
76+
Args:
77+
shape_model (OneDAutoEncoder): The shape model.
78+
output_ids (torch.Tensor): The tensor containing the output IDs.
79+
resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
80+
chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
81+
Returns:
82+
tuple: A tuple containing the vertices and faces of the mesh.
83+
"""
84+
shape_ids = (
85+
output_ids[:, : shape_model.cfg.num_encoder_latents, ...]
86+
.clamp_(0, shape_model.cfg.num_codes - 1)
87+
.view(-1, shape_model.cfg.num_encoder_latents)
88+
)
89+
latents = shape_model.decode_indices(shape_ids)
90+
mesh_v_f, _ = shape_model.extract_geometry(
91+
latents,
92+
resolution_base=resolution_base,
93+
chunk_size=chunk_size,
94+
use_warp=True,
95+
)
96+
return mesh_v_f
97+
98+
99+
if __name__ == "__main__":
100+
parser = argparse.ArgumentParser(
101+
description="cube shape encode and decode example script"
102+
)
103+
parser.add_argument(
104+
"--mesh-path",
105+
type=str,
106+
required=True,
107+
help="Path to the input mesh file.",
108+
)
109+
parser.add_argument(
110+
"--config-path",
111+
type=str,
112+
default="configs/open_model.yaml",
113+
help="Path to the configuration YAML file.",
114+
)
115+
parser.add_argument(
116+
"--shape-ckpt-path",
117+
type=str,
118+
required=True,
119+
help="Path to the shape encoder/decoder checkpoint file.",
120+
)
121+
parser.add_argument(
122+
"--recovered-mesh-path",
123+
type=str,
124+
default="recovered_mesh.obj",
125+
help="Path to save the recovered mesh file.",
126+
)
127+
args = parser.parse_args()
128+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
129+
logging.info(f"Using device: {device}")
130+
131+
cfg = load_config(args.config_path)
132+
133+
shape_model = OneDAutoEncoder(
134+
parse_structured(OneDAutoEncoder.Config, cfg.shape_model)
135+
)
136+
load_model_weights(
137+
shape_model,
138+
args.shape_ckpt_path,
139+
)
140+
shape_model = shape_model.eval().to(device)
141+
point_cloud = load_and_process_mesh(args.mesh_path)
142+
output = shape_model.encode(point_cloud.to(device))
143+
indices = output[3]["indices"]
144+
print("Got the following shape indices:")
145+
print(indices)
146+
print("Indices shape: ", indices.shape)
147+
mesh_v_f = run_shape_decode(shape_model, indices)
148+
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
149+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
150+
mesh.export(args.recovered_mesh_path)

0 commit comments

Comments
 (0)