Skip to content

Commit 99d004f

Browse files
committed
Refactor: ruff format
1 parent a1b5b4b commit 99d004f

File tree

3 files changed

+48
-37
lines changed

3 files changed

+48
-37
lines changed

src/lerobot/policies/conrft/configuration_conrft.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
2424
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
2525
from lerobot.optim.optimizers import MultiAdamConfig
26-
2726
from lerobot.policies.sac.configuration_sac import (
2827
ActorLearnerConfig,
2928
ConcurrencyConfig,

src/lerobot/policies/conrft/modeling_conrft.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
from lerobot.policies.octo.modeling_octo import OctoPolicy
4343
from lerobot.policies.pretrained import PreTrainedPolicy
4444
from lerobot.policies.sac.modeling_sac import (
45+
MLP,
4546
CriticEnsemble,
4647
CriticHead,
47-
MLP,
4848
SACObservationEncoder,
4949
_convert_normalization_params_to_tensor,
5050
)
@@ -120,7 +120,7 @@ def _init_encoders(self):
120120
self.octo_policy,
121121
use_proprio=self.config.use_proprio,
122122
state_dim=self.config.state_dim,
123-
proprio_latent_dim=self.config.proprio_latent_dim
123+
proprio_latent_dim=self.config.proprio_latent_dim,
124124
)
125125

126126
def _init_consistency_policy(self, continuous_action_dim):
@@ -162,7 +162,7 @@ def _init_critics(self, continuous_action_dim):
162162
target_heads = [
163163
CriticHead(
164164
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
165-
**asdict(self.config.critic_network_kwargs)
165+
**asdict(self.config.critic_network_kwargs),
166166
)
167167
for _ in range(self.config.num_critics)
168168
]
@@ -245,7 +245,9 @@ def forward(
245245
def update_target_networks(self):
246246
"""Update target networks with soft updates"""
247247
tau = self.config.soft_target_update_rate
248-
for target_param, param in zip(self.critic_ensemble_target.parameters(), self.critic_ensemble.parameters(), strict=False):
248+
for target_param, param in zip(
249+
self.critic_ensemble_target.parameters(), self.critic_ensemble.parameters(), strict=False
250+
):
249251
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
250252

251253
def set_training_stage(self, stage: str):
@@ -319,7 +321,7 @@ def compute_cal_ql_loss(self, batch):
319321
# TODO(lilkm): Get indices before forward pass to avoid unnecessary computation
320322
# Subsample critics
321323
if self.config.num_subsample_critics is not None:
322-
indices = torch.randperm(self.config.num_critics)[:self.config.num_subsample_critics]
324+
indices = torch.randperm(self.config.num_critics)[: self.config.num_subsample_critics]
323325
target_q_values = target_q_values[indices]
324326

325327
target_q = torch.min(target_q_values, dim=0)[0]
@@ -330,7 +332,7 @@ def compute_cal_ql_loss(self, batch):
330332

331333
# TODO(lilkm): Get indices before forward pass to avoid unnecessary computation
332334
if self.config.num_subsample_critics is not None:
333-
indices = torch.randperm(self.config.num_critics)[:self.config.num_subsample_critics]
335+
indices = torch.randperm(self.config.num_critics)[: self.config.num_subsample_critics]
334336
current_q_values = current_q_values[indices]
335337
critic_size = self.config.num_subsample_critics
336338
else:
@@ -406,7 +408,9 @@ def compute_cal_ql_loss(self, batch):
406408
cql_q_samples = torch.cat([cql_q_samples, current_q_expanded], dim=-1)
407409

408410
# Subtract log(num_samples) * temperature
409-
cql_q_samples = cql_q_samples - torch.log(torch.tensor(cql_q_samples.shape[-1])) * self.config.cql_temp
411+
cql_q_samples = (
412+
cql_q_samples - torch.log(torch.tensor(cql_q_samples.shape[-1])) * self.config.cql_temp
413+
)
410414

411415
# Compute logsumexp of OOD actions
412416
cql_ood_values = torch.logsumexp(cql_q_samples / self.config.cql_temp, dim=-1) * self.config.cql_temp
@@ -444,7 +448,9 @@ def compute_bc_loss(self, batch):
444448
indices = torch.randint(0, self.config.num_scales - 1, (batch_size,), device=device)
445449

446450
# Compute sigma values using the same formula as JAX
447-
t = (self.config.sigma_max**(1 / self.config.rho) + indices / (self.config.num_scales - 1) * (self.config.sigma_min**(1 / self.config.rho) - self.config.sigma_max**(1 / self.config.rho)))
451+
t = self.config.sigma_max ** (1 / self.config.rho) + indices / (self.config.num_scales - 1) * (
452+
self.config.sigma_min ** (1 / self.config.rho) - self.config.sigma_max ** (1 / self.config.rho)
453+
)
448454
t = t**self.config.rho
449455

450456
# Add noise to actions
@@ -483,7 +489,7 @@ def compute_actor_loss(self, batch):
483489
q_values = self.critic_ensemble(observations, policy_action, obs_feat)
484490
# TODO(lilkm): Get indices before forward pass to avoid unnecessary computation
485491
if self.config.num_subsample_critics is not None:
486-
indices = torch.randperm(self.config.num_critics)[:self.config.num_subsample_critics]
492+
indices = torch.randperm(self.config.num_critics)[: self.config.num_subsample_critics]
487493
q_values = q_values[indices]
488494
q_value = q_values.mean(dim=0)
489495
q_loss = -q_value.mean() # Negative for gradient ascent
@@ -500,8 +506,8 @@ def compute_actor_loss(self, batch):
500506
"q_mean": q_value.mean(),
501507
}
502508

503-
class SinusoidalPosEmb(nn.Module):
504509

510+
class SinusoidalPosEmb(nn.Module):
505511
def __init__(self, dim):
506512
super().__init__()
507513
self.dim = dim
@@ -515,7 +521,6 @@ def forward(self, time):
515521

516522

517523
class TimeMLP(nn.Module):
518-
519524
def __init__(self, t_dim: int):
520525
super().__init__()
521526
self.t_dim = t_dim
@@ -534,7 +539,13 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
534539
class OctoEncodingWrapper(nn.Module):
535540
"""Wrapper around Octo transformer to extract action embeddings for ConRFT."""
536541

537-
def __init__(self, octo_policy: OctoPolicy, use_proprio: bool = True, state_dim: int = 18, proprio_latent_dim: int = 64):
542+
def __init__(
543+
self,
544+
octo_policy: OctoPolicy,
545+
use_proprio: bool = True,
546+
state_dim: int = 18,
547+
proprio_latent_dim: int = 64,
548+
):
538549
super().__init__()
539550
self.octo_policy = octo_policy
540551
self.octo_transformer = octo_policy.model.octo_transformer
@@ -549,7 +560,7 @@ def __init__(self, octo_policy: OctoPolicy, use_proprio: bool = True, state_dim:
549560
self.proprio_encoder = nn.Sequential(
550561
nn.Linear(state_dim, self.proprio_latent_dim),
551562
nn.LayerNorm(self.proprio_latent_dim),
552-
nn.Tanh()
563+
nn.Tanh(),
553564
)
554565

555566
def forward(
@@ -584,7 +595,6 @@ def forward(
584595
# mask_expanded = mask.view(batch_size, 1, 1, 1, 1)
585596
# image_wrist = torch.where(mask_expanded, torch.zeros_like(image_wrist), image_wrist)
586597

587-
588598
# Get transformer outputs
589599
transformer_outputs = self.octo_transformer(obs, task_dict, timestep_pad_mask)
590600

@@ -593,13 +603,13 @@ def forward(
593603

594604
# Extract the actual tensor from TimestepGroup
595605
# TimestepGroup has .tokens attribute containing the tensor
596-
if hasattr(action_embeddings, 'tokens'):
606+
if hasattr(action_embeddings, "tokens"):
597607
action_embeddings = action_embeddings.tokens
598608

599609
# TODO(lilkm): check this
600610
# Mean over tokens and take last timestep like JAX
601611
action_embeddings = action_embeddings.mean(dim=-2) # Mean over tokens
602-
action_embeddings = action_embeddings[:, -1, :] # Take last timestep
612+
action_embeddings = action_embeddings[:, -1, :] # Take last timestep
603613

604614
# # Flatten to [batch_size, embedding_dim] for consistency policy
605615
# # action_embeddings shape: [batch_size, horizon, n_tokens, embedding_dim]
@@ -685,7 +695,9 @@ def forward(
685695
repeat: int = 1,
686696
) -> tuple[Tensor, Tensor]:
687697
"""Forward pass of consistency policy"""
688-
obs_enc, action_embeddings = self.encoder(observations, tasks=tasks, action_embeddings=action_embeddings)
698+
obs_enc, action_embeddings = self.encoder(
699+
observations, tasks=tasks, action_embeddings=action_embeddings
700+
)
689701

690702
device = get_device_from_parameters(self)
691703
batch_size = obs_enc.shape[0]
@@ -719,7 +731,10 @@ def base_network(
719731
repeat: int = 1,
720732
) -> Tensor:
721733
# Get scaling factors and ensure proper dimensions
722-
c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in get_scalings_for_boundary_condition(sigmas, self.sigma_data, self.sigma_min)]
734+
c_skip, c_out, c_in = [
735+
append_dims(x, x_t.ndim)
736+
for x in get_scalings_for_boundary_condition(sigmas, self.sigma_data, self.sigma_min)
737+
]
723738

724739
# Time embedding
725740
rescaled_t = 1000 * 0.25 * torch.log(sigmas + 1e-44)
@@ -749,17 +764,17 @@ def get_features(self, observations):
749764
def get_sigmas_karras(num_scales, sigma_min, sigma_max, rho, device="cpu"):
750765
"""Generate Karras noise schedule"""
751766
ramp = torch.linspace(0, 1, num_scales, device=device)
752-
min_inv_rho = sigma_min**(1 / rho)
753-
max_inv_rho = sigma_max**(1 / rho)
754-
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
767+
min_inv_rho = sigma_min ** (1 / rho)
768+
max_inv_rho = sigma_max ** (1 / rho)
769+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
755770
# Append zero for the final step
756771
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
757772
return sigmas
758773

759774

760775
def get_scalings_for_boundary_condition(sigma, sigma_data, sigma_min):
761776
"""Get c_skip, c_out, c_in scalings for boundary condition"""
762-
c_skip = sigma_data**2 / ((sigma - sigma_min)**2 + sigma_data**2)
777+
c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2)
763778
c_out = (sigma - sigma_min) * sigma_data / torch.sqrt(sigma**2 + sigma_data**2)
764779
c_in = 1 / torch.sqrt(sigma**2 + sigma_data**2)
765780
return c_skip, c_out, c_in

src/lerobot/scripts/rl/learner.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,10 @@ def start_learner_threads(
197197
"""
198198
# Check if this is offline-only ConRFT training (no gRPC server needed)
199199
is_conrft_offline_only = (
200-
cfg.policy.type == "conrft" and
201-
cfg.policy.offline_steps > 0 and
202-
cfg.policy.online_steps == 0 and
203-
cfg.dataset is not None
200+
cfg.policy.type == "conrft"
201+
and cfg.policy.offline_steps > 0
202+
and cfg.policy.online_steps == 0
203+
and cfg.dataset is not None
204204
)
205205

206206
if is_conrft_offline_only:
@@ -391,10 +391,10 @@ def add_actor_information_and_train(
391391

392392
# Check if this is ConRFT offline-only training
393393
is_conrft_offline_only = (
394-
is_conrft and
395-
cfg.policy.offline_steps > 0 and
396-
cfg.policy.online_steps == 0 and
397-
offline_replay_buffer is not None
394+
is_conrft
395+
and cfg.policy.offline_steps > 0
396+
and cfg.policy.online_steps == 0
397+
and offline_replay_buffer is not None
398398
)
399399

400400
if is_conrft_offline_only:
@@ -1397,8 +1397,7 @@ def run_conrft_offline_training(
13971397
optimizers["critic"].zero_grad()
13981398
loss_critic.backward()
13991399
torch.nn.utils.clip_grad_norm_(
1400-
parameters=policy.critic_ensemble.parameters(),
1401-
max_norm=clip_grad_norm_value
1400+
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
14021401
)
14031402
optimizers["critic"].step()
14041403

@@ -1438,8 +1437,7 @@ def run_conrft_offline_training(
14381437
optimizers["critic"].zero_grad()
14391438
loss_critic.backward()
14401439
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
1441-
parameters=policy.critic_ensemble.parameters(),
1442-
max_norm=clip_grad_norm_value
1440+
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
14431441
).item()
14441442
optimizers["critic"].step()
14451443

@@ -1467,8 +1465,7 @@ def run_conrft_offline_training(
14671465
optimizers["actor"].zero_grad()
14681466
loss_actor.backward()
14691467
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
1470-
parameters=policy.consistency_policy.parameters(),
1471-
max_norm=clip_grad_norm_value
1468+
parameters=policy.consistency_policy.parameters(), max_norm=clip_grad_norm_value
14721469
).item()
14731470
optimizers["actor"].step()
14741471

0 commit comments

Comments
 (0)