@@ -96,7 +96,11 @@ def reset(self, env_ids: Sequence[int] | None = None):
96
96
if env_ids is None :
97
97
env_ids = slice (None )
98
98
# reset accumulative data buffers
99
+ self ._data .pos_w [env_ids ] = 0.0
99
100
self ._data .quat_w [env_ids ] = 0.0
101
+ self ._data .quat_w [env_ids , 0 ] = 1.0
102
+ self ._data .projected_gravity_b [env_ids ] = 0.0
103
+ self ._data .projected_gravity_b [env_ids , 2 ] = - 1.0
100
104
self ._data .lin_vel_b [env_ids ] = 0.0
101
105
self ._data .ang_vel_b [env_ids ] = 0.0
102
106
self ._data .lin_acc_b [env_ids ] = 0.0
@@ -135,6 +139,12 @@ def _initialize_impl(self):
135
139
else :
136
140
raise RuntimeError (f"Failed to find a RigidBodyAPI for the prim paths: { self .cfg .prim_path } " )
137
141
142
+ # Get world gravity
143
+ gravity = self ._physics_sim_view .get_gravity ()
144
+ gravity_dir = torch .tensor ((gravity [0 ], gravity [1 ], gravity [2 ]), device = self .device )
145
+ gravity_dir = math_utils .normalize (gravity_dir .unsqueeze (0 )).squeeze (0 )
146
+ self .GRAVITY_VEC_W = gravity_dir .repeat (self .num_instances , 1 )
147
+
138
148
# Create internal buffers
139
149
self ._initialize_buffers_impl ()
140
150
@@ -167,16 +177,18 @@ def _update_buffers_impl(self, env_ids: Sequence[int]):
167
177
lin_acc_w = (lin_vel_w - self ._prev_lin_vel_w [env_ids ]) / self ._dt + self ._gravity_bias_w [env_ids ]
168
178
ang_acc_w = (ang_vel_w - self ._prev_ang_vel_w [env_ids ]) / self ._dt
169
179
# stack data in world frame and batch rotate
170
- dynamics_data = torch .stack ((lin_vel_w , ang_vel_w , lin_acc_w , ang_acc_w ), dim = 0 )
171
- dynamics_data_rot = math_utils .quat_apply_inverse (self ._data .quat_w [env_ids ].repeat (4 , 1 ), dynamics_data ).chunk (
172
- 4 , dim = 0
180
+ dynamics_data = torch .stack ((lin_vel_w , ang_vel_w , lin_acc_w , ang_acc_w , self . GRAVITY_VEC_W [ env_ids ] ), dim = 0 )
181
+ dynamics_data_rot = math_utils .quat_apply_inverse (self ._data .quat_w [env_ids ].repeat (5 , 1 ), dynamics_data ).chunk (
182
+ 5 , dim = 0
173
183
)
174
184
# store the velocities.
175
185
self ._data .lin_vel_b [env_ids ] = dynamics_data_rot [0 ]
176
186
self ._data .ang_vel_b [env_ids ] = dynamics_data_rot [1 ]
177
187
# store the accelerations
178
188
self ._data .lin_acc_b [env_ids ] = dynamics_data_rot [2 ]
179
189
self ._data .ang_acc_b [env_ids ] = dynamics_data_rot [3 ]
190
+ # store projected gravity
191
+ self ._data .projected_gravity_b [env_ids ] = dynamics_data_rot [4 ]
180
192
181
193
self ._prev_lin_vel_w [env_ids ] = lin_vel_w
182
194
self ._prev_ang_vel_w [env_ids ] = ang_vel_w
@@ -187,6 +199,7 @@ def _initialize_buffers_impl(self):
187
199
self ._data .pos_w = torch .zeros (self ._view .count , 3 , device = self ._device )
188
200
self ._data .quat_w = torch .zeros (self ._view .count , 4 , device = self ._device )
189
201
self ._data .quat_w [:, 0 ] = 1.0
202
+ self ._data .projected_gravity_b = torch .zeros (self ._view .count , 3 , device = self ._device )
190
203
self ._data .lin_vel_b = torch .zeros_like (self ._data .pos_w )
191
204
self ._data .ang_vel_b = torch .zeros_like (self ._data .pos_w )
192
205
self ._data .lin_acc_b = torch .zeros_like (self ._data .pos_w )
0 commit comments