diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index 527b056df..c4a80c561 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -28,12 +28,12 @@ def get_dataloaders(patch_shape, data_path, cell_type=None): train_loader = get_livecell_loader( path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16, cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, - raw_transform=raw_transform, label_dtype=torch.float32 + raw_transform=raw_transform, label_dtype=torch.float32, ) val_loader = get_livecell_loader( path=data_path, patch_shape=patch_shape, split="val", batch_size=1, num_workers=16, cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, - raw_transform=raw_transform, label_dtype=torch.float32 + raw_transform=raw_transform, label_dtype=torch.float32, ) return train_loader, val_loader @@ -62,6 +62,14 @@ def finetune_livecell(args): # lora classic: 48.46 GB # full ft: 49.35 + # NOTE: for default settings + # loss over logits: 47.89 GB + # loss over masks: 49.42 GB + + # for resource-efficient setting (i.e. n_objects=5 for 'vit_b') + # loss over logits: 24.01 GB + # loss over masks: 24.38 GB + # Run training. sam_training.train_sam( name=checkpoint_name, diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 7951ebb9f..df9cc6c90 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -5,7 +5,6 @@ import torch from torch.utils.tensorboard import SummaryWriter -from torchvision.utils import make_grid from .sam_trainer import SamTrainer @@ -99,8 +98,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): with forward_context(): # 1. train for the interactive segmentation - (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y) = self._interactive_train_iteration(x, labels_instances) + loss, mask_loss, iou_regression_loss, model_iou = self._interactive_train_iteration(x, labels_instances) backprop(loss) @@ -114,9 +112,8 @@ def _train_epoch_impl(self, progress, forward_context, backprop): if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None self.logger.log_train( - self._iteration, loss, lr, x, labels_instances, samples, + self._iteration, loss, lr, x, labels_instances, mask_loss, iou_regression_loss, model_iou, unetr_loss ) @@ -147,7 +144,7 @@ def _validate_impl(self, forward_context): with forward_context(): # 1. validate for the interactive segmentation (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) + metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) with forward_context(): # 2. validate for the automatic instance segmentation @@ -164,7 +161,7 @@ def _validate_impl(self, forward_context): if self.logger is not None: self.logger.log_validation( - self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y, + self._iteration, metric_val, loss_val, x, labels_instances, mask_loss, iou_regression_loss, model_iou_val, unetr_loss ) @@ -175,25 +172,22 @@ class JointSamLogger(TorchEmLogger): """@private""" def __init__(self, trainer, save_root, **unused_kwargs): super().__init__(trainer, save_root) - self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ - os.path.join(save_root, "logs", trainer.name) + self.log_dir = f"./logs/{trainer.name}" if save_root is None else os.path.join(save_root, "logs", trainer.name) os.makedirs(self.log_dir, exist_ok=True) self.tb = SummaryWriter(self.log_dir) self.log_image_interval = trainer.log_image_interval - def add_image(self, x, y, samples, name, step): + def add_image(self, x, y, name, step): selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] image = normalize_im(x[selection].cpu()) self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step) self.tb.add_image(tag=f"{name}/target", img_tensor=y[selection], global_step=step) - sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) - self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) def log_train( - self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss + self, step, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou, instance_loss ): self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) @@ -202,10 +196,10 @@ def log_train( self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) if step % self.log_image_interval == 0: - self.add_image(x, y, samples, "train", step) + self.add_image(x, y, "train", step) def log_validation( - self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss + self, step, metric, loss, x, y, mask_loss, iou_regression_loss, model_iou, instance_loss ): self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) @@ -213,4 +207,4 @@ def log_validation( self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) - self.add_image(x, y, samples, "validation", step) + self.add_image(x, y, "validation", step) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 830674e01..ecdacd204 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -7,7 +7,6 @@ import numpy as np import torch -from torchvision.utils import make_grid import torch_em from torch_em.trainer.logger_base import TorchEmLogger @@ -126,16 +125,16 @@ def _compute_iou(self, pred, true, eps=1e-7): def _compute_loss(self, batched_outputs, y_one_hot): """Compute the loss for one iteration. The loss is made up of two components: - - The mask loss: dice score between the predicted masks and targets. - - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target. + - The mask loss: dice score between the predicted logits masks and targets. + - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction (logits masks) and target. """ mask_loss, iou_regression_loss = 0.0, 0.0 # Loop over the batch. for batch_output, targets in zip(batched_outputs, y_one_hot): - predicted_objects = torch.sigmoid(batch_output["masks"]) - # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). + predicted_objects = torch.sigmoid(batch_output["low_res_masks"]) + # Compute the dice scores for the 1 or 3 predicted (logits) masks per true object (outer loop). # We swap the axes that go into the dice loss so that the object axis # corresponds to the channel axes. This ensures that the dice is computed # independetly per channel. We do not reduce the channel axis in the dice, @@ -147,7 +146,7 @@ def _compute_loss(self, batched_outputs, y_one_hot): dice_scores, _ = torch.min(dice_scores, dim=0) # Compute the actual IOU between the predicted and true objects. - # The outer loop is for the 1 or 3 predicted masks per true object. + # The outer loop is for the 1 or 3 predicted (logits) masks per true object. with torch.no_grad(): true_iou = torch.stack([ self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1]) @@ -168,8 +167,7 @@ def _compute_loss(self, batched_outputs, y_one_hot): # def _get_best_masks(self, batched_outputs, batched_iou_predictions): - # Batched mask and logit (low-res mask) predictions. - masks = torch.stack([m["masks"] for m in batched_outputs]) + # Batched logit (low-res mask) predictions. logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) # Determine the best IOU across the multi-object prediction axis @@ -183,17 +181,13 @@ def _get_best_masks(self, batched_outputs, batched_iou_predictions): # Note that we squash the first two axes (batch x objects) into one when indexing. # That's why we need to reshape bax into (batch x objects) using a view. # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...) - batch_size, n_objects = masks.shape[:2] - h, w = masks.shape[-2:] - masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w) - + batch_size, n_objects = logits.shape[:2] h, w = logits.shape[-2:] logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w) - # Binarize the mask. Note that the mask here also contains logits, so we use 0.0 - # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid) - masks = (masks > 0.0).float() - return masks, logits + # Binarize the logits. + logits = (logits > 0.0).float() + return logits def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output): """Compute the loss for several (sub-)iterations of iterative prompting. @@ -228,10 +222,10 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim if i < (num_subiter - 1): # We need not update the prompts for the last iteration. # Determine the next prompts based on current predictions. with torch.no_grad(): - # Get the mask and logit predictions corresponding to the predicted object + # Get the logit predictions corresponding to the predicted object # (per actual object) with the best IOU. - masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) - batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits) + logits = self._get_best_masks(batched_outputs, batched_iou_predictions) + batched_inputs = self._update_prompts(batched_inputs, y_one_hot, logits) loss = loss / num_subiter mask_loss = mask_loss / num_subiter @@ -240,11 +234,11 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim return loss, mask_loss, iou_regression_loss, mean_model_iou - def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks): + def _update_prompts(self, batched_inputs, y_one_hot, logits_masks): # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") - for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks): + for y_, _inp, logits in zip(y_one_hot, batched_inputs, logits_masks): # here, we get each object in the pairs and do the point choices per-object - net_coords, net_labels, _, _ = self.prompt_generator(x2, x1) + net_coords, net_labels, _, _ = self.prompt_generator(segmentation=y_, prediction=logits) # convert the point coordinates to the expected resolution for iterative prompting # NOTE: @@ -278,7 +272,12 @@ def _preprocess_batch(self, batched_inputs, y, sampled_ids): """Compute one hot target (one mask per channel) for the sampled ids and restrict the number of sampled objects to the minimal number in the batch. """ - assert len(y) == len(sampled_ids) + # Get the downsampled masks. + y_downsampled = torch.from_numpy( + np.stack([bi["gt_downsampled"][None] for bi in batched_inputs]) + ).to(y.dtype) + + assert len(y) == len(y_downsampled) == len(sampled_ids) # Get the minimal number of objects in this batch. # The number of objects in a patch might be < n_objects_per_batch. @@ -286,11 +285,11 @@ def _preprocess_batch(self, batched_inputs, y, sampled_ids): # number of objects across the batch. n_objects = min(len(ids) for ids in sampled_ids) - y = y.to(self.device, non_blocking=True) + y_downsampled = y_downsampled.to(self.device, non_blocking=True) # Compute the one hot targets for the seg-id. y_one_hot = torch.stack([ torch.stack([target == seg_id for seg_id in ids[:n_objects]]) - for target, ids in zip(y, sampled_ids) + for target, ids in zip(y_downsampled, sampled_ids) ]).float() # Also restrict the prompts to the number of objects. @@ -306,13 +305,16 @@ def _interactive_train_iteration(self, x, y): batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) + # This step checks whether all foreground objects have valid masks. + assert all(len(torch.unique(curr_by)) > 1 for by in y_one_hot for curr_by in by) + loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss( batched_inputs=batched_inputs, y_one_hot=y_one_hot, num_subiter=self.n_sub_iteration, multimask_output=multimask_output ) - return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot + return loss, mask_loss, iou_regression_loss, model_iou def _check_input_normalization(self, x, input_check_done): # The expected data range of the SAM model is 8bit (0-255). @@ -347,16 +349,14 @@ def _train_epoch_impl(self, progress, forward_context, backprop): self.optimizer.zero_grad() with forward_context(): - (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y) = self._interactive_train_iteration(x, y) + loss, mask_loss, iou_regression_loss, model_iou = self._interactive_train_iteration(x, y) backprop(loss) if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None self.logger.log_train( - self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou + self._iteration, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou ) self._iteration += 1 @@ -387,7 +387,7 @@ def _interactive_val_iteration(self, x, y, val_iteration): metric = mask_loss model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs])) - return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric + return loss, mask_loss, iou_regression_loss, model_iou, metric def _validate_impl(self, forward_context): self.model.eval() @@ -403,7 +403,7 @@ def _validate_impl(self, forward_context): with forward_context(): (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) + metric) = self._interactive_val_iteration(x, y, val_iteration) loss_val += loss.item() metric_val += metric.item() @@ -418,8 +418,7 @@ def _validate_impl(self, forward_context): if self.logger is not None: self.logger.log_validation( - self._iteration, metric_val, loss_val, x, y, - sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val + self._iteration, metric_val, loss_val, x, y, mask_loss, iou_regression_loss, model_iou_val, ) return metric_val @@ -435,20 +434,18 @@ def __init__(self, trainer, save_root, **unused_kwargs): self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) self.log_image_interval = trainer.log_image_interval - def add_image(self, x, y, samples, name, step): + def add_image(self, x, y, name, step): self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step) self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step) - sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) - self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) - def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou): + def log_train(self, step, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou): self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step) self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step) self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) if step % self.log_image_interval == 0: - self.add_image(x, y, samples, "train", step) + self.add_image(x, y, "train", step) def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou): self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) @@ -456,4 +453,4 @@ def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regre self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step) self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) - self.add_image(x, y, samples, "validation", step) + self.add_image(x, y, "validation", step) diff --git a/micro_sam/training/trainable_sam.py b/micro_sam/training/trainable_sam.py index 88e851492..d0e42b1f5 100644 --- a/micro_sam/training/trainable_sam.py +++ b/micro_sam/training/trainable_sam.py @@ -16,9 +16,10 @@ class TrainableSAM(nn.Module): sam: The Segment Anything Model. """ - def __init__(self, sam: Sam) -> None: + def __init__(self, sam: Sam, upsample_masks: bool = False) -> None: super().__init__() self.sam = sam + self.upsample_masks = upsample_masks self.transform = ResizeLongestSide(sam.image_encoder.img_size) def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: @@ -103,9 +104,14 @@ def forward( multimask_output=multimask_output, ) - masks = self.sam.postprocess_masks( - masks=low_res_masks, input_size=image_record["input_size"], original_size=image_record["original_size"], - ) + curr_outputs = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions} + if self.upsample_masks: + masks = self.sam.postprocess_masks( + masks=low_res_masks, + input_size=image_record["input_size"], + original_size=image_record["original_size"], + ) + curr_outputs["masks"] = masks outputs.append( {"low_res_masks": low_res_masks, "masks": masks, "iou_predictions": iou_predictions} diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index b0845c120..a097c8680 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -6,6 +6,7 @@ import numpy as np import torch +from torch.nn import functional as F from segment_anything.utils.transforms import ResizeLongestSide @@ -183,13 +184,12 @@ def _distort_boxes(self, bbox_coordinates, shape): distorted_boxes.append([y0, x0, y1, x1]) return distorted_boxes - def _get_prompt_lists(self, gt, n_samples, prompt_generator): + def _get_prompt_lists(self, gt, cell_ids, n_samples, prompt_generator): """Returns a list of "expected" prompts subjected to the random input attributes for prompting.""" _, bbox_coordinates = get_centers_and_bounding_boxes(gt, mode="p") # get the segment ids - cell_ids = np.unique(gt)[1:] if n_samples is None: # n-samples is set to None, so we use all ids sampled_cell_ids = cell_ids @@ -209,6 +209,27 @@ def _get_prompt_lists(self, gt, n_samples, prompt_generator): point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates) return box_prompts, point_prompts, point_label_prompts, sampled_cell_ids + def _downsample_labels(self, y): + """Converts the masks to match the shape of "low_res_masks". + """ + + # Convert the labels to "low_res_mask" shape + # First step is to use the logic from `ResizeLongestSide` to resize the longest side. + target_length = self.transform.target_length + target_shape = self.transform.get_preprocess_shape(y.shape[2], y.shape[3], target_length) + y = F.interpolate(input=y, size=target_shape) + + # Next, we pad the remaining region to (1024, 1024) + h, w = y.shape[-2:] + padh = target_length - h + padw = target_length - w + y = F.pad(input=y, pad=(0, padw, 0, padh)) + + # Finally, let's resize the labels to the desired shape (i.e. (256, 256)) + y = F.interpolate(input=y, size=(256, 256)) + + return y + def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM. """ @@ -228,13 +249,21 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): get_point_prompts=get_points ) + # Downsample the labels. + y_downsampled = self._downsample_labels(y) + batched_inputs = [] batched_sampled_cell_ids_list = [] - for image, gt in zip(x, y): + for image, gt, gt_downsampled in zip(x, y, y_downsampled): gt = gt.squeeze().numpy().astype(np.int64) + gt_downsampled = gt_downsampled.squeeze().numpy().astype(np.int64) + + # Get the cell ids from the downsampled labels. + cell_ids = np.unique(gt_downsampled)[1:] + box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists( - gt, n_samples, prompt_generator, + gt=gt, cell_ids=cell_ids, n_samples=n_samples, prompt_generator=prompt_generator, ) # check to be sure about the expected size of the no. of elements in different settings @@ -247,7 +276,7 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): batched_sampled_cell_ids_list.append(sampled_cell_ids) - batched_input = {"image": image, "original_size": image.shape[1:]} + batched_input = {"image": image, "original_size": image.shape[1:], "gt_downsampled": gt_downsampled} if get_boxes: batched_input["boxes"] = self.transform.apply_boxes_torch( box_prompts, original_size=gt.shape[-2:]