42
42
from lerobot .policies .octo .modeling_octo import OctoPolicy
43
43
from lerobot .policies .pretrained import PreTrainedPolicy
44
44
from lerobot .policies .sac .modeling_sac import (
45
+ MLP ,
45
46
CriticEnsemble ,
46
47
CriticHead ,
47
- MLP ,
48
48
SACObservationEncoder ,
49
49
_convert_normalization_params_to_tensor ,
50
50
)
@@ -120,7 +120,7 @@ def _init_encoders(self):
120
120
self .octo_policy ,
121
121
use_proprio = self .config .use_proprio ,
122
122
state_dim = self .config .state_dim ,
123
- proprio_latent_dim = self .config .proprio_latent_dim
123
+ proprio_latent_dim = self .config .proprio_latent_dim ,
124
124
)
125
125
126
126
def _init_consistency_policy (self , continuous_action_dim ):
@@ -162,7 +162,7 @@ def _init_critics(self, continuous_action_dim):
162
162
target_heads = [
163
163
CriticHead (
164
164
input_dim = self .encoder_critic .output_dim + continuous_action_dim ,
165
- ** asdict (self .config .critic_network_kwargs )
165
+ ** asdict (self .config .critic_network_kwargs ),
166
166
)
167
167
for _ in range (self .config .num_critics )
168
168
]
@@ -245,7 +245,9 @@ def forward(
245
245
def update_target_networks (self ):
246
246
"""Update target networks with soft updates"""
247
247
tau = self .config .soft_target_update_rate
248
- for target_param , param in zip (self .critic_ensemble_target .parameters (), self .critic_ensemble .parameters (), strict = False ):
248
+ for target_param , param in zip (
249
+ self .critic_ensemble_target .parameters (), self .critic_ensemble .parameters (), strict = False
250
+ ):
249
251
target_param .data .copy_ (tau * param .data + (1 - tau ) * target_param .data )
250
252
251
253
def set_training_stage (self , stage : str ):
@@ -319,7 +321,7 @@ def compute_cal_ql_loss(self, batch):
319
321
# TODO(lilkm): Get indices before forward pass to avoid unnecessary computation
320
322
# Subsample critics
321
323
if self .config .num_subsample_critics is not None :
322
- indices = torch .randperm (self .config .num_critics )[:self .config .num_subsample_critics ]
324
+ indices = torch .randperm (self .config .num_critics )[: self .config .num_subsample_critics ]
323
325
target_q_values = target_q_values [indices ]
324
326
325
327
target_q = torch .min (target_q_values , dim = 0 )[0 ]
@@ -330,7 +332,7 @@ def compute_cal_ql_loss(self, batch):
330
332
331
333
# TODO(lilkm): Get indices before forward pass to avoid unnecessary computation
332
334
if self .config .num_subsample_critics is not None :
333
- indices = torch .randperm (self .config .num_critics )[:self .config .num_subsample_critics ]
335
+ indices = torch .randperm (self .config .num_critics )[: self .config .num_subsample_critics ]
334
336
current_q_values = current_q_values [indices ]
335
337
critic_size = self .config .num_subsample_critics
336
338
else :
@@ -406,7 +408,9 @@ def compute_cal_ql_loss(self, batch):
406
408
cql_q_samples = torch .cat ([cql_q_samples , current_q_expanded ], dim = - 1 )
407
409
408
410
# Subtract log(num_samples) * temperature
409
- cql_q_samples = cql_q_samples - torch .log (torch .tensor (cql_q_samples .shape [- 1 ])) * self .config .cql_temp
411
+ cql_q_samples = (
412
+ cql_q_samples - torch .log (torch .tensor (cql_q_samples .shape [- 1 ])) * self .config .cql_temp
413
+ )
410
414
411
415
# Compute logsumexp of OOD actions
412
416
cql_ood_values = torch .logsumexp (cql_q_samples / self .config .cql_temp , dim = - 1 ) * self .config .cql_temp
@@ -444,7 +448,9 @@ def compute_bc_loss(self, batch):
444
448
indices = torch .randint (0 , self .config .num_scales - 1 , (batch_size ,), device = device )
445
449
446
450
# Compute sigma values using the same formula as JAX
447
- t = (self .config .sigma_max ** (1 / self .config .rho ) + indices / (self .config .num_scales - 1 ) * (self .config .sigma_min ** (1 / self .config .rho ) - self .config .sigma_max ** (1 / self .config .rho )))
451
+ t = self .config .sigma_max ** (1 / self .config .rho ) + indices / (self .config .num_scales - 1 ) * (
452
+ self .config .sigma_min ** (1 / self .config .rho ) - self .config .sigma_max ** (1 / self .config .rho )
453
+ )
448
454
t = t ** self .config .rho
449
455
450
456
# Add noise to actions
@@ -483,7 +489,7 @@ def compute_actor_loss(self, batch):
483
489
q_values = self .critic_ensemble (observations , policy_action , obs_feat )
484
490
# TODO(lilkm): Get indices before forward pass to avoid unnecessary computation
485
491
if self .config .num_subsample_critics is not None :
486
- indices = torch .randperm (self .config .num_critics )[:self .config .num_subsample_critics ]
492
+ indices = torch .randperm (self .config .num_critics )[: self .config .num_subsample_critics ]
487
493
q_values = q_values [indices ]
488
494
q_value = q_values .mean (dim = 0 )
489
495
q_loss = - q_value .mean () # Negative for gradient ascent
@@ -500,8 +506,8 @@ def compute_actor_loss(self, batch):
500
506
"q_mean" : q_value .mean (),
501
507
}
502
508
503
- class SinusoidalPosEmb (nn .Module ):
504
509
510
+ class SinusoidalPosEmb (nn .Module ):
505
511
def __init__ (self , dim ):
506
512
super ().__init__ ()
507
513
self .dim = dim
@@ -515,7 +521,6 @@ def forward(self, time):
515
521
516
522
517
523
class TimeMLP (nn .Module ):
518
-
519
524
def __init__ (self , t_dim : int ):
520
525
super ().__init__ ()
521
526
self .t_dim = t_dim
@@ -534,7 +539,13 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
534
539
class OctoEncodingWrapper (nn .Module ):
535
540
"""Wrapper around Octo transformer to extract action embeddings for ConRFT."""
536
541
537
- def __init__ (self , octo_policy : OctoPolicy , use_proprio : bool = True , state_dim : int = 18 , proprio_latent_dim : int = 64 ):
542
+ def __init__ (
543
+ self ,
544
+ octo_policy : OctoPolicy ,
545
+ use_proprio : bool = True ,
546
+ state_dim : int = 18 ,
547
+ proprio_latent_dim : int = 64 ,
548
+ ):
538
549
super ().__init__ ()
539
550
self .octo_policy = octo_policy
540
551
self .octo_transformer = octo_policy .model .octo_transformer
@@ -549,7 +560,7 @@ def __init__(self, octo_policy: OctoPolicy, use_proprio: bool = True, state_dim:
549
560
self .proprio_encoder = nn .Sequential (
550
561
nn .Linear (state_dim , self .proprio_latent_dim ),
551
562
nn .LayerNorm (self .proprio_latent_dim ),
552
- nn .Tanh ()
563
+ nn .Tanh (),
553
564
)
554
565
555
566
def forward (
@@ -584,7 +595,6 @@ def forward(
584
595
# mask_expanded = mask.view(batch_size, 1, 1, 1, 1)
585
596
# image_wrist = torch.where(mask_expanded, torch.zeros_like(image_wrist), image_wrist)
586
597
587
-
588
598
# Get transformer outputs
589
599
transformer_outputs = self .octo_transformer (obs , task_dict , timestep_pad_mask )
590
600
@@ -593,13 +603,13 @@ def forward(
593
603
594
604
# Extract the actual tensor from TimestepGroup
595
605
# TimestepGroup has .tokens attribute containing the tensor
596
- if hasattr (action_embeddings , ' tokens' ):
606
+ if hasattr (action_embeddings , " tokens" ):
597
607
action_embeddings = action_embeddings .tokens
598
608
599
609
# TODO(lilkm): check this
600
610
# Mean over tokens and take last timestep like JAX
601
611
action_embeddings = action_embeddings .mean (dim = - 2 ) # Mean over tokens
602
- action_embeddings = action_embeddings [:, - 1 , :] # Take last timestep
612
+ action_embeddings = action_embeddings [:, - 1 , :] # Take last timestep
603
613
604
614
# # Flatten to [batch_size, embedding_dim] for consistency policy
605
615
# # action_embeddings shape: [batch_size, horizon, n_tokens, embedding_dim]
@@ -685,7 +695,9 @@ def forward(
685
695
repeat : int = 1 ,
686
696
) -> tuple [Tensor , Tensor ]:
687
697
"""Forward pass of consistency policy"""
688
- obs_enc , action_embeddings = self .encoder (observations , tasks = tasks , action_embeddings = action_embeddings )
698
+ obs_enc , action_embeddings = self .encoder (
699
+ observations , tasks = tasks , action_embeddings = action_embeddings
700
+ )
689
701
690
702
device = get_device_from_parameters (self )
691
703
batch_size = obs_enc .shape [0 ]
@@ -719,7 +731,10 @@ def base_network(
719
731
repeat : int = 1 ,
720
732
) -> Tensor :
721
733
# Get scaling factors and ensure proper dimensions
722
- c_skip , c_out , c_in = [append_dims (x , x_t .ndim ) for x in get_scalings_for_boundary_condition (sigmas , self .sigma_data , self .sigma_min )]
734
+ c_skip , c_out , c_in = [
735
+ append_dims (x , x_t .ndim )
736
+ for x in get_scalings_for_boundary_condition (sigmas , self .sigma_data , self .sigma_min )
737
+ ]
723
738
724
739
# Time embedding
725
740
rescaled_t = 1000 * 0.25 * torch .log (sigmas + 1e-44 )
@@ -749,17 +764,17 @@ def get_features(self, observations):
749
764
def get_sigmas_karras (num_scales , sigma_min , sigma_max , rho , device = "cpu" ):
750
765
"""Generate Karras noise schedule"""
751
766
ramp = torch .linspace (0 , 1 , num_scales , device = device )
752
- min_inv_rho = sigma_min ** (1 / rho )
753
- max_inv_rho = sigma_max ** (1 / rho )
754
- sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho ))** rho
767
+ min_inv_rho = sigma_min ** (1 / rho )
768
+ max_inv_rho = sigma_max ** (1 / rho )
769
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
755
770
# Append zero for the final step
756
771
sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
757
772
return sigmas
758
773
759
774
760
775
def get_scalings_for_boundary_condition (sigma , sigma_data , sigma_min ):
761
776
"""Get c_skip, c_out, c_in scalings for boundary condition"""
762
- c_skip = sigma_data ** 2 / ((sigma - sigma_min )** 2 + sigma_data ** 2 )
777
+ c_skip = sigma_data ** 2 / ((sigma - sigma_min ) ** 2 + sigma_data ** 2 )
763
778
c_out = (sigma - sigma_min ) * sigma_data / torch .sqrt (sigma ** 2 + sigma_data ** 2 )
764
779
c_in = 1 / torch .sqrt (sigma ** 2 + sigma_data ** 2 )
765
780
return c_skip , c_out , c_in
0 commit comments