-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinteractive_plot.py
More file actions
72 lines (55 loc) · 2.58 KB
/
interactive_plot.py
File metadata and controls
72 lines (55 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import h5py
import numpy as np
import pyvista as pv
from tqdm import tqdm # Import tqdm for progress bars
# Explicitly set PyVista backend to standalone
pv.global_theme.jupyter_backend = 'static'
# Path to the dataset
dataset_path = "/home/ma1614/zero123/zero123/CORL_rgb_depth2/demo.hdf5"
f_org = h5py.File(dataset_path, "r")
# Get the list of demos
demos = list(f_org["data"].keys())
print("Available demos:", demos)
# Create a single PyVista Plotter for all demos
plotter = pv.Plotter()
# Process each demo with a progress bar
for demo_name in tqdm(demos, desc="Processing demos"):
print(f"Processing trajectory for {demo_name}...")
# Extract ee_states data for the current demo
demo_data = np.array(f_org["data"][demo_name]["obs"]['ee_states'])
print(f"Shape of {demo_name} data: {demo_data.shape}")
positions = []
orientations = []
# Process each row in the demo data with a nested progress bar
for row in tqdm(demo_data, desc=f"Processing rows in {demo_name}", leave=False):
# Reshape the row into a 4x4 transformation matrix (column-major order)
transform = np.array(row).reshape(4, 4, order='F')
# Extract position and rotation
Px, Py, Pz = transform[:3, 3]
R = transform[:3, :3]
positions.append([Px, Py, Pz])
orientations.append(R)
# Convert positions to a NumPy array for easier slicing
positions = np.array(positions)
# Add the trajectory for the current demo
trajectory = pv.lines_from_points(positions)
plotter.add_mesh(trajectory, color=np.random.rand(3), line_width=3, label=f"Trajectory for {demo_name}")
# Add orientation arrows for the current demo
arrow_scale = 0.01 # Adjust arrow size as needed
for i, (pos, R) in enumerate(zip(positions, orientations)):
if i % 10 == 0: # Add arrows every 10th position
# Add X-axis arrow (red)
x_arrow = pv.Arrow(start=pos, direction=R[:, 0], scale=arrow_scale)
plotter.add_mesh(x_arrow, color='red')
# Add Y-axis arrow (green)
y_arrow = pv.Arrow(start=pos, direction=R[:, 1], scale=arrow_scale)
plotter.add_mesh(y_arrow, color='green')
# Add Z-axis arrow (blue)
z_arrow = pv.Arrow(start=pos, direction=R[:, 2], scale=arrow_scale)
plotter.add_mesh(z_arrow, color='blue')
# Configure the plot
plotter.show_grid() # Add a grid for better visualization
plotter.add_legend() # Add a legend for demo trajectories
plotter.view_isometric() # Set isometric view for better perspective
# Show the interactive plot
plotter.show()