From cb96d728232c822fe0f09647d3bfa5672a2e1df0 Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Sun, 29 Dec 2024 16:37:49 +0900 Subject: [PATCH] Align the variable names with the standard scheduler --- src/ncsn/pipeline_ncsn.py | 20 ++++++++++---------- src/ncsn/scheduler/scheduling_ncsn.py | 10 +++++----- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/ncsn/pipeline_ncsn.py b/src/ncsn/pipeline_ncsn.py index f42f795..46d1cc9 100644 --- a/src/ncsn/pipeline_ncsn.py +++ b/src/ncsn/pipeline_ncsn.py @@ -128,8 +128,8 @@ def __call__( # Generate a random sample # NOTE: The behavior of random number generation is different between CPU and GPU, # so first generate random numbers on CPU and then move them to GPU (if available). - samples = torch.rand(samples_shape, generator=generator) - samples = samples.to(self.device) + sample = torch.rand(samples_shape, generator=generator) + sample = sample.to(self.device) # Set the number of inference steps for the scheduler self.scheduler.set_timesteps(num_inference_steps) @@ -139,17 +139,17 @@ def __call__( # Perform `num_annnealed_steps` annealing steps for i in range(self.scheduler.num_annealed_steps): # Predict the score using the model - model_output = self.unet(samples, t).sample # type: ignore + model_output = self.unet(sample, t).sample # type: ignore # Perform the annealed langevin dynamics output = self.scheduler.step( model_output=model_output, timestep=t, - samples=samples, + sample=sample, generator=generator, return_dict=return_dict, ) - samples = ( + sample = ( output.prev_sample if isinstance(output, AnnealedLangevinDynamicsOutput) else output[0] @@ -162,14 +162,14 @@ def __call__( callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - samples = callback_outputs.pop("samples", samples) + sample = callback_outputs.pop("samples", sample) - samples = self.decode_samples(samples) + sample = self.decode_samples(sample) if output_type == "pil": - samples = self.numpy_to_pil(samples.cpu().numpy()) + sample = self.numpy_to_pil(sample.cpu().numpy()) if return_dict: - return ImagePipelineOutput(images=samples) # type: ignore + return ImagePipelineOutput(images=sample) # type: ignore else: - return (samples,) + return (sample,) diff --git a/src/ncsn/scheduler/scheduling_ncsn.py b/src/ncsn/scheduler/scheduling_ncsn.py index b22cb90..f31145e 100644 --- a/src/ncsn/scheduler/scheduling_ncsn.py +++ b/src/ncsn/scheduler/scheduling_ncsn.py @@ -105,18 +105,18 @@ def step( self, model_output: torch.Tensor, timestep: int, - samples: torch.Tensor, + sample: torch.Tensor, return_dict: bool = True, **kwargs, ) -> Union[AnnealedLangevinDynamicsOutput, Tuple]: - z = torch.randn_like(samples) + z = torch.randn_like(sample) step_size = self.step_size[timestep] - samples = samples + 0.5 * step_size * model_output + torch.sqrt(step_size) * z + sample = sample + 0.5 * step_size * model_output + torch.sqrt(step_size) * z if return_dict: - return AnnealedLangevinDynamicsOutput(prev_sample=samples) + return AnnealedLangevinDynamicsOutput(prev_sample=sample) else: - return (samples,) + return (sample,) def add_noise( self,