diff --git a/.gitignore b/.gitignore
index dfb15106..111a017b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -27,3 +27,8 @@ dist/
.vscode
imgui.ini
+
+octo/
+checkpoints/
+ManiSkill2_real2sim/
+videos/
\ No newline at end of file
diff --git a/README.md b/README.md
index f232c0d4..6158d4dd 100644
--- a/README.md
+++ b/README.md
@@ -6,11 +6,10 @@
Significant progress has been made in building generalist robot manipulation policies, yet their scalable and reproducible evaluation remains challenging, as real-world evaluation is operationally expensive and inefficient. We propose employing physical simulators as efficient, scalable, and informative complements to real-world evaluations. These simulation evaluations offer valuable quantitative metrics for checkpoint selection, insights into potential real-world policy behaviors or failure modes, and standardized setups to enhance reproducibility.
-This repository is based in the [SAPIEN](https://sapien.ucsd.edu/) simulator and the [ManiSkill2](https://maniskill2.github.io/) benchmark (we will also integrate the evaluation envs into ManiSkill3 once it is complete).
+This repository is based in the [SAPIEN](https://sapien.ucsd.edu/) simulator and the [ManiSkill 3](https://github.com/haosulab/ManiSkill) robotics framework. Note that to reproduce the original results, you need to use `main` branch which uses a older version of ManiSkill and SAPIEN. The version used here leverages the CPU/GPU simulation and rendering capabilities of the latest ManiSkill and SAPIEN.
+
+The `maniskill3` branch of SimplerEnv currently is simply used for installing the inference setup for policies like RT-1 and Octo. The real2sim environments are written in ManiSkill 3's github repo.
-This repository encompasses 2 real-to-sim evaluation setups:
-- `Visual Matching` evaluation: Matching real & sim visual appearances for policy evaluation by overlaying real-world images onto simulation backgrounds and adjusting foreground object and robot textures in simulation.
-- `Variant Aggregation` evaluation: creating different sim environment variants (e.g., different backgrounds, lightings, distractors, table textures, etc) and averaging their results.
We hope that our work guides and inspires future real-to-sim evaluation efforts.
@@ -32,178 +31,76 @@ We hope that our work guides and inspires future real-to-sim evaluation efforts.
## Getting Started
-Follow the [Installation](#installation) section to install the minimal requirements for our environments. Then you can run the following minimal inference script with interactive python. The scripts creates prepackaged environments for our `visual matching` evaluation setup.
+Follow the [Installation](#installation) section to install the minimal requirements to create our environments. Then you can run the following minimal inference script with interactive python.
```python
-import simpler_env
-from simpler_env.utils.env.observation_utils import get_image_from_maniskill2_obs_dict
-
-env = simpler_env.make('google_robot_pick_coke_can')
-obs, reset_info = env.reset()
-instruction = env.get_language_instruction()
-print("Reset info", reset_info)
-print("Instruction", instruction)
-
-done, truncated = False, False
-while not (done or truncated):
- # action[:3]: delta xyz; action[3:6]: delta rotation in axis-angle representation;
- # action[6:7]: gripper (the meaning of open / close depends on robot URDF)
- image = get_image_from_maniskill2_obs_dict(env, obs)
- action = env.action_space.sample() # replace this with your policy inference
- obs, reward, done, truncated, info = env.step(action) # for long horizon tasks, you can call env.advance_to_next_subtask() to advance to the next subtask; the environment might also autoadvance if env._elapsed_steps is larger than a threshold
- new_instruction = env.get_language_instruction()
- if new_instruction != instruction:
- # for long horizon tasks, we get a new instruction when robot proceeds to the next subtask
- instruction = new_instruction
- print("New Instruction", instruction)
-
-episode_stats = info.get('episode_stats', {})
-print("Episode stats", episode_stats)
-```
-
-Additionally, you can play with our environments in an interactive manner through [`ManiSkill2_real2sim/mani_skill2_real2sim/examples/demo_manual_control_custom_envs.py`](https://github.com/simpler-env/ManiSkill2_real2sim/blob/main/mani_skill2_real2sim/examples/demo_manual_control_custom_envs.py). See the script for more details and commands.
+import gymnasium as gym
+from simpler_env.utils.env.observation_utils import get_image_from_maniskill3_obs_dict
+from mani_skill.envs.tasks.digital_twins.bridge_dataset_eval import *
+env = gym.make(
+ "PutSpoonOnTableClothInScene-v1",
+ obs_mode="rgb+segmentation",
+ num_envs=16, # if num_envs > 1, GPU simulation backend is used.
+)
+obs, _ = env.reset()
+# returns language instruction for each parallel env
+instruction = env.unwrapped.get_language_instruction()
+print("instruction:", instruction[0])
+
+while True:
+ # action[:3]: delta xyz; action[3:6]: delta rotation in axis-angle representation;
+ # action[6:7]: gripper (the meaning of open / close depends on robot URDF)
+ image = get_image_from_maniskill3_obs_dict(env, obs) # this is the image observation for policy inference
+ action = env.action_space.sample() # replace this with your policy inference
+ obs, reward, terminated, truncated, info = env.step(action)
+ if truncated.any():
+ break
+print("Episode Info", info)
+```
+
## Installation
+The basic installation is simply installing ManiSkill 3 which officially supports real2sim environments.
+
Prerequisites:
- CUDA version >=11.8 (this is required if you want to perform a full installation of this repo and perform RT-1 or Octo inference)
- An NVIDIA GPU (ideally RTX; for non-RTX GPUs, such as 1080Ti and A100, environments that involve ray tracing will be slow). Currently TPU is not supported as SAPIEN requires a GPU to run.
-Create an anaconda environment:
-```
-conda create -n simpler_env python=3.10 (any version above 3.10 should be fine)
-conda activate simpler_env
-```
-
-Clone this repo:
-```
-git clone https://github.com/simpler-env/SimplerEnv --recurse-submodules
-```
-
-Install numpy<2.0 (otherwise errors in IK might occur in pinocchio):
-```
-pip install numpy==1.24.4
-```
-
-Install ManiSkill2 real-to-sim environments and their dependencies:
-```
-cd {this_repo}/ManiSkill2_real2sim
-pip install -e .
+First git clone this repo:
+```bash
+git clone https://github.com/simpler-env/SimplerEnv
```
-Install this package:
-```
-cd {this_repo}
+Create a conda/mamba environment and install dependencies:
+```bash
+cd path/to/SimplerEnv
+conda create -n simpler_env python=3.10.12
+conda activate ms3-octo
+pip install --upgrade git+https://github.com/haosulab/ManiSkill.git
+pip install torch==2.3.1 tyro==0.8.5
pip install -e .
```
-**If you'd like to perform evaluations on our provided agents (e.g., RT-1, Octo), or add new robots and environments, please additionally follow the full installation instructions [here](#full-installation-rt-1-and-octo-inference-env-building).**
-
-## Examples
-
-- Simple RT-1 and Octo evaluation script on prepackaged environments with visual matching evaluation setup: see [`simpler_env/simple_inference_visual_matching_prepackaged_envs.py`](https://github.com/simpler-env/SimplerEnv/blob/main/simpler_env/simple_inference_visual_matching_prepackaged_envs.py).
-- Colab notebook for RT-1 and Octo inference: see [this link](https://colab.research.google.com/github/simpler-env/SimplerEnv/blob/main/example.ipynb).
-- Environment interactive visualization and manual control: see [`ManiSkill2_real2sim/mani_skill2_real2sim/examples/demo_manual_control_custom_envs.py`](https://github.com/simpler-env/ManiSkill2_real2sim/blob/main/mani_skill2_real2sim/examples/demo_manual_control_custom_envs.py)
-- Policy inference scripts to reproduce our Google Robot and WidowX real-to-sim evaluation results with sweeps over object / robot poses and advanced loggings. These contain both visual matching and variant aggregation evaluation setups along with RT-1, RT-1-X, and Octo policies. See [`scripts/`](https://github.com/simpler-env/SimplerEnv/tree/main/scripts).
-- Real-to-sim evaluation videos from running `scripts/*.sh`: see [this link](https://huggingface.co/datasets/xuanlinli17/simpler-env-eval-example-videos/tree/main).
+**If you'd like to perform evaluations on our provided agents (e.g., RT-1, Octo), or add new robots and environments, please additionally follow the full installation instructions [here](#full-installation-rt-1-and-octo-inference).**
## Current Environments
-To get a list of all available environments, run:
-```
-import simpler_env
-print(simpler_env.ENVIRONMENTS)
-```
+In ManiSkill 3, the following environments (a subset of the original environments in the paper) have been ported over to ManiSkill 3 with GPU simulation and rendering support.
-| Task Name | ManiSkill2 Env Name | Image (Visual Matching) |
+| Task Name | ManiSkill 3 Env Name | Image (Visual Matching) |
| ----------- | ----- | ----- |
-| google_robot_pick_coke_can | GraspSingleOpenedCokeCanInScene-v0 |
|
-| google_robot_pick_object | GraspSingleRandomObjectInScene-v0 |
|
-| google_robot_move_near | MoveNearGoogleBakedTexInScene-v1 |
|
-| google_robot_open_drawer | OpenDrawerCustomInScene-v0 |
|
-| google_robot_close_drawer | CloseDrawerCustomInScene-v0 |
|
-| google_robot_place_in_closed_drawer | PlaceIntoClosedDrawerCustomInScene-v0 |
|
-| widowx_spoon_on_towel | PutSpoonOnTableClothInScene-v0 |
|
-| widowx_carrot_on_plate | PutCarrotOnPlateInScene-v0 |
|
-| widowx_stack_cube | StackGreenCubeOnYellowCubeBakedTexInScene-v0 |
|
-| widowx_put_eggplant_in_basket | PutEggplantInBasketScene-v0 |
|
+| widowx_spoon_on_towel | PutSpoonOnTableClothInScene-v1 |
|
+| widowx_carrot_on_plate | PutCarrotOnPlateInScene-v1 |
|
+| widowx_stack_cube | StackGreenCubeOnYellowCubeBakedTexInScene-v1 |
|
+| widowx_put_eggplant_in_basket | PutEggplantInBasketScene-v1 |
|
-We also support creating sub-tasks variations such as `google_robot_pick_{horizontal/vertical/standing}_coke_can`, `google_robot_open_{top/middle/bottom}_drawer`, and `google_robot_close_{top/middle/bottom}_drawer`. For the `google_robot_place_in_closed_drawer` task, we use the `google_robot_place_apple_in_closed_top_drawer` subtask for paper evaluations.
-
-By default, Google Robot environments use a control frequency of 3hz, and Bridge environments use a control frequency of 5hz. Simulation frequency is ~500hz.
-
-
-## Compare Your Policy Evaluation Approach to SIMPLER
-
-We make it easy to compare your offline robot policy evaluation approach to SIMPLER. In [our paper](https://simpler-env.github.io/), we use two metrics to measure the quality of simulated evaluation pipelines: Mean Maximum Rank Violation (MMRV) and the Pearson Correlation Coefficient. Both capture how well the offline evaluations reflect the policy's real-world performance and behaviors during robot rollouts.
-
-To make comparisons easy, we provide our real and SIMPLER evaluation performance for all policies on all tasks. We also provide the corresponding functions for computing the aforementioned metrics we report in the paper.
-
-To compute the corresponding metrics for *your* offline policy evaluation approach `your_sim_eval(task, policy)`, you can use the following snippet:
-```
-from simpler_env.utils.metrics import mean_maximum_rank_violation, pearson_correlation, REAL_PERF
-
-sim_eval_perf = [
- your_sim_eval(task="google_robot_move_near", policy=p)
- for p in ["rt-1-x", "octo", ...]
-]
-real_eval_perf = [
- REAL_PERF["google_robot_move_near"][p] for p in ["rt-1-x", "octo", ...]
-]
-mmrv = mean_maximum_rank_violation(real_eval_perf, sim_eval_perf)
-pearson = pearson_correlation(real_eval_perf, sim_eval_perf)
-```
-
-To reproduce the key numbers from our paper for SIMPLER, you can run the [`tools/calc_metrics.py`](tools/calc_metrics.py) script:
-```
-python3 tools/calc_metrics.py
-```
-
-## Code Structure
-
-```
-ManiSkill2_real2sim/: the ManiSkill2 real-to-sim environment codebase, which contains the environments, robots, and objects for real-to-sim evaluation.
- data/
- custom/: custom object assets (e.g., coke can, cabinet) and their infos
- hab2_bench_assets/: custom scene assets
- real_inpainting/: real-world inpainting images for visual matching evaluation
- debug/: debugging assets
- mani_skill2_real2sim/
- agents/: robot agents, configs, and controller implementations
- assets/: robot assets such as URDF and meshes
- envs/: environments
- examples/demo_manual_control_custom_envs.py: interactive script for environment visualization and manual
- utils/
- ...
-simpler_env/
- evaluation/: real-to-sim evaluator with advanced environment building and logging
- argparse.py: argument parser supporting custom policy and environment building
- maniskill2_evaluator.py: evaluator that supports environment parameter sweeps and advanced logging
- policies/: policy implementations
- rt1/: RT-1 policy implementation
- octo/: Octo policy implementation
- utils/:
- env/: environment building and observation utilities
- debug/: debugging tools for policies and robots
- ...
- main_inference.py: main inference script, taking in args from evaluation.argparse and calling evaluation.maniskill2_evaluator
- simple_inference_visual_matching_prepackaged_envs.py: an independent simple inference script on prepackaged environments, doesn't depend on evaluation/*
-tools/
- robot_object_visualization/: tools for visualizing robots and objects when creating new environments
- sysid/: tools for system identification when adding new robots
- calc_metrics.py: tools for summarizing eval results and calculating metrics, such as Mean Maximum Rank Violation (MMRV) and Pearson Correlation
- coacd_process_mesh.py: tools for generating convex collision meshes through CoACD when adding new assets
- merge_videos.py: tools for merging videos into one
- ...
-scripts/: example bash scripts for policy inference under our variant aggregation / visual matching evaluation setup,
- with custom environment building and advanced logging; also useful for reproducing our evaluation results
-...
-```
## Adding New Policies
-If you want to use existing environments for evaluating new policies, you can keep `./ManiSkill2_real2sim` as is.
+If you want to use existing environments for evaluating new policies, you can follow the instructions below.
1. Implement new policy inference scripts in `simpler_env/policies/{your_new_policy}`, following the examples for RT-1 (`simpler_env/policies/rt1`) and Octo (`simpler_env/policies/octo`) policies.
2. You can now use `simpler_env/simple_inference_visual_matching_prepackaged_envs.py` to perform policy evaluations in simulation.
@@ -216,10 +113,10 @@ If you want to use existing environments for evaluating new policies, you can ke
## Adding New Real-to-Sim Evaluation Environments and Robots
-We provide a step-by-step guide to add new real-to-sim evaluation environments and robots in [this README](ADDING_NEW_ENVS_ROBOTS.md)
+This is a WIP, and a new and updated tutorial for ManiSkill 3 will be coming soon on the ManiSkill 3 github / documentation.
-## Full Installation (RT-1 and Octo Inference, Env Building)
+## Full Installation (RT-1 and Octo Inference)
If you'd like to perform evaluations on our provided agents (e.g., RT-1, Octo), or add new robots and environments, please follow the full installation instructions below.
@@ -228,9 +125,17 @@ sudo apt install ffmpeg
```
```
+cd path/to/SimplerEnv
+pip install -e .
pip install tensorflow==2.15.0
pip install -r requirements_full_install.txt
pip install tensorflow[and-cuda]==2.15.1 # tensorflow gpu support
+
+pip install --upgrade "jax[cuda12_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+git clone https://github.com/octo-models/octo/
+cd octo
+git checkout 653c54acde686fde619855f2eac0dd6edad7116b # we use octo-1.0
+pip install -e .
```
Install simulated annealing utils for system identification:
@@ -272,23 +177,21 @@ mv rt_1_tf_trained_for_000001120 checkpoints
### Octo Inference Setup
-Install Octo:
-```
-pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # or jax[cuda12_pip] if you have CUDA 12
-
-cd {this_repo}
-git clone https://github.com/octo-models/octo/
-cd octo
-git checkout 653c54acde686fde619855f2eac0dd6edad7116b # we use octo-1.0
-pip install -e .
-# You don't need to run "pip install -r requirements.txt" inside the octo repo; the package dependencies are already handled in the simpler_env repo
-# Octo checkpoints are managed by huggingface, so you don't need to download them manually.
-```
+Earlier instructions already setup Octo for inference.
If you are using CUDA 12, then to use GPU for Octo inference, you need CUDA version >= 12.2 to satisfy the requirement of Jax; in this case, you can perform a runfile install of the corresponding CUDA (e.g., version 12.3), then set the environment variables whenever you run Octo inference scripts:
`PATH=/usr/local/cuda-12.3/bin:$PATH LD_LIBRARY_PATH=/usr/local/cuda-12.3/lib64:$LD_LIBRARY_PATH bash scripts/octo_xxx_script.sh`
+### Evaluating Octo and RT-1
+
+The new ManiSkill3 evaluation script is in `simpler_env/real2sim_eval_maniskill3.py`. See the script for more details. An example usage is shown below:
+```
+XLA_PYTHON_CLIENT_PREALLOCATE=false python simpler_env/real2sim_eval_maniskill3.py \
+ --model="octo-small" -e "PutEggplantInBasketScene-v1" -s 0 --num-episodes 192 --num-envs 64
+```
+to evaluate 192 episodes of octo-small model on PutEggplantInBasketScene-v1 environment with 64 parallel environments. You can use more environments if you have enough memory. Note that this is not deterministic and results may vary between runs.
+
## Troubleshooting
1. If you encounter issues such as
@@ -299,7 +202,7 @@ Some required Vulkan extension is not present. You may not use the renderer to r
Segmentation fault (core dumped)
```
-Follow [this link](https://maniskill.readthedocs.io/en/latest/user_guide/getting_started/installation.html#vulkan) to troubleshoot the issue. (Even though the doc points to SAPIEN 3 and ManiSkill3, the troubleshooting section still applies to the current environments that use SAPIEN 2.2 and ManiSkill2).
+Follow [this link](https://maniskill.readthedocs.io/en/latest/user_guide/getting_started/installation.html#vulkan) to troubleshoot the issue.
2. You can ignore the following error if it is caused by tensorflow's internal code. Sometimes this error will occur when running the inference or debugging scripts.
@@ -319,3 +222,5 @@ If you find our ideas / environments helpful, please cite our work at
year={2024}
}
```
+
+
\ No newline at end of file
diff --git a/simpler_env/__init__.py b/simpler_env/__init__.py
index cbecacd6..dcfdb6e0 100644
--- a/simpler_env/__init__.py
+++ b/simpler_env/__init__.py
@@ -1,5 +1,5 @@
import gymnasium as gym
-import mani_skill2_real2sim.envs
+import mani_skill.envs
ENVIRONMENTS = [
"google_robot_pick_coke_can",
diff --git a/simpler_env/policies/octo/octo_model.py b/simpler_env/policies/octo/octo_model.py
index 92110985..fe939495 100644
--- a/simpler_env/policies/octo/octo_model.py
+++ b/simpler_env/policies/octo/octo_model.py
@@ -9,9 +9,24 @@
import tensorflow as tf
from transformers import AutoTokenizer
from transforms3d.euler import euler2axangle
-
+from functools import partial
from simpler_env.utils.action.action_ensemble import ActionEnsembler
+from mani_skill.utils.geometry import rotation_conversions
+from mani_skill.utils import common
+import torch
+from torch.utils import dlpack as torch_dlpack
+
+from jax import dlpack as jax_dlpack
+import jax.numpy as jnp
+def torch2jax(x_torch):
+ x_torch = x_torch.contiguous() # https://github.com/google/jax/issues/8082
+ x_jax = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch))
+ return x_jax
+
+def jax2torch(x_jax):
+ x_torch = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(x_jax))
+ return x_torch
class OctoInference:
def __init__(
@@ -55,6 +70,8 @@ def __init__(
self.model = OctoModel.load_pretrained(self.model_type)
self.action_mean = self.model.dataset_statistics[dataset_id]["action"]["mean"]
self.action_std = self.model.dataset_statistics[dataset_id]["action"]["std"]
+ self.action_mean = jnp.array(self.action_mean)
+ self.action_std = jnp.array(self.action_std)
else:
raise NotImplementedError()
@@ -86,13 +103,9 @@ def __init__(
self.num_image_history = 0
def _resize_image(self, image: np.ndarray) -> np.ndarray:
- image = tf.image.resize(
- image,
- size=(self.image_size, self.image_size),
- method="lanczos3",
- antialias=True,
- )
- image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8).numpy()
+ """resize image to a square image of size self.image_size. image should be shape (B, H, W, 3)"""
+ image = jax.vmap(partial(jax.image.resize, shape=(self.image_size, self.image_size, 3), method="lanczos3", antialias=True))(image)
+ image = jnp.clip(jnp.round(image), 0, 255).astype(jnp.uint8)
return image
def _add_image_to_history(self, image: np.ndarray) -> None:
@@ -105,17 +118,20 @@ def _add_image_to_history(self, image: np.ndarray) -> None:
self.num_image_history = min(self.num_image_history + 1, self.horizon)
def _obtain_image_history_and_mask(self) -> tuple[np.ndarray, np.ndarray]:
- images = np.stack(self.image_history, axis=0)
+ images = jnp.stack(self.image_history, axis=1)
+ batch_size = images.shape[0]
horizon = len(self.image_history)
- pad_mask = np.ones(horizon, dtype=np.float64) # note: this should be of float type, not a bool type
- pad_mask[: horizon - min(horizon, self.num_image_history)] = 0
+ pad_mask = jnp.ones((batch_size, horizon), dtype=jnp.float32) # note: this should be of float type, not a bool type
+ pad_mask = pad_mask.at[:, : horizon - min(horizon, self.num_image_history)].set(0)
# pad_mask = np.ones(self.horizon, dtype=np.float64) # note: this should be of float type, not a bool type
# pad_mask[:self.horizon - self.num_image_history] = 0
return images, pad_mask
- def reset(self, task_description: str) -> None:
- self.task = self.model.create_tasks(texts=[task_description])
- self.task_description = task_description
+ def reset(self, task_descriptions: str) -> None:
+ if isinstance(task_descriptions, str):
+ task_descriptions = [task_descriptions]
+ self.task = self.model.create_tasks(texts=task_descriptions)
+ self.task_description = task_descriptions
self.image_history.clear()
if self.action_ensemble:
self.action_ensembler.reset()
@@ -130,7 +146,7 @@ def reset(self, task_description: str) -> None:
def step(self, image: np.ndarray, task_description: Optional[str] = None, *args, **kwargs) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""
Input:
- image: np.ndarray of shape (H, W, 3), uint8
+ image: np.ndarray/torch tensor of shape (B, H, W, 3), uint8
task_description: Optional[str], task description; if different from previous task description, policy state is reset
Output:
raw_action: dict; raw policy action output
@@ -145,45 +161,48 @@ def step(self, image: np.ndarray, task_description: Optional[str] = None, *args,
# task description has changed; reset the policy state
self.reset(task_description)
- assert image.dtype == np.uint8
+ # assert image.dtype == np.uint8
+ assert len(image.shape) == 4, "image shape should be (batch_size, height, width, 3)"
+ batch_size = image.shape[0]
+ image = torch2jax(image)
image = self._resize_image(image)
self._add_image_to_history(image)
images, pad_mask = self._obtain_image_history_and_mask()
- images, pad_mask = images[None], pad_mask[None]
-
# we need use a different rng key for each model forward step; this has a large impact on model performance
self.rng, key = jax.random.split(self.rng) # each shape [2,]
# print("octo local rng", self.rng, key)
input_observation = {"image_primary": images, "pad_mask": pad_mask}
+ # images.shape (b, h, w, c, 3), pad_mask.shape (b, h)
norm_raw_actions = self.model.sample_actions(
input_observation,
self.task,
rng=key,
)
raw_actions = norm_raw_actions * self.action_std[None] + self.action_mean[None]
- raw_actions = raw_actions[0] # remove batch, becoming (action_pred_horizon, action_dim)
- assert raw_actions.shape == (self.pred_action_horizon, 7)
+ assert raw_actions.shape == (batch_size, self.pred_action_horizon, 7)
if self.action_ensemble:
raw_actions = self.action_ensembler.ensemble_action(raw_actions)
- raw_actions = raw_actions[None] # [1, 7]
-
+ raw_actions = jax2torch(raw_actions)
raw_action = {
- "world_vector": np.array(raw_actions[0, :3]),
- "rotation_delta": np.array(raw_actions[0, 3:6]),
- "open_gripper": np.array(raw_actions[0, 6:7]), # range [0, 1]; 1 = open; 0 = close
+ "world_vector": raw_actions[:, :3],
+ "rotation_delta": raw_actions[:, 3:6],
+ "open_gripper": raw_actions[:, 6:7], # range [0, 1]; 1 = open; 0 = close
}
-
- # process raw_action to obtain the action to be sent to the maniskill2 environment
+ raw_action = common.to_tensor(raw_action)
+
+ # TODO (stao): check if we need torch float 64s.
+ # process raw_action to obtain the action to be sent to the maniskill environment
action = {}
action["world_vector"] = raw_action["world_vector"] * self.action_scale
- action_rotation_delta = np.asarray(raw_action["rotation_delta"], dtype=np.float64)
- roll, pitch, yaw = action_rotation_delta
- action_rotation_ax, action_rotation_angle = euler2axangle(roll, pitch, yaw)
- action_rotation_axangle = action_rotation_ax * action_rotation_angle
- action["rot_axangle"] = action_rotation_axangle * self.action_scale
-
+ # action_rotation_delta = np.asarray(raw_action["rotation_delta"], dtype=np.float64)
+ # roll, pitch, yaw = action_rotation_delta
+ # action_rotation_ax, action_rotation_angle = euler2axangle(roll, pitch, yaw)
+ # action_rotation_axangle = action_rotation_ax * action_rotation_angle
+ # action["rot_axangle"] = action_rotation_axangle * self.action_scale
+ # TODO: is there a better conversion from euler angles to axis angle?
+ action["rot_axangle"] = rotation_conversions.matrix_to_axis_angle(rotation_conversions.euler_angles_to_matrix(raw_action["rotation_delta"], "XYZ"))
if self.policy_setup == "google_robot":
current_gripper_action = raw_action["open_gripper"]
diff --git a/simpler_env/real2sim_eval_maniskill3.py b/simpler_env/real2sim_eval_maniskill3.py
new file mode 100644
index 00000000..b2be73e2
--- /dev/null
+++ b/simpler_env/real2sim_eval_maniskill3.py
@@ -0,0 +1,192 @@
+from collections import defaultdict
+import json
+import os
+import signal
+import time
+import numpy as np
+from typing import Annotated, Optional
+
+import torch
+import tree
+from mani_skill.utils import common
+from mani_skill.utils import visualization
+from mani_skill.utils.visualization.misc import images_to_video
+signal.signal(signal.SIGINT, signal.SIG_DFL) # allow ctrl+c
+from simpler_env.utils.env.observation_utils import get_image_from_maniskill3_obs_dict
+
+import gymnasium as gym
+import numpy as np
+from mani_skill.envs.tasks.digital_twins.bridge_dataset_eval import *
+from mani_skill.envs.sapien_env import BaseEnv
+import tyro
+from dataclasses import dataclass
+from pathlib import Path
+
+@dataclass
+class Args:
+ """
+ This is a script to evaluate policies on real2sim environments. Example command to run:
+
+ XLA_PYTHON_CLIENT_PREALLOCATE=false python real2sim_eval_maniskill3.py \
+ --model="octo-small" -e "PutEggplantInBasketScene-v1" -s 0 --num-episodes 192 --num-envs 64
+ """
+
+
+ env_id: Annotated[str, tyro.conf.arg(aliases=["-e"])] = "PutCarrotOnPlateInScene-v1"
+ """The environment ID of the task you want to simulate. Can be one of
+ PutCarrotOnPlateInScene-v1, PutSpoonOnTableClothInScene-v1, StackGreenCubeOnYellowCubeBakedTexInScene-v1, PutEggplantInBasketScene-v1"""
+
+ shader: str = "default"
+
+ num_envs: int = 1
+ """Number of environments to run. With more than 1 environment the environment will use the GPU backend
+ which runs faster enabling faster large-scale evaluations. Note that the overall behavior of the simulation
+ will be slightly different between CPU and GPU backends."""
+
+ num_episodes: int = 100
+ """Number of episodes to run and record evaluation metrics over"""
+
+ record_dir: str = "videos"
+ """The directory to save videos and results"""
+
+ model: Optional[str] = None
+ """The model to evaluate on the given environment. Can be one of octo-base, octo-small, rt-1x. If not given, random actions are sampled."""
+
+ ckpt_path: str = ""
+ """Checkpoint path for models. Only used for RT models"""
+
+ seed: Annotated[int, tyro.conf.arg(aliases=["-s"])] = 0
+ """Seed the model and environment. Default seed is 0"""
+
+ reset_by_episode_id: bool = True
+ """Whether to reset by fixed episode ids instead of random sampling initial states."""
+
+ info_on_video: bool = False
+ """Whether to write info text onto the video"""
+
+ save_video: bool = True
+ """Whether to save videos"""
+
+ debug: bool = False
+
+def main():
+ args = tyro.cli(Args)
+ if args.seed is not None:
+ np.random.seed(args.seed)
+
+
+ sensor_configs = dict()
+ sensor_configs["shader_pack"] = args.shader
+ env: BaseEnv = gym.make(
+ args.env_id,
+ obs_mode="rgb+segmentation",
+ num_envs=args.num_envs,
+ sensor_configs=sensor_configs
+ )
+ sim_backend = 'gpu' if env.device.type == 'cuda' else 'cpu'
+
+ # Setup up the policy inference model
+ model = None
+ try:
+
+ policy_setup = "widowx_bridge"
+ if args.model is None:
+ pass
+ else:
+ from simpler_env.policies.rt1.rt1_model import RT1Inference
+ from simpler_env.policies.octo.octo_model import OctoInference
+ if args.model == "octo-base" or args.model == "octo-small":
+ model = OctoInference(model_type=args.model, policy_setup=policy_setup, init_rng=args.seed, action_scale=1)
+ elif args.model == "rt-1x":
+ ckpt_path=args.ckpt_path
+ model = RT1Inference(
+ saved_model_path=ckpt_path,
+ policy_setup=policy_setup,
+ action_scale=1,
+ )
+ elif args.model is not None:
+ raise ValueError(f"Model {args.model} does not exist / is not supported.")
+ except:
+ if args.model is not None:
+ raise Exception("SIMPLER Env Policy Inference is not installed")
+
+ model_name = args.model if args.model is not None else "random"
+ if model_name == "random":
+ print("Using random actions.")
+ exp_dir = os.path.join(args.record_dir, f"real2sim_eval/{model_name}_{args.env_id}")
+ Path(exp_dir).mkdir(parents=True, exist_ok=True)
+
+ eval_metrics = defaultdict(list)
+ eps_count = 0
+
+ print(f"Running Real2Sim Evaluation of model {args.model} on environment {args.env_id}")
+ print(f"Using {args.num_envs} environments on the {sim_backend} simulation backend")
+
+ timers = {"env.step+inference": 0, "env.step": 0, "inference": 0, "total": 0}
+ total_start_time = time.time()
+
+ while eps_count < args.num_episodes:
+ seed = args.seed + eps_count
+ obs, _ = env.reset(seed=seed, options={"episode_id": torch.tensor([seed + i for i in range(args.num_envs)])})
+ instruction = env.unwrapped.get_language_instruction()
+ print("instruction:", instruction[0])
+ if model is not None:
+ model.reset(instruction)
+ images = []
+ predicted_terminated, truncated = False, False
+ images.append(get_image_from_maniskill3_obs_dict(env, obs))
+ elapsed_steps = 0
+ while not (predicted_terminated or truncated):
+ if model is not None:
+ start_time = time.time()
+ raw_action, action = model.step(images[-1], instruction)
+ action = torch.cat([action["world_vector"], action["rot_axangle"], action["gripper"]], dim=1)
+ timers["inference"] += time.time() - start_time
+ else:
+ action = env.action_space.sample()
+
+ if elapsed_steps > 0:
+ if args.save_video and args.info_on_video:
+ for i in range(len(images[-1])):
+ images[-1][i] = visualization.put_info_on_image(images[-1][i], tree.map_structure(lambda x: x[i], info))
+
+ start_time = time.time()
+ obs, reward, terminated, truncated, info = env.step(action)
+ timers["env.step"] += time.time() - start_time
+ elapsed_steps += 1
+ info = common.to_numpy(info)
+
+ truncated = bool(truncated.any()) # note that all envs truncate and terminate at the same time.
+ images.append(get_image_from_maniskill3_obs_dict(env, obs))
+
+ for k, v in info.items():
+ eval_metrics[k].append(v.flatten())
+ if args.save_video:
+ for i in range(len(images[-1])):
+ images_to_video([img[i].cpu().numpy() for img in images], exp_dir, f"{sim_backend}_eval_{seed + i}_success={info['success'][i].item()}", fps=10, verbose=True)
+ eps_count += args.num_envs
+ if args.num_envs == 1:
+ print(f"Evaluated episode {eps_count}. Seed {seed}. Results after {eps_count} episodes:")
+ else:
+ print(f"Evaluated {args.num_envs} episodes, seeds {seed} to {eps_count}. Results after {eps_count} episodes:")
+ for k, v in eval_metrics.items():
+ print(f"{k}: {np.mean(v)}")
+ # Print timing information
+ timers["total"] = time.time() - total_start_time
+ timers["env.step+inference"] = timers["env.step"] + timers["inference"]
+ mean_metrics = {k: np.mean(v) for k, v in eval_metrics.items()}
+ mean_metrics["total_episodes"] = eps_count
+ mean_metrics["time/episodes_per_second"] = eps_count / timers["total"]
+ print("Timing Info:")
+ for key, value in timers.items():
+ mean_metrics[f"time/{key}"] = value
+ print(f"{key}: {value:.2f} seconds")
+ metrics_path = os.path.join(exp_dir, f"{sim_backend}_eval_metrics.json")
+ if sim_backend == "gpu":
+ metrics_path = metrics_path.replace("gpu", f"gpu_{args.num_envs}_envs")
+ with open(metrics_path, "w") as f:
+ json.dump(mean_metrics, f, indent=4)
+ print(f"Evaluation complete. Results saved to {exp_dir}. Metrics saved to {metrics_path}")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/simpler_env/utils/action/action_ensemble.py b/simpler_env/utils/action/action_ensemble.py
index 367ccf91..f43c0641 100644
--- a/simpler_env/utils/action/action_ensemble.py
+++ b/simpler_env/utils/action/action_ensemble.py
@@ -1,7 +1,7 @@
from collections import deque
import numpy as np
-
+import jax.numpy as jnp
class ActionEnsembler:
def __init__(self, pred_action_horizon, action_ensemble_temp=0.0):
@@ -18,13 +18,16 @@ def ensemble_action(self, cur_action):
if cur_action.ndim == 1:
curr_act_preds = np.stack(self.action_history)
else:
- curr_act_preds = np.stack(
- [pred_actions[i] for (i, pred_actions) in zip(range(num_actions - 1, -1, -1), self.action_history)]
- )
+ curr_act_preds = jnp.stack(
+ [pred_actions[:, i] for (i, pred_actions) in zip(range(num_actions - 1, -1, -1), self.action_history)]
+ ) # shape (1 to self.pred_action_horizon, batch_size, action_dim)
# more recent predictions get exponentially *less* weight than older predictions
- weights = np.exp(-self.action_ensemble_temp * np.arange(num_actions))
+ weights = jnp.exp(-self.action_ensemble_temp * jnp.arange(num_actions))
weights = weights / weights.sum()
# compute the weighted average across all predictions for this timestep
- cur_action = np.sum(weights[:, None] * curr_act_preds, axis=0)
-
+ # Expand weights to match batch and action dimensions
+ weights_expanded = weights[:, None, None]
+
+ # Apply weights across all batches and sum
+ cur_action = jnp.sum(weights_expanded * curr_act_preds, axis=0)
return cur_action
diff --git a/simpler_env/utils/env/observation_utils.py b/simpler_env/utils/env/observation_utils.py
index 05268425..68e50948 100644
--- a/simpler_env/utils/env/observation_utils.py
+++ b/simpler_env/utils/env/observation_utils.py
@@ -8,3 +8,16 @@ def get_image_from_maniskill2_obs_dict(env, obs, camera_name=None):
else:
raise NotImplementedError()
return obs["image"][camera_name]["rgb"]
+
+def get_image_from_maniskill3_obs_dict(env, obs, camera_name=None):
+ import torch
+ # obtain image from observation dictionary returned by ManiSkill environment
+ if camera_name is None:
+ if "google_robot" in env.unwrapped.robot_uids.uid:
+ camera_name = "overhead_camera"
+ elif "widowx" in env.unwrapped.robot_uids.uid:
+ camera_name = "3rd_view_camera"
+ else:
+ raise NotImplementedError()
+ img = obs["sensor_data"][camera_name]["rgb"]
+ return img.to(torch.uint8)
\ No newline at end of file