|
| 1 | +import argparse |
| 2 | +import logging |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +import trimesh |
| 7 | + |
| 8 | +from cube.inference.engine import load_config, load_model_weights, parse_structured |
| 9 | +from cube.model.autoencoder.one_d_autoencoder import OneDAutoEncoder |
| 10 | + |
| 11 | +MESH_SCALE = 0.96 |
| 12 | + |
| 13 | + |
| 14 | +def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray: |
| 15 | + """Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0""" |
| 16 | + vertices = vertices |
| 17 | + bbmin = vertices.min(0) |
| 18 | + bbmax = vertices.max(0) |
| 19 | + center = (bbmin + bbmax) * 0.5 |
| 20 | + scale = 2.0 * mesh_scale / (bbmax - bbmin).max() |
| 21 | + vertices = (vertices - center) * scale |
| 22 | + return vertices |
| 23 | + |
| 24 | + |
| 25 | +def load_scaled_mesh(file_path: str) -> trimesh.Trimesh: |
| 26 | + """ |
| 27 | + Load a mesh and scale it to a unit cube, and clean the mesh. |
| 28 | + Parameters: |
| 29 | + file_obj: str | IO |
| 30 | + file_type: str |
| 31 | + Returns: |
| 32 | + mesh: trimesh.Trimesh |
| 33 | + """ |
| 34 | + mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh") |
| 35 | + mesh.remove_infinite_values() |
| 36 | + mesh.update_faces(mesh.nondegenerate_faces()) |
| 37 | + mesh.update_faces(mesh.unique_faces()) |
| 38 | + mesh.remove_unreferenced_vertices() |
| 39 | + if len(mesh.vertices) == 0 or len(mesh.faces) == 0: |
| 40 | + raise ValueError("Mesh has no vertices or faces after cleaning") |
| 41 | + mesh.vertices = rescale(mesh.vertices) |
| 42 | + return mesh |
| 43 | + |
| 44 | + |
| 45 | +def load_and_process_mesh(file_path: str, n_samples: int = 8192): |
| 46 | + """ |
| 47 | + Loads a 3D mesh from the specified file path, samples points from its surface, |
| 48 | + and processes the sampled points into a point cloud with normals. |
| 49 | + Args: |
| 50 | + file_path (str): The file path to the 3D mesh file. |
| 51 | + n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192. |
| 52 | + Returns: |
| 53 | + torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud. |
| 54 | + Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz). |
| 55 | + """ |
| 56 | + |
| 57 | + mesh = load_scaled_mesh(file_path) |
| 58 | + positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples) |
| 59 | + normals = mesh.face_normals[face_indices] |
| 60 | + point_cloud = np.concatenate( |
| 61 | + [positions, normals], axis=1 |
| 62 | + ) # Shape: (num_samples, 6) |
| 63 | + point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float() |
| 64 | + return point_cloud |
| 65 | + |
| 66 | + |
| 67 | +@torch.inference_mode() |
| 68 | +def run_shape_decode( |
| 69 | + shape_model: OneDAutoEncoder, |
| 70 | + output_ids: torch.Tensor, |
| 71 | + resolution_base: float = 8.43, |
| 72 | + chunk_size: int = 100_000, |
| 73 | +): |
| 74 | + """ |
| 75 | + Decodes the shape from the given output IDs and extracts the geometry. |
| 76 | + Args: |
| 77 | + shape_model (OneDAutoEncoder): The shape model. |
| 78 | + output_ids (torch.Tensor): The tensor containing the output IDs. |
| 79 | + resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43. |
| 80 | + chunk_size (int, optional): The chunk size for processing. Defaults to 100,000. |
| 81 | + Returns: |
| 82 | + tuple: A tuple containing the vertices and faces of the mesh. |
| 83 | + """ |
| 84 | + shape_ids = ( |
| 85 | + output_ids[:, : shape_model.cfg.num_encoder_latents, ...] |
| 86 | + .clamp_(0, shape_model.cfg.num_codes - 1) |
| 87 | + .view(-1, shape_model.cfg.num_encoder_latents) |
| 88 | + ) |
| 89 | + latents = shape_model.decode_indices(shape_ids) |
| 90 | + mesh_v_f, _ = shape_model.extract_geometry( |
| 91 | + latents, |
| 92 | + resolution_base=resolution_base, |
| 93 | + chunk_size=chunk_size, |
| 94 | + use_warp=True, |
| 95 | + ) |
| 96 | + return mesh_v_f |
| 97 | + |
| 98 | + |
| 99 | +if __name__ == "__main__": |
| 100 | + parser = argparse.ArgumentParser( |
| 101 | + description="cube shape encode and decode example script" |
| 102 | + ) |
| 103 | + parser.add_argument( |
| 104 | + "--mesh-path", |
| 105 | + type=str, |
| 106 | + required=True, |
| 107 | + help="Path to the input mesh file.", |
| 108 | + ) |
| 109 | + parser.add_argument( |
| 110 | + "--config-path", |
| 111 | + type=str, |
| 112 | + default="configs/open_model.yaml", |
| 113 | + help="Path to the configuration YAML file.", |
| 114 | + ) |
| 115 | + parser.add_argument( |
| 116 | + "--shape-ckpt-path", |
| 117 | + type=str, |
| 118 | + required=True, |
| 119 | + help="Path to the shape encoder/decoder checkpoint file.", |
| 120 | + ) |
| 121 | + parser.add_argument( |
| 122 | + "--recovered-mesh-path", |
| 123 | + type=str, |
| 124 | + default="recovered_mesh.obj", |
| 125 | + help="Path to save the recovered mesh file.", |
| 126 | + ) |
| 127 | + args = parser.parse_args() |
| 128 | + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| 129 | + logging.info(f"Using device: {device}") |
| 130 | + |
| 131 | + cfg = load_config(args.config_path) |
| 132 | + |
| 133 | + shape_model = OneDAutoEncoder( |
| 134 | + parse_structured(OneDAutoEncoder.Config, cfg.shape_model) |
| 135 | + ) |
| 136 | + load_model_weights( |
| 137 | + shape_model, |
| 138 | + args.shape_ckpt_path, |
| 139 | + ) |
| 140 | + shape_model = shape_model.eval().to(device) |
| 141 | + point_cloud = load_and_process_mesh(args.mesh_path) |
| 142 | + output = shape_model.encode(point_cloud.to(device)) |
| 143 | + indices = output[3]["indices"] |
| 144 | + print("Got the following shape indices:") |
| 145 | + print(indices) |
| 146 | + print("Indices shape: ", indices.shape) |
| 147 | + mesh_v_f = run_shape_decode(shape_model, indices) |
| 148 | + vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1] |
| 149 | + mesh = trimesh.Trimesh(vertices=vertices, faces=faces) |
| 150 | + mesh.export(args.recovered_mesh_path) |
0 commit comments