diff --git a/baselines/cifar/utils.py b/baselines/cifar/utils.py index 2a488c3fd..daa5cc72e 100644 --- a/baselines/cifar/utils.py +++ b/baselines/cifar/utils.py @@ -40,7 +40,7 @@ 'Number of epochs between evaluating on the corrupted ' 'test data. Use -1 to never evaluate.') flags.DEFINE_enum('dataset', 'cifar10', - enum_values=['cifar10', 'cifar100'], + enum_values=['cifar10', 'cifar100', 'cifar10n', 'cifar100n'], help='Dataset.') flags.DEFINE_string( diff --git a/baselines/privileged_information/cifar_pi/no_pi.py b/baselines/privileged_information/cifar_pi/no_pi.py index 2ceb7e8da..1472831e8 100644 --- a/baselines/privileged_information/cifar_pi/no_pi.py +++ b/baselines/privileged_information/cifar_pi/no_pi.py @@ -228,9 +228,6 @@ def main(argv): 'augmix_depth': FLAGS.augmix_depth, 'augmix_prob_coeff': FLAGS.augmix_prob_coeff, 'augmix_width': FLAGS.augmix_width, - 'same_mix_weight_per_batch': FLAGS.same_mix_weight_per_batch, - 'use_random_shuffling': FLAGS.use_random_shuffling, - 'use_truncated_beta': FLAGS.use_truncated_beta } # Note that stateless_{fold_in,split} may incur a performance cost, but a diff --git a/baselines/privileged_information/cifar_pi/tram.py b/baselines/privileged_information/cifar_pi/tram.py index e5154acef..6c38fa3dd 100644 --- a/baselines/privileged_information/cifar_pi/tram.py +++ b/baselines/privileged_information/cifar_pi/tram.py @@ -230,9 +230,6 @@ def main(argv): 'augmix_depth': FLAGS.augmix_depth, 'augmix_prob_coeff': FLAGS.augmix_prob_coeff, 'augmix_width': FLAGS.augmix_width, - 'same_mix_weight_per_batch': FLAGS.same_mix_weight_per_batch, - 'use_random_shuffling': FLAGS.use_random_shuffling, - 'use_truncated_beta': FLAGS.use_truncated_beta } # Note that stateless_{fold_in,split} may incur a performance cost, but a diff --git a/uncertainty_baselines/datasets/__init__.py b/uncertainty_baselines/datasets/__init__.py index 0dbeacabd..309df42ac 100644 --- a/uncertainty_baselines/datasets/__init__.py +++ b/uncertainty_baselines/datasets/__init__.py @@ -27,6 +27,8 @@ from uncertainty_baselines.datasets.cifar import Cifar100Dataset from uncertainty_baselines.datasets.cifar import Cifar10CorruptedDataset from uncertainty_baselines.datasets.cifar import Cifar10Dataset +from uncertainty_baselines.datasets.cifar import Cifar10NDataset +from uncertainty_baselines.datasets.cifar import Cifar100NDataset from uncertainty_baselines.datasets.cifar100_corrupted import Cifar100CorruptedDataset from uncertainty_baselines.datasets.cityscapes import CityscapesDataset from uncertainty_baselines.datasets.cityscapes_corrupted import CityscapesCorruptedDataset diff --git a/uncertainty_baselines/datasets/datasets.py b/uncertainty_baselines/datasets/datasets.py index 1511cde46..dcc031316 100644 --- a/uncertainty_baselines/datasets/datasets.py +++ b/uncertainty_baselines/datasets/datasets.py @@ -28,6 +28,8 @@ from uncertainty_baselines.datasets.cifar import Cifar100Dataset from uncertainty_baselines.datasets.cifar import Cifar10CorruptedDataset from uncertainty_baselines.datasets.cifar import Cifar10Dataset +from uncertainty_baselines.datasets.cifar import Cifar10NDataset +from uncertainty_baselines.datasets.cifar import Cifar100NDataset from uncertainty_baselines.datasets.cifar100_corrupted import Cifar100CorruptedDataset from uncertainty_baselines.datasets.cityscapes import CityscapesDataset from uncertainty_baselines.datasets.clinc_intent import ClincIntentDetectionDataset @@ -88,6 +90,8 @@ 'aptos': APTOSDataset, 'cifar100': Cifar100Dataset, 'cifar10': Cifar10Dataset, + 'cifar10n': Cifar10NDataset, + 'cifar100n': Cifar100NDataset, 'cifar10_corrupted': Cifar10CorruptedDataset, 'cifar100_corrupted': Cifar100CorruptedDataset, 'cityscapes': CityscapesDataset,