From 2dd15de1840acc7745548be953252c2c4c1596f6 Mon Sep 17 00:00:00 2001 From: "Dr. Kashif Rasul" Date: Wed, 22 Jul 2020 11:50:33 +0200 Subject: [PATCH 1/2] added dilate loss --- pts/model/n_beats/n_beats_ensemble.py | 2 +- pts/model/n_beats/n_beats_network.py | 54 ++-------- pts/modules/__init__.py | 3 + pts/modules/loss.py | 75 ++++++++++++++ pts/modules/path_soft_dtw.py | 144 ++++++++++++++++++++++++++ pts/modules/soft_dtw.py | 103 ++++++++++++++++++ setup.py | 1 + 7 files changed, 333 insertions(+), 49 deletions(-) create mode 100644 pts/modules/loss.py create mode 100644 pts/modules/path_soft_dtw.py create mode 100644 pts/modules/soft_dtw.py diff --git a/pts/model/n_beats/n_beats_ensemble.py b/pts/model/n_beats/n_beats_ensemble.py index 861519c..b6e3231 100644 --- a/pts/model/n_beats/n_beats_ensemble.py +++ b/pts/model/n_beats/n_beats_ensemble.py @@ -118,7 +118,7 @@ class NBEATSEnsembleEstimator(Estimator): meta_loss_function The different 'loss_function' (also known as metric) to use for training the models. Unlike other models in GluonTS this network does not use a distribution. - Default and recommended value: ["sMAPE", "MASE", "MAPE"] + Default and recommended value: ["sMAPE", "MASE", "MAPE", "DILATE"] meta_bagging_size The number of models that share the parameter combination of 'context_length' and 'loss_function'. Each of these models gets a different initialization random initialization. diff --git a/pts/model/n_beats/n_beats_network.py b/pts/model/n_beats/n_beats_network.py index af5fcea..80411c6 100644 --- a/pts/model/n_beats/n_beats_network.py +++ b/pts/model/n_beats/n_beats_network.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from pts.feature import get_seasonality +from pts.modules import smape_loss, mape_loss, mase_loss, dilate_loss VALID_N_BEATS_STACK_TYPES = "G", "S", "T" VALID_LOSS_FUNCTIONS = "sMAPE", "MASE", "MAPE" @@ -251,51 +252,6 @@ def forward(self, past_target: torch.Tensor): _, last_forecast = self.net_blocks[-1](backcast) return forecast + last_forecast - def smape_loss( - self, forecast: torch.Tensor, future_target: torch.Tensor - ) -> torch.Tensor: - denominator = (torch.abs(future_target) + torch.abs(forecast)).detach() - flag = denominator == 0 - - return (200 / self.prediction_length) * torch.mean( - (torch.abs(future_target - forecast) * torch.logical_not(flag)) / (denominator + flag), - dim=1, - ) - - def mape_loss( - self, forecast: torch.Tensor, future_target: torch.Tensor - ) -> torch.Tensor: - denominator = torch.abs(future_target) - flag = denominator == 0 - - return (100 / self.prediction_length) * torch.mean( - (torch.abs(future_target - forecast) * torch.logical_not(flag)) / (denominator + flag), - dim=1, - ) - - def mase_loss( - self, - forecast: torch.Tensor, - future_target: torch.Tensor, - past_target: torch.Tensor, - periodicity: int, - ) -> torch.Tensor: - factor = 1 / (self.context_length + self.prediction_length - periodicity) - - whole_target = torch.cat((past_target, future_target), dim=1) - seasonal_error = factor * torch.mean( - torch.abs( - whole_target[:, periodicity:, ...] - - whole_target[:, :-periodicity:, ...] - ), - dim=1, - ) - flag = seasonal_error == 0 - - return (torch.mean(torch.abs(future_target - forecast), dim=1) * torch.logical_not(flag)) / ( - seasonal_error + flag - ) - class NBEATSTrainingNetwork(NBEATSNetwork): def __init__(self, loss_function: str, freq: str, *args, **kwargs) -> None: @@ -317,13 +273,15 @@ def forward( forecast = super().forward(past_target=past_target) if self.loss_function == "sMAPE": - loss = self.smape_loss(forecast, future_target) + loss = smape_loss(forecast, future_target) elif self.loss_function == "MAPE": - loss = self.mape_loss(forecast, future_target) + loss = mape_loss(forecast, future_target) elif self.loss_function == "MASE": - loss = self.mase_loss( + loss = mase_loss( forecast, future_target, past_target, self.periodicity ) + elif self.loss_function == "DILATE": + loss = dilate_loss(forecast, future_target) else: raise ValueError( f"Invalid value {self.loss_function} for argument loss_function." diff --git a/pts/modules/__init__.py b/pts/modules/__init__.py index fa20d1f..f406cfc 100644 --- a/pts/modules/__init__.py +++ b/pts/modules/__init__.py @@ -20,3 +20,6 @@ from .flows import RealNVP, MAF from .lambda_layer import LambdaLayer from .scaler import MeanScaler, NOPScaler +from .soft_dtw import SoftDTWBatch +from .path_soft_dtw import PathDTWBatch +from .loss import smape_loss, mape_loss, mase_loss, dilate_loss diff --git a/pts/modules/loss.py b/pts/modules/loss.py new file mode 100644 index 0000000..b361cb7 --- /dev/null +++ b/pts/modules/loss.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn + +from .soft_dtw import SoftDTWBatch, pairwise_distances +from .path_soft_dtw import PathDTWBatch + + +def smape_loss(forecast: torch.Tensor, future_target: torch.Tensor) -> torch.Tensor: + denominator = (torch.abs(future_target) + torch.abs(forecast)).detach() + flag = denominator == 0 + + return (200 / self.prediction_length) * torch.mean( + (torch.abs(future_target - forecast) * torch.logical_not(flag)) + / (denominator + flag), + dim=1, + ) + + +def mape_loss(forecast: torch.Tensor, future_target: torch.Tensor) -> torch.Tensor: + denominator = torch.abs(future_target) + flag = denominator == 0 + + return (100 / self.prediction_length) * torch.mean( + (torch.abs(future_target - forecast) * torch.logical_not(flag)) + / (denominator + flag), + dim=1, + ) + + +def mase_loss( + forecast: torch.Tensor, + future_target: torch.Tensor, + past_target: torch.Tensor, + periodicity: int, +) -> torch.Tensor: + factor = 1 / (self.context_length + self.prediction_length - periodicity) + + whole_target = torch.cat((past_target, future_target), dim=1) + seasonal_error = factor * torch.mean( + torch.abs( + whole_target[:, periodicity:, ...] - whole_target[:, :-periodicity:, ...] + ), + dim=1, + ) + flag = seasonal_error == 0 + + return ( + torch.mean(torch.abs(future_target - forecast), dim=1) * torch.logical_not(flag) + ) / (seasonal_error + flag) + + +def dilate_loss(forecast, future_target, alpha=0.5, gamma=0.01): + batch_size, N_output = forecast.shape + + pairwise_distance = torch.zeros((batch_size, N_output, N_output)).to( + forecast.device + ) + for k in range(batch_size): + Dk = pairwise_distances( + future_target[k, :].view(-1, 1), forecast[k, :].view(-1, 1) + ) + pairwise_distance[k : k + 1, :, :] = Dk + + softdtw_batch = SoftDTWBatch.apply + loss_shape = softdtw_batch(pairwise_distance, gamma) + + path_dtw = PathDTWBatch.apply + path = path_dtw(pairwise_distance, gamma) + + omega = pairwise_distances(torch.arange(1, N_output + 1).view(N_output, 1)).to( + forecast.device + ) + loss_temporal = torch.sum(path * omega) / (N_output * N_output) + + return alpha * loss_shape + (1 - alpha) * loss_temporal diff --git a/pts/modules/path_soft_dtw.py b/pts/modules/path_soft_dtw.py new file mode 100644 index 0000000..2ec2709 --- /dev/null +++ b/pts/modules/path_soft_dtw.py @@ -0,0 +1,144 @@ +import numpy as np +import torch +from torch.autograd import Function +from numba import jit + + +@jit(nopython=True) +def my_max(x, gamma): + # use the log-sum-exp trick + max_x = np.max(x) + exp_x = np.exp((x - max_x) / gamma) + Z = np.sum(exp_x) + return gamma * np.log(Z) + max_x, exp_x / Z + + +@jit(nopython=True) +def my_min(x, gamma): + min_x, argmax_x = my_max(-x, gamma) + return -min_x, argmax_x + + +@jit(nopython=True) +def my_max_hessian_product(p, z, gamma): + return (p * z - p * np.sum(p * z)) / gamma + + +@jit(nopython=True) +def my_min_hessian_product(p, z, gamma): + return -my_max_hessian_product(p, z, gamma) + + +@jit(nopython=True) +def dtw_grad(theta, gamma): + m = theta.shape[0] + n = theta.shape[1] + V = np.zeros((m + 1, n + 1)) + V[:, 0] = 1e10 + V[0, :] = 1e10 + V[0, 0] = 0 + + Q = np.zeros((m + 2, n + 2, 3)) + + for i in range(1, m + 1): + for j in range(1, n + 1): + # theta is indexed starting from 0. + v, Q[i, j] = my_min( + np.array([V[i, j - 1], V[i - 1, j - 1], V[i - 1, j]]), gamma + ) + V[i, j] = theta[i - 1, j - 1] + v + + E = np.zeros((m + 2, n + 2)) + E[m + 1, :] = 0 + E[:, n + 1] = 0 + E[m + 1, n + 1] = 1 + Q[m + 1, n + 1] = 1 + + for i in range(m, 0, -1): + for j in range(n, 0, -1): + E[i, j] = ( + Q[i, j + 1, 0] * E[i, j + 1] + + Q[i + 1, j + 1, 1] * E[i + 1, j + 1] + + Q[i + 1, j, 2] * E[i + 1, j] + ) + + return V[m, n], E[1 : m + 1, 1 : n + 1], Q, E + + +@jit(nopython=True) +def dtw_hessian_prod(theta, Z, Q, E, gamma): + m = Z.shape[0] + n = Z.shape[1] + + V_dot = np.zeros((m + 1, n + 1)) + V_dot[0, 0] = 0 + + Q_dot = np.zeros((m + 2, n + 2, 3)) + for i in range(1, m + 1): + for j in range(1, n + 1): + # theta is indexed starting from 0. + V_dot[i, j] = ( + Z[i - 1, j - 1] + + Q[i, j, 0] * V_dot[i, j - 1] + + Q[i, j, 1] * V_dot[i - 1, j - 1] + + Q[i, j, 2] * V_dot[i - 1, j] + ) + + v = np.array([V_dot[i, j - 1], V_dot[i - 1, j - 1], V_dot[i - 1, j]]) + Q_dot[i, j] = my_min_hessian_product(Q[i, j], v, gamma) + E_dot = np.zeros((m + 2, n + 2)) + + for j in range(n, 0, -1): + for i in range(m, 0, -1): + E_dot[i, j] = ( + Q_dot[i, j + 1, 0] * E[i, j + 1] + + Q[i, j + 1, 0] * E_dot[i, j + 1] + + Q_dot[i + 1, j + 1, 1] * E[i + 1, j + 1] + + Q[i + 1, j + 1, 1] * E_dot[i + 1, j + 1] + + Q_dot[i + 1, j, 2] * E[i + 1, j] + + Q[i + 1, j, 2] * E_dot[i + 1, j] + ) + + return V_dot[m, n], E_dot[1 : m + 1, 1 : n + 1] + + +class PathDTWBatch(Function): + @staticmethod + def forward(ctx, D, gamma): # D.shape: [batch_size, N , N] + batch_size, N, N = D.shape + device = D.device + D_cpu = D.detach().cpu().numpy() + gamma_gpu = torch.FloatTensor([gamma]).to(device) + + grad_gpu = torch.zeros((batch_size, N, N)).to(device) + Q_gpu = torch.zeros((batch_size, N + 2, N + 2, 3)).to(device) + E_gpu = torch.zeros((batch_size, N + 2, N + 2)).to(device) + + for k in range(0, batch_size): # loop over all D in the batch + _, grad_cpu_k, Q_cpu_k, E_cpu_k = dtw_grad(D_cpu[k, :, :], gamma) + grad_gpu[k, :, :] = torch.FloatTensor(grad_cpu_k).to(device) + Q_gpu[k, :, :, :] = torch.FloatTensor(Q_cpu_k).to(device) + E_gpu[k, :, :] = torch.FloatTensor(E_cpu_k).to(device) + ctx.save_for_backward(grad_gpu, D, Q_gpu, E_gpu, gamma_gpu) + return torch.mean(grad_gpu, dim=0) + + @staticmethod + def backward(ctx, grad_output): + device = grad_output.device + grad_gpu, D_gpu, Q_gpu, E_gpu, gamma = ctx.saved_tensors + D_cpu = D_gpu.detach().cpu().numpy() + Q_cpu = Q_gpu.detach().cpu().numpy() + E_cpu = E_gpu.detach().cpu().numpy() + gamma = gamma.detach().cpu().numpy()[0] + Z = grad_output.detach().cpu().numpy() + + batch_size, N, N = D_cpu.shape + Hessian = torch.zeros((batch_size, N, N)).to(device) + for k in range(0, batch_size): + _, hess_k = dtw_hessian_prod( + D_cpu[k, :, :], Z, Q_cpu[k, :, :, :], E_cpu[k, :, :], gamma + ) + Hessian[k : k + 1, :, :] = torch.FloatTensor(hess_k).to(device) + + return Hessian, None + diff --git a/pts/modules/soft_dtw.py b/pts/modules/soft_dtw.py new file mode 100644 index 0000000..3f7c572 --- /dev/null +++ b/pts/modules/soft_dtw.py @@ -0,0 +1,103 @@ +import numpy as np +import torch +from numba import jit +from torch.autograd import Function + + +def pairwise_distances(x, y=None): + """ + Input: x is a Nxd matrix + y is an optional Mxd matrix + Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] + if y is not given then use 'y=x'. + i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 + """ + x_norm = (x ** 2).sum(1).view(-1, 1) + if y is not None: + y_t = torch.transpose(y, 0, 1) + y_norm = (y ** 2).sum(1).view(1, -1) + else: + y_t = torch.transpose(x, 0, 1) + y_norm = x_norm.view(1, -1) + + dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) + return torch.clamp(dist, 0.0, float("inf")) + + +@jit(nopython=True) +def compute_softdtw(D, gamma): + N = D.shape[0] + M = D.shape[1] + R = np.zeros((N + 2, M + 2)) + 1e8 + R[0, 0] = 0 + for j in range(1, M + 1): + for i in range(1, N + 1): + r0 = -R[i - 1, j - 1] / gamma + r1 = -R[i - 1, j] / gamma + r2 = -R[i, j - 1] / gamma + rmax = max(max(r0, r1), r2) + rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) + softmin = -gamma * (np.log(rsum) + rmax) + R[i, j] = D[i - 1, j - 1] + softmin + return R + + +@jit(nopython=True) +def compute_softdtw_backward(D_, R, gamma): + N = D_.shape[0] + M = D_.shape[1] + D = np.zeros((N + 2, M + 2)) + E = np.zeros((N + 2, M + 2)) + D[1 : N + 1, 1 : M + 1] = D_ + E[-1, -1] = 1 + R[:, -1] = -1e8 + R[-1, :] = -1e8 + R[-1, -1] = R[-2, -2] + for j in range(M, 0, -1): + for i in range(N, 0, -1): + a0 = (R[i + 1, j] - R[i, j] - D[i + 1, j]) / gamma + b0 = (R[i, j + 1] - R[i, j] - D[i, j + 1]) / gamma + c0 = (R[i + 1, j + 1] - R[i, j] - D[i + 1, j + 1]) / gamma + a = np.exp(a0) + b = np.exp(b0) + c = np.exp(c0) + E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c + return E[1 : N + 1, 1 : M + 1] + + +class SoftDTWBatch(Function): + @staticmethod + def forward(ctx, D, gamma=1.0): # D.shape: [batch_size, N , N] + dev = D.device + batch_size, N, N = D.shape + gamma = torch.FloatTensor([gamma]).to(dev) + D_ = D.detach().cpu().numpy() + g_ = gamma.item() + + total_loss = 0 + R = torch.zeros((batch_size, N + 2, N + 2)).to(dev) + for k in range(0, batch_size): # loop over all D in the batch + Rk = torch.FloatTensor(compute_softdtw(D_[k, :, :], g_)).to(dev) + R[k : k + 1, :, :] = Rk + total_loss = total_loss + Rk[-2, -2] + ctx.save_for_backward(D, R, gamma) + return total_loss / batch_size + + @staticmethod + def backward(ctx, grad_output): + dev = grad_output.device + D, R, gamma = ctx.saved_tensors + batch_size, N, N = D.shape + D_ = D.detach().cpu().numpy() + R_ = R.detach().cpu().numpy() + g_ = gamma.item() + + E = torch.zeros((batch_size, N, N)).to(dev) + for k in range(batch_size): + Ek = torch.FloatTensor( + compute_softdtw_backward(D_[k, :, :], R_[k, :, :], g_) + ).to(dev) + E[k : k + 1, :, :] = Ek + + return grad_output * E, None + diff --git a/setup.py b/setup.py index c6adc2f..0a61951 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ 'matplotlib', 'python-rapidjson', 'tensorboard', + 'numba', ], test_suite='tests', From 75d1f739922cd7eae63c456997a65bb76236cc35 Mon Sep 17 00:00:00 2001 From: "Dr. Kashif Rasul" Date: Wed, 22 Jul 2020 12:35:18 +0200 Subject: [PATCH 2/2] fix losses --- pts/model/n_beats/n_beats_network.py | 17 +++++++++++++---- pts/modules/loss.py | 16 +++++++++++----- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/pts/model/n_beats/n_beats_network.py b/pts/model/n_beats/n_beats_network.py index 80411c6..1bfece5 100644 --- a/pts/model/n_beats/n_beats_network.py +++ b/pts/model/n_beats/n_beats_network.py @@ -9,7 +9,7 @@ from pts.modules import smape_loss, mape_loss, mase_loss, dilate_loss VALID_N_BEATS_STACK_TYPES = "G", "S", "T" -VALID_LOSS_FUNCTIONS = "sMAPE", "MASE", "MAPE" +VALID_LOSS_FUNCTIONS = "sMAPE", "MASE", "MAPE", "DILATE" def linspace( @@ -273,12 +273,21 @@ def forward( forecast = super().forward(past_target=past_target) if self.loss_function == "sMAPE": - loss = smape_loss(forecast, future_target) + loss = smape_loss( + forecast, future_target, prediction_length=self.prediction_length + ) elif self.loss_function == "MAPE": - loss = mape_loss(forecast, future_target) + loss = mape_loss( + forecast, future_target, prediction_length=self.prediction_length + ) elif self.loss_function == "MASE": loss = mase_loss( - forecast, future_target, past_target, self.periodicity + forecast, + future_target, + past_target, + context_length=self.context_length, + prediction_length=self.prediction_length, + periodicity=self.periodicity, ) elif self.loss_function == "DILATE": loss = dilate_loss(forecast, future_target) diff --git a/pts/modules/loss.py b/pts/modules/loss.py index b361cb7..ed77d6e 100644 --- a/pts/modules/loss.py +++ b/pts/modules/loss.py @@ -5,22 +5,26 @@ from .path_soft_dtw import PathDTWBatch -def smape_loss(forecast: torch.Tensor, future_target: torch.Tensor) -> torch.Tensor: +def smape_loss( + forecast: torch.Tensor, future_target: torch.Tensor, prediction_length: int +) -> torch.Tensor: denominator = (torch.abs(future_target) + torch.abs(forecast)).detach() flag = denominator == 0 - return (200 / self.prediction_length) * torch.mean( + return (200 / prediction_length) * torch.mean( (torch.abs(future_target - forecast) * torch.logical_not(flag)) / (denominator + flag), dim=1, ) -def mape_loss(forecast: torch.Tensor, future_target: torch.Tensor) -> torch.Tensor: +def mape_loss( + forecast: torch.Tensor, future_target: torch.Tensor, prediction_length: int +) -> torch.Tensor: denominator = torch.abs(future_target) flag = denominator == 0 - return (100 / self.prediction_length) * torch.mean( + return (100 / prediction_length) * torch.mean( (torch.abs(future_target - forecast) * torch.logical_not(flag)) / (denominator + flag), dim=1, @@ -31,9 +35,11 @@ def mase_loss( forecast: torch.Tensor, future_target: torch.Tensor, past_target: torch.Tensor, + context_length: int, + prediction_length: int, periodicity: int, ) -> torch.Tensor: - factor = 1 / (self.context_length + self.prediction_length - periodicity) + factor = 1 / (context_length + prediction_length - periodicity) whole_target = torch.cat((past_target, future_target), dim=1) seasonal_error = factor * torch.mean(