Skip to content

WebUI requirement #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
darcyOly999 opened this issue Mar 24, 2025 · 1 comment
Open

WebUI requirement #65

darcyOly999 opened this issue Mar 24, 2025 · 1 comment
Labels
enhancement New feature or request

Comments

@darcyOly999
Copy link

darcyOly999 commented Mar 24, 2025

Gradio example presents below:

app.py

import gradio as gr
from gradio_litmodel3d import LitModel3D
import torch
import trimesh
from cube3d.inference.engine import Engine, EngineFast
import os
import shutil
from typing import *


import numpy as np


MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)

config_path = "cube3d/configs/open_model.yaml"
gpt_ckpt_path = "model_weights/shape_gpt.safetensors"
shape_ckpt_path = "model_weights/shape_tokenizer.safetensors"
engine_fast = EngineFast( # only supported on CUDA devices, replace with Engine otherwise
    config_path,
    gpt_ckpt_path,
    shape_ckpt_path,
    device=torch.device("cuda"),
)


def start_session(req: gr.Request):
    user_dir = os.path.join(TMP_DIR, str(req.session_hash))
    os.makedirs(user_dir, exist_ok=True)


def end_session(req: gr.Request):
    user_dir = os.path.join(TMP_DIR, str(req.session_hash))
    shutil.rmtree(user_dir)

def get_seed(randomize_seed: bool, seed: int) -> int:
    """
    Get the random seed.
    """
    return np.random.randint(0, MAX_SEED) if randomize_seed else seed


def text_to_3d(
    prompt: str,
    resolution_base: float,
    top_k: int,
    req: gr.Request,
) -> Tuple[str, str]:
    """
    Extract a GLB file from the 3D model.

    Args:
        state (dict): The state of the generated 3D model.
        mesh_simplify (float): The mesh simplification factor.
        texture_size (int): The texture resolution.

    Returns:
        str: The path to the extracted GLB file.
    """
    user_dir = os.path.join(TMP_DIR, str(req.session_hash))
    mesh_v_f = engine_fast.t2s([prompt], use_kv_cache=True, resolution_base=resolution_base, top_k=top_k)
    vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
    glb_path = os.path.join(user_dir, 'sample.glb')
    _ = trimesh.Trimesh(vertices=vertices, faces=faces).export(glb_path)
    torch.cuda.empty_cache()
    return glb_path, glb_path



with gr.Blocks(delete_cache=(600, 600)) as demo:
    with gr.Row():
        with gr.Column():
            with gr.Tabs() as input_tabs:
                with gr.Tab(label="Prompts", id=0) as single_image_input_tab:
                    text_prompt = gr.Text(label="Text Prompt")

            with gr.Accordion(label="Generation Settings", open=False):
                seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1, interactive=True)
                randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
                gr.Markdown("Sparse Structure Generation")
                with gr.Row():
                    resolution_base = gr.Slider(0.0, 10.0, label="Resolution Base", value=8.0, step=0.1, interactive=True)
                    top_k = gr.Slider(1, 10, label="Top K", value=5, step=1, interactive=True)

            extract_glb_btn = gr.Button("Extract GLB", interactive=True)

        with gr.Column():
            model_output = LitModel3D(label="Extracted GLB", exposure=10.0, height=300)

            with gr.Row():
                download_glb = gr.DownloadButton(label="Download GLB", interactive=False)

    output_buf = gr.State()

    # Handlers
    demo.load(start_session)
    demo.unload(end_session)

    extract_glb_btn.click(
        text_to_3d,
        inputs=[text_prompt, resolution_base, top_k],
        outputs=[model_output, download_glb],
    ).then(
        lambda: gr.Button(interactive=True),
        outputs=[download_glb],
    )

    model_output.clear(
        lambda: gr.Button(interactive=False),
        outputs=[download_glb],
    )


# Launch the Gradio app
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=80)
@darcyOly999 darcyOly999 added the enhancement New feature or request label Mar 24, 2025
@akashkgarg
Copy link
Collaborator

@darcyOly999 Thank you for this. Did you want to create a PR for the proposed change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants