We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradio example presents below:
app.py
import gradio as gr from gradio_litmodel3d import LitModel3D import torch import trimesh from cube3d.inference.engine import Engine, EngineFast import os import shutil from typing import * import numpy as np MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') os.makedirs(TMP_DIR, exist_ok=True) config_path = "cube3d/configs/open_model.yaml" gpt_ckpt_path = "model_weights/shape_gpt.safetensors" shape_ckpt_path = "model_weights/shape_tokenizer.safetensors" engine_fast = EngineFast( # only supported on CUDA devices, replace with Engine otherwise config_path, gpt_ckpt_path, shape_ckpt_path, device=torch.device("cuda"), ) def start_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) shutil.rmtree(user_dir) def get_seed(randomize_seed: bool, seed: int) -> int: """ Get the random seed. """ return np.random.randint(0, MAX_SEED) if randomize_seed else seed def text_to_3d( prompt: str, resolution_base: float, top_k: int, req: gr.Request, ) -> Tuple[str, str]: """ Extract a GLB file from the 3D model. Args: state (dict): The state of the generated 3D model. mesh_simplify (float): The mesh simplification factor. texture_size (int): The texture resolution. Returns: str: The path to the extracted GLB file. """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) mesh_v_f = engine_fast.t2s([prompt], use_kv_cache=True, resolution_base=resolution_base, top_k=top_k) vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1] glb_path = os.path.join(user_dir, 'sample.glb') _ = trimesh.Trimesh(vertices=vertices, faces=faces).export(glb_path) torch.cuda.empty_cache() return glb_path, glb_path with gr.Blocks(delete_cache=(600, 600)) as demo: with gr.Row(): with gr.Column(): with gr.Tabs() as input_tabs: with gr.Tab(label="Prompts", id=0) as single_image_input_tab: text_prompt = gr.Text(label="Text Prompt") with gr.Accordion(label="Generation Settings", open=False): seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1, interactive=True) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) gr.Markdown("Sparse Structure Generation") with gr.Row(): resolution_base = gr.Slider(0.0, 10.0, label="Resolution Base", value=8.0, step=0.1, interactive=True) top_k = gr.Slider(1, 10, label="Top K", value=5, step=1, interactive=True) extract_glb_btn = gr.Button("Extract GLB", interactive=True) with gr.Column(): model_output = LitModel3D(label="Extracted GLB", exposure=10.0, height=300) with gr.Row(): download_glb = gr.DownloadButton(label="Download GLB", interactive=False) output_buf = gr.State() # Handlers demo.load(start_session) demo.unload(end_session) extract_glb_btn.click( text_to_3d, inputs=[text_prompt, resolution_base, top_k], outputs=[model_output, download_glb], ).then( lambda: gr.Button(interactive=True), outputs=[download_glb], ) model_output.clear( lambda: gr.Button(interactive=False), outputs=[download_glb], ) # Launch the Gradio app if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=80)
The text was updated successfully, but these errors were encountered:
@darcyOly999 Thank you for this. Did you want to create a PR for the proposed change?
Sorry, something went wrong.
No branches or pull requests
Gradio example presents below:
app.py
The text was updated successfully, but these errors were encountered: