Skip to content

Save weights to custom folder of user 's choice #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ModelConfig(BaseModel):
layer_norm: bool = True
amp: bool = True
num_classes: int = 90
pretrain_save_file:Optional[str] = 'model.pth'
pretrain_weights: Optional[str] = None
device: Literal["cpu", "cuda", "mps"] = DEVICE
resolution: int = 560
Expand All @@ -34,7 +35,7 @@ class RFDETRBaseConfig(ModelConfig):
num_select: int = 300
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"]
out_feature_indexes: List[int] = [2, 5, 8, 11]
pretrain_weights: Optional[str] = "rf-detr-base.pth"
pretrain_weights: Optional[str] = "rfdetr_base"

class RFDETRLargeConfig(RFDETRBaseConfig):
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base"
Expand All @@ -43,7 +44,7 @@ class RFDETRLargeConfig(RFDETRBaseConfig):
ca_nheads: int = 24
dec_n_points: int = 4
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"]
pretrain_weights: Optional[str] = "rf-detr-large.pth"
pretrain_weights: Optional[str] = "rfdetr_large"

class TrainConfig(BaseModel):
lr: float = 1e-4
Expand Down
3 changes: 2 additions & 1 deletion rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __init__(self, **kwargs):
self.callbacks = defaultdict(list)

def maybe_download_pretrain_weights(self):
download_pretrain_weights(self.model_config.pretrain_weights)
download_pretrain_weights(self.model_config.pretrain_weights, self.model_config.pretrain_save_file)


def get_model_config(self, **kwargs):
return ModelConfig(**kwargs)
Expand Down
46 changes: 29 additions & 17 deletions rfdetr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,39 +52,46 @@
logger = getLogger(__name__)

HOSTED_MODELS = {
"rf-detr-base.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth",
"rfdetr_base": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth",
# below is a less converged model that may be better for finetuning but worse for inference
"rf-detr-base-2.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth",
"rf-detr-large.pth": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth"
"rfdetr_base2": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth",
"rfdetr_large": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth"
}

def download_pretrain_weights(pretrain_weights: str, redownload=False):
if pretrain_weights in HOSTED_MODELS:
if redownload or not os.path.exists(pretrain_weights):
logger.info(
f"Downloading pretrained weights for {pretrain_weights}"
)
download_file(
HOSTED_MODELS[pretrain_weights],
pretrain_weights,
)
def download_pretrain_weights(model_type: str, output_path: str, redownload:bool = False):
if model_type not in HOSTED_MODELS:
raise ValueError(f"Unknown model type '{model_type}'. Valid options are: {list(HOSTED_MODELS.keys())}")

if redownload or not os.path.exists(output_path):
# Create parent directory only if there is one
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)

logger.info(
f"Downloading pretrained weights for {model_type}"
)
download_file(
HOSTED_MODELS[model_type],
output_path,
)

class Model:
def __init__(self, **kwargs):
args = populate_args(**kwargs)
self.resolution = args.resolution
self.model = build_model(args)
self.device = torch.device(args.device)
if args.pretrain_weights is not None:
if args.pretrain_save_file is not None:
print("Loading pretrain weights")
try:
checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
checkpoint = torch.load(args.pretrain_save_file, map_location='cpu', weights_only=False)
except Exception as e:
print(f"Failed to load pretrain weights: {e}")
# re-download weights if they are corrupted
print("Failed to load pretrain weights, re-downloading")
download_pretrain_weights(args.pretrain_weights, redownload=True)
checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
download_pretrain_weights(args.pretrain_weights,args.pretrain_save_file, redownload=True)
checkpoint = torch.load(args.pretrain_save_file, map_location='cpu', weights_only=False)

checkpoint_num_classes = checkpoint['model']['class_embed.bias'].shape[0]
if checkpoint_num_classes != args.num_classes + 1:
Expand Down Expand Up @@ -541,6 +548,7 @@ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_
"cutoff_epoch",
"pretrained_encoder",
"pretrain_weights",
"pretrain_save_file",
"pretrain_exclude_keys",
"pretrain_keys_modify_to_load",
"freeze_florence",
Expand Down Expand Up @@ -630,6 +638,8 @@ def get_args_parser():
parser.add_argument('--pretrained_encoder', type=str, default=None,
help="Path to the pretrained encoder.")
parser.add_argument('--pretrain_weights', type=str, default=None,
help="Model type to use.")
parser.add_argument('--pretrain_save_file', type=str, default='model.pth',
help="Path to the pretrained model.")
parser.add_argument('--pretrain_exclude_keys', type=str, default=None, nargs='+',
help="Keys you do not want to load.")
Expand Down Expand Up @@ -806,6 +816,7 @@ def populate_args(
# Model parameters
pretrained_encoder=None,
pretrain_weights=None,
pretrain_save_file=None,
pretrain_exclude_keys=None,
pretrain_keys_modify_to_load=None,
pretrained_distiller=None,
Expand Down Expand Up @@ -924,6 +935,7 @@ def populate_args(
cutoff_epoch=cutoff_epoch,
pretrained_encoder=pretrained_encoder,
pretrain_weights=pretrain_weights,
pretrain_save_file=pretrain_save_file,
pretrain_exclude_keys=pretrain_exclude_keys,
pretrain_keys_modify_to_load=pretrain_keys_modify_to_load,
pretrained_distiller=pretrained_distiller,
Expand Down