From 0d9cf86b876a039125708b7be7aa31a0fcb718e4 Mon Sep 17 00:00:00 2001 From: robimc14 Date: Mon, 26 Aug 2019 14:15:55 +0100 Subject: [PATCH 1/4] added most likely hetero GP fitting procedure --- botorch/models/utils.py | 125 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/botorch/models/utils.py b/botorch/models/utils.py index e1bec1d9ed..7694db58bb 100644 --- a/botorch/models/utils.py +++ b/botorch/models/utils.py @@ -15,6 +15,16 @@ from ..exceptions import InputDataError, InputDataWarning +from .models import SingleTaskGP +from .models import HeteroskedasticSingleTaskGP +from ..sampling import IIDNormalSampler +from gpytorch.constraints import GreaterThan +from gpytorch.mlls import ExactMarginalLogLikelihood + +from gpytorch.kernels.scale_kernel import ScaleKernel +from gpytorch.kernels.rbf_kernel import RBFKernel + + def _make_X_full(X: Tensor, output_indices: List[int], tf: int) -> Tensor: r"""Helper to construct input tensor with task indices. @@ -179,3 +189,118 @@ def check_standardization( if raise_on_fail: raise InputDataError(msg) warnings.warn(msg, InputDataWarning) + + +def fit_most_likely_HeteroskedasticGP( + train_X: Tensor, + train_Y: Tensor, + covar_module: Optional[Module] = None, + num_var_samples: int = 100, + max_iter: int = 10, + atol_mean: float = 1e-04, + atol_var: float = 1e-04, + ) -> HeteroskedasticSingleTaskGP: + r"""Fit the Most Likely Heteroskedastic GP. + + The original algorithm is described in + http://people.csail.mit.edu/kersting/papers/kersting07icml_mlHetGP.pdf + + Args: + train_X: A `n x d` or `batch_shape x n x d` (batch mode) tensor of training + features. + train_Y: A `n x m` or `batch_shape x n x m` (batch mode) tensor of + training observations. + covar_module: The covariance (kernel) matrix for the initial homoskedastic GP. + If omitted, use the RBFKernel. + num_var_samples: Number of samples to draw from posterior when estimating noise. + max_iter: Maximum number of iterations used when fitting the model. + atol_mean: The tolerance for the mean check. + atol_std: The tolerance for the var check. + Returns: + HeteroskedasticSingleTaskGP Model fit using the "most-likely" procedure. + """ + + if covar_module is None: + covar_module = ScaleKernel(RBFKernel()) + + # check to see if input Tensors are normalized and standardized + check_min_max_scaling(train_X) + check_standardization(train_Y) + + # fit initial homoskedastic model used to estimate noise levels + homo_model = SingleTaskGP(train_X=train_X, train_Y=train_Y, + covar_module=covar_module) + homo_model.likelihood.noise_covar.register_constraint("raw_noise", + GreaterThan(1e-5)) + homo_mll = gpytorch.mlls.ExactMarginalLogLikelihood(homo_model.likelihood, + homo_model) + botorch.fit.fit_gpytorch_model(homo_mll) + + # get estimates of noise + homo_mll.eval() + with torch.no_grad(): + homo_posterior = homo_mll.model.posterior(train_X.clone()) + homo_predictive_posterior = homo_mll.model.posterior(train_X.clone(), + observation_noise=True) + sampler = IIDNormalSampler(num_samples=num_var_samples, resample=True) + predictive_samples = sampler(homo_predictive_posterior) + observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1,1))**2).mean(dim=0) + + # save mean and variance to check if they change later + saved_mean = homo_posterior.mean + saved_var = homo_posterior.variance + + for i in range(max_iter): + + # now train hetero model using computed noise + hetero_model = HeteroskedasticSingleTaskGP(train_X=train_X, train_Y=train_Y, + train_Yvar=observed_var) + hetero_mll = gpytorch.mlls.ExactMarginalLogLikelihood(hetero_model.likelihood, + hetero_model) + try: + botorch.fit.fit_gpytorch_model(hetero_mll) + except Exception as e: + msg = f'Fitting failed on iteration {i}. Returning the current MLL' + warnings.warn(msg, e) + return saved_hetero_mll + + hetero_mll.eval() + with torch.no_grad(): + hetero_posterior = hetero_mll.model.posterior(train_X.clone()) + hetero_predictive_posterior = hetero_mll.model.posterior(train_X.clone(), + observation_noise=True) + + new_mean = hetero_posterior.mean + new_var = hetero_posterior.variance + + means_equal = torch.all(torch.lt(torch.abs(torch.add(saved_mean, -new_mean)), atol_mean)) + max_change_in_means = torch.max(torch.abs(torch.add(saved_mean, -new_mean))) + + var_equal = torch.all(torch.lt(torch.abs(torch.add(saved_var, -new_var)), atol_var)) + max_change_in_var = torch.max(torch.abs(torch.add(saved_var -new_var))) + + if means_eq and variances_eq: + return hetero_mll + else: + saved_hetero_mll = hetero_mll + + saved_mean = new_mean + saved_var = new_var + + # get new noise estimate + sampler = IIDNormalSampler(num_samples=num_var_samples, resample=True) + predictive_samples = sampler(hetero_predictive_posterior) + observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1,1))**2).mean(dim=0) + + + msg = f'Did not reach convergence after {max_iter} iterations. Returning the current MLL.' + warnings.warn(msg) + return hetero_mll + + + + + + + + From 0c8c69417cb1511d5e3a6fcdffb31737a86b157f Mon Sep 17 00:00:00 2001 From: robimc14 Date: Mon, 26 Aug 2019 16:23:52 +0100 Subject: [PATCH 2/4] deleted standardization and scaling testing --- botorch/models/utils.py | 80 +++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 44 deletions(-) diff --git a/botorch/models/utils.py b/botorch/models/utils.py index 7694db58bb..68cc8f878d 100644 --- a/botorch/models/utils.py +++ b/botorch/models/utils.py @@ -192,7 +192,7 @@ def check_standardization( def fit_most_likely_HeteroskedasticGP( - train_X: Tensor, + train_X: Tensor, train_Y: Tensor, covar_module: Optional[Module] = None, num_var_samples: int = 100, @@ -200,10 +200,10 @@ def fit_most_likely_HeteroskedasticGP( atol_mean: float = 1e-04, atol_var: float = 1e-04, ) -> HeteroskedasticSingleTaskGP: - r"""Fit the Most Likely Heteroskedastic GP. + r"""Fit the Most Likely Heteroskedastic GP. - The original algorithm is described in - http://people.csail.mit.edu/kersting/papers/kersting07icml_mlHetGP.pdf + The original algorithm is described in + http://people.csail.mit.edu/kersting/papers/kersting07icml_mlHetGP.pdf Args: train_X: A `n x d` or `batch_shape x n x d` (batch mode) tensor of training @@ -211,37 +211,38 @@ def fit_most_likely_HeteroskedasticGP( train_Y: A `n x m` or `batch_shape x n x m` (batch mode) tensor of training observations. covar_module: The covariance (kernel) matrix for the initial homoskedastic GP. - If omitted, use the RBFKernel. + If omitted, use the RBFKernel. num_var_samples: Number of samples to draw from posterior when estimating noise. max_iter: Maximum number of iterations used when fitting the model. - atol_mean: The tolerance for the mean check. - atol_std: The tolerance for the var check. + atol_mean: The tolerance for the mean check. + atol_std: The tolerance for the var check. Returns: - HeteroskedasticSingleTaskGP Model fit using the "most-likely" procedure. + HeteroskedasticSingleTaskGP Model fit using the "most-likely" procedure. """ - if covar_module is None: - covar_module = ScaleKernel(RBFKernel()) + if covar_module is None: + covar_module = ScaleKernel(RBFKernel()) - # check to see if input Tensors are normalized and standardized - check_min_max_scaling(train_X) - check_standardization(train_Y) + # CANNOT CHECK RIGHT NOW BECAUSE NEED TO FIRST ADD BATCH DIMENSION + # check to see if input Tensors are normalized and standardized + # check_min_max_scaling(train_X) + # check_standardization(train_Y) - # fit initial homoskedastic model used to estimate noise levels - homo_model = SingleTaskGP(train_X=train_X, train_Y=train_Y, + # fit initial homoskedastic model used to estimate noise levels + homo_model = SingleTaskGP(train_X=train_X, train_Y=train_Y, covar_module=covar_module) - homo_model.likelihood.noise_covar.register_constraint("raw_noise", + homo_model.likelihood.noise_covar.register_constraint("raw_noise", GreaterThan(1e-5)) - homo_mll = gpytorch.mlls.ExactMarginalLogLikelihood(homo_model.likelihood, + homo_mll = gpytorch.mlls.ExactMarginalLogLikelihood(homo_model.likelihood, homo_model) - botorch.fit.fit_gpytorch_model(homo_mll) - - # get estimates of noise - homo_mll.eval() - with torch.no_grad(): - homo_posterior = homo_mll.model.posterior(train_X.clone()) - homo_predictive_posterior = homo_mll.model.posterior(train_X.clone(), - observation_noise=True) + botorch.fit.fit_gpytorch_model(homo_mll) + + # get estimates of noise + homo_mll.eval() + with torch.no_grad(): + homo_posterior = homo_mll.model.posterior(train_X.clone()) + homo_predictive_posterior = homo_mll.model.posterior(train_X.clone(), + observation_noise=True) sampler = IIDNormalSampler(num_samples=num_var_samples, resample=True) predictive_samples = sampler(homo_predictive_posterior) observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1,1))**2).mean(dim=0) @@ -260,26 +261,26 @@ def fit_most_likely_HeteroskedasticGP( try: botorch.fit.fit_gpytorch_model(hetero_mll) except Exception as e: - msg = f'Fitting failed on iteration {i}. Returning the current MLL' - warnings.warn(msg, e) + msg = f'Fitting failed on iteration {i}. Returning the current MLL' + warnings.warn(msg, e) return saved_hetero_mll hetero_mll.eval() with torch.no_grad(): hetero_posterior = hetero_mll.model.posterior(train_X.clone()) hetero_predictive_posterior = hetero_mll.model.posterior(train_X.clone(), - observation_noise=True) + observation_noise=True) new_mean = hetero_posterior.mean new_var = hetero_posterior.variance - - means_equal = torch.all(torch.lt(torch.abs(torch.add(saved_mean, -new_mean)), atol_mean)) + + mean_equality = torch.all(torch.lt(torch.abs(torch.add(saved_mean, -new_mean)), atol_mean)) max_change_in_means = torch.max(torch.abs(torch.add(saved_mean, -new_mean))) - var_equal = torch.all(torch.lt(torch.abs(torch.add(saved_var, -new_var)), atol_var)) - max_change_in_var = torch.max(torch.abs(torch.add(saved_var -new_var))) + var_equality = torch.all(torch.lt(torch.abs(torch.add(saved_var, -new_var)), atol_var)) + max_change_in_var = torch.max(torch.abs(torch.add(saved_var, -new_var))) - if means_eq and variances_eq: + if mean_equality and var_equality: return hetero_mll else: saved_hetero_mll = hetero_mll @@ -291,16 +292,7 @@ def fit_most_likely_HeteroskedasticGP( sampler = IIDNormalSampler(num_samples=num_var_samples, resample=True) predictive_samples = sampler(hetero_predictive_posterior) observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1,1))**2).mean(dim=0) - - + msg = f'Did not reach convergence after {max_iter} iterations. Returning the current MLL.' warnings.warn(msg) - return hetero_mll - - - - - - - - + return hetero_mll \ No newline at end of file From 7b93b45793452a51590cfdb3998654693723bf84 Mon Sep 17 00:00:00 2001 From: robimc14 Date: Mon, 26 Aug 2019 16:34:08 +0100 Subject: [PATCH 3/4] added Module import --- botorch/models/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/botorch/models/utils.py b/botorch/models/utils.py index 68cc8f878d..90fcb800e8 100644 --- a/botorch/models/utils.py +++ b/botorch/models/utils.py @@ -7,7 +7,7 @@ """ import warnings -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Any import torch from gpytorch.utils.broadcasting import _mul_broadcast_shape @@ -20,6 +20,7 @@ from ..sampling import IIDNormalSampler from gpytorch.constraints import GreaterThan from gpytorch.mlls import ExactMarginalLogLikelihood +from gpytorch.module import Module from gpytorch.kernels.scale_kernel import ScaleKernel from gpytorch.kernels.rbf_kernel import RBFKernel From 190592f190b350a932eb5fdeae039fdb74409634 Mon Sep 17 00:00:00 2001 From: robimc14 Date: Mon, 26 Aug 2019 16:37:34 +0100 Subject: [PATCH 4/4] used black formatter --- botorch/models/utils.py | 76 ++++++++++++++++++++++++----------------- 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/botorch/models/utils.py b/botorch/models/utils.py index 90fcb800e8..ef55795a44 100644 --- a/botorch/models/utils.py +++ b/botorch/models/utils.py @@ -26,7 +26,6 @@ from gpytorch.kernels.rbf_kernel import RBFKernel - def _make_X_full(X: Tensor, output_indices: List[int], tf: int) -> Tensor: r"""Helper to construct input tensor with task indices. @@ -200,7 +199,7 @@ def fit_most_likely_HeteroskedasticGP( max_iter: int = 10, atol_mean: float = 1e-04, atol_var: float = 1e-04, - ) -> HeteroskedasticSingleTaskGP: +) -> HeteroskedasticSingleTaskGP: r"""Fit the Most Likely Heteroskedastic GP. The original algorithm is described in @@ -230,70 +229,85 @@ def fit_most_likely_HeteroskedasticGP( # check_standardization(train_Y) # fit initial homoskedastic model used to estimate noise levels - homo_model = SingleTaskGP(train_X=train_X, train_Y=train_Y, - covar_module=covar_module) - homo_model.likelihood.noise_covar.register_constraint("raw_noise", - GreaterThan(1e-5)) - homo_mll = gpytorch.mlls.ExactMarginalLogLikelihood(homo_model.likelihood, - homo_model) + homo_model = SingleTaskGP( + train_X=train_X, train_Y=train_Y, covar_module=covar_module + ) + homo_model.likelihood.noise_covar.register_constraint( + "raw_noise", GreaterThan(1e-5) + ) + homo_mll = gpytorch.mlls.ExactMarginalLogLikelihood( + homo_model.likelihood, homo_model + ) botorch.fit.fit_gpytorch_model(homo_mll) - # get estimates of noise + # get estimates of noise homo_mll.eval() with torch.no_grad(): homo_posterior = homo_mll.model.posterior(train_X.clone()) - homo_predictive_posterior = homo_mll.model.posterior(train_X.clone(), - observation_noise=True) + homo_predictive_posterior = homo_mll.model.posterior( + train_X.clone(), observation_noise=True + ) sampler = IIDNormalSampler(num_samples=num_var_samples, resample=True) predictive_samples = sampler(homo_predictive_posterior) - observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1,1))**2).mean(dim=0) + observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1, 1)) ** 2).mean( + dim=0 + ) # save mean and variance to check if they change later saved_mean = homo_posterior.mean saved_var = homo_posterior.variance - for i in range(max_iter): + for i in range(max_iter): # now train hetero model using computed noise - hetero_model = HeteroskedasticSingleTaskGP(train_X=train_X, train_Y=train_Y, - train_Yvar=observed_var) - hetero_mll = gpytorch.mlls.ExactMarginalLogLikelihood(hetero_model.likelihood, - hetero_model) + hetero_model = HeteroskedasticSingleTaskGP( + train_X=train_X, train_Y=train_Y, train_Yvar=observed_var + ) + hetero_mll = gpytorch.mlls.ExactMarginalLogLikelihood( + hetero_model.likelihood, hetero_model + ) try: botorch.fit.fit_gpytorch_model(hetero_mll) except Exception as e: - msg = f'Fitting failed on iteration {i}. Returning the current MLL' + msg = f"Fitting failed on iteration {i}. Returning the current MLL" warnings.warn(msg, e) return saved_hetero_mll hetero_mll.eval() with torch.no_grad(): hetero_posterior = hetero_mll.model.posterior(train_X.clone()) - hetero_predictive_posterior = hetero_mll.model.posterior(train_X.clone(), - observation_noise=True) - + hetero_predictive_posterior = hetero_mll.model.posterior( + train_X.clone(), observation_noise=True + ) + new_mean = hetero_posterior.mean new_var = hetero_posterior.variance - - mean_equality = torch.all(torch.lt(torch.abs(torch.add(saved_mean, -new_mean)), atol_mean)) + + mean_equality = torch.all( + torch.lt(torch.abs(torch.add(saved_mean, -new_mean)), atol_mean) + ) max_change_in_means = torch.max(torch.abs(torch.add(saved_mean, -new_mean))) - var_equality = torch.all(torch.lt(torch.abs(torch.add(saved_var, -new_var)), atol_var)) + var_equality = torch.all( + torch.lt(torch.abs(torch.add(saved_var, -new_var)), atol_var) + ) max_change_in_var = torch.max(torch.abs(torch.add(saved_var, -new_var))) - + if mean_equality and var_equality: return hetero_mll else: saved_hetero_mll = hetero_mll - + saved_mean = new_mean saved_var = new_var - + # get new noise estimate sampler = IIDNormalSampler(num_samples=num_var_samples, resample=True) predictive_samples = sampler(hetero_predictive_posterior) - observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1,1))**2).mean(dim=0) - - msg = f'Did not reach convergence after {max_iter} iterations. Returning the current MLL.' + observed_var = 0.5 * ((predictive_samples - train_Y.reshape(-1, 1)) ** 2).mean( + dim=0 + ) + + msg = f"Did not reach convergence after {max_iter} iterations. Returning the current MLL." warnings.warn(msg) - return hetero_mll \ No newline at end of file + return hetero_mll