@@ -96,7 +96,11 @@ def reset(self, env_ids: Sequence[int] | None = None):
96
96
if env_ids is None :
97
97
env_ids = slice (None )
98
98
# reset accumulative data buffers
99
+ self ._data .pos_w [env_ids ] = 0.0
99
100
self ._data .quat_w [env_ids ] = 0.0
101
+ self ._data .quat_w [env_ids , 0 ] = 1.0
102
+ self ._data .projected_gravity_b [env_ids ] = 0.0
103
+ self ._data .projected_gravity_b [env_ids , 2 ] = - 1.0
100
104
self ._data .lin_vel_b [env_ids ] = 0.0
101
105
self ._data .ang_vel_b [env_ids ] = 0.0
102
106
self ._data .lin_acc_b [env_ids ] = 0.0
@@ -129,6 +133,12 @@ def _initialize_impl(self):
129
133
else :
130
134
raise RuntimeError (f"Failed to find a RigidBodyAPI for the prim paths: { self .cfg .prim_path } " )
131
135
136
+ # Get world gravity
137
+ gravity = self ._physics_sim_view .get_gravity ()
138
+ gravity_dir = torch .tensor ((gravity [0 ], gravity [1 ], gravity [2 ]), device = self .device )
139
+ gravity_dir = math_utils .normalize (gravity_dir .unsqueeze (0 )).squeeze (0 )
140
+ self .GRAVITY_VEC_W = gravity_dir .repeat (self .num_instances , 1 )
141
+
132
142
# Create internal buffers
133
143
self ._initialize_buffers_impl ()
134
144
@@ -161,16 +171,18 @@ def _update_buffers_impl(self, env_ids: Sequence[int]):
161
171
lin_acc_w = (lin_vel_w - self ._prev_lin_vel_w [env_ids ]) / self ._dt + self ._gravity_bias_w [env_ids ]
162
172
ang_acc_w = (ang_vel_w - self ._prev_ang_vel_w [env_ids ]) / self ._dt
163
173
# stack data in world frame and batch rotate
164
- dynamics_data = torch .stack ((lin_vel_w , ang_vel_w , lin_acc_w , ang_acc_w ), dim = 0 )
165
- dynamics_data_rot = math_utils .quat_apply_inverse (self ._data .quat_w [env_ids ].repeat (4 , 1 ), dynamics_data ).chunk (
166
- 4 , dim = 0
174
+ dynamics_data = torch .stack ((lin_vel_w , ang_vel_w , lin_acc_w , ang_acc_w , self . GRAVITY_VEC_W [ env_ids ] ), dim = 0 )
175
+ dynamics_data_rot = math_utils .quat_apply_inverse (self ._data .quat_w [env_ids ].repeat (5 , 1 ), dynamics_data ).chunk (
176
+ 5 , dim = 0
167
177
)
168
178
# store the velocities.
169
179
self ._data .lin_vel_b [env_ids ] = dynamics_data_rot [0 ]
170
180
self ._data .ang_vel_b [env_ids ] = dynamics_data_rot [1 ]
171
181
# store the accelerations
172
182
self ._data .lin_acc_b [env_ids ] = dynamics_data_rot [2 ]
173
183
self ._data .ang_acc_b [env_ids ] = dynamics_data_rot [3 ]
184
+ # store projected gravity
185
+ self ._data .projected_gravity_b [env_ids ] = dynamics_data_rot [4 ]
174
186
175
187
self ._prev_lin_vel_w [env_ids ] = lin_vel_w
176
188
self ._prev_ang_vel_w [env_ids ] = ang_vel_w
@@ -181,6 +193,7 @@ def _initialize_buffers_impl(self):
181
193
self ._data .pos_w = torch .zeros (self ._view .count , 3 , device = self ._device )
182
194
self ._data .quat_w = torch .zeros (self ._view .count , 4 , device = self ._device )
183
195
self ._data .quat_w [:, 0 ] = 1.0
196
+ self ._data .projected_gravity_b = torch .zeros (self ._view .count , 3 , device = self ._device )
184
197
self ._data .lin_vel_b = torch .zeros_like (self ._data .pos_w )
185
198
self ._data .ang_vel_b = torch .zeros_like (self ._data .pos_w )
186
199
self ._data .lin_acc_b = torch .zeros_like (self ._data .pos_w )
0 commit comments