Skip to content

added dilate loss #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pts/model/n_beats/n_beats_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class NBEATSEnsembleEstimator(Estimator):
meta_loss_function
The different 'loss_function' (also known as metric) to use for training the models.
Unlike other models in GluonTS this network does not use a distribution.
Default and recommended value: ["sMAPE", "MASE", "MAPE"]
Default and recommended value: ["sMAPE", "MASE", "MAPE", "DILATE"]
meta_bagging_size
The number of models that share the parameter combination of 'context_length'
and 'loss_function'. Each of these models gets a different initialization random initialization.
Expand Down
67 changes: 17 additions & 50 deletions pts/model/n_beats/n_beats_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import torch.nn.functional as F

from pts.feature import get_seasonality
from pts.modules import smape_loss, mape_loss, mase_loss, dilate_loss

VALID_N_BEATS_STACK_TYPES = "G", "S", "T"
VALID_LOSS_FUNCTIONS = "sMAPE", "MASE", "MAPE"
VALID_LOSS_FUNCTIONS = "sMAPE", "MASE", "MAPE", "DILATE"


def linspace(
Expand Down Expand Up @@ -251,51 +252,6 @@ def forward(self, past_target: torch.Tensor):
_, last_forecast = self.net_blocks[-1](backcast)
return forecast + last_forecast

def smape_loss(
self, forecast: torch.Tensor, future_target: torch.Tensor
) -> torch.Tensor:
denominator = (torch.abs(future_target) + torch.abs(forecast)).detach()
flag = denominator == 0

return (200 / self.prediction_length) * torch.mean(
(torch.abs(future_target - forecast) * torch.logical_not(flag)) / (denominator + flag),
dim=1,
)

def mape_loss(
self, forecast: torch.Tensor, future_target: torch.Tensor
) -> torch.Tensor:
denominator = torch.abs(future_target)
flag = denominator == 0

return (100 / self.prediction_length) * torch.mean(
(torch.abs(future_target - forecast) * torch.logical_not(flag)) / (denominator + flag),
dim=1,
)

def mase_loss(
self,
forecast: torch.Tensor,
future_target: torch.Tensor,
past_target: torch.Tensor,
periodicity: int,
) -> torch.Tensor:
factor = 1 / (self.context_length + self.prediction_length - periodicity)

whole_target = torch.cat((past_target, future_target), dim=1)
seasonal_error = factor * torch.mean(
torch.abs(
whole_target[:, periodicity:, ...]
- whole_target[:, :-periodicity:, ...]
),
dim=1,
)
flag = seasonal_error == 0

return (torch.mean(torch.abs(future_target - forecast), dim=1) * torch.logical_not(flag)) / (
seasonal_error + flag
)


class NBEATSTrainingNetwork(NBEATSNetwork):
def __init__(self, loss_function: str, freq: str, *args, **kwargs) -> None:
Expand All @@ -317,13 +273,24 @@ def forward(
forecast = super().forward(past_target=past_target)

if self.loss_function == "sMAPE":
loss = self.smape_loss(forecast, future_target)
loss = smape_loss(
forecast, future_target, prediction_length=self.prediction_length
)
elif self.loss_function == "MAPE":
loss = self.mape_loss(forecast, future_target)
loss = mape_loss(
forecast, future_target, prediction_length=self.prediction_length
)
elif self.loss_function == "MASE":
loss = self.mase_loss(
forecast, future_target, past_target, self.periodicity
loss = mase_loss(
forecast,
future_target,
past_target,
context_length=self.context_length,
prediction_length=self.prediction_length,
periodicity=self.periodicity,
)
elif self.loss_function == "DILATE":
loss = dilate_loss(forecast, future_target)
else:
raise ValueError(
f"Invalid value {self.loss_function} for argument loss_function."
Expand Down
3 changes: 3 additions & 0 deletions pts/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@
from .flows import RealNVP, MAF
from .lambda_layer import LambdaLayer
from .scaler import MeanScaler, NOPScaler
from .soft_dtw import SoftDTWBatch
from .path_soft_dtw import PathDTWBatch
from .loss import smape_loss, mape_loss, mase_loss, dilate_loss
81 changes: 81 additions & 0 deletions pts/modules/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
import torch.nn as nn

from .soft_dtw import SoftDTWBatch, pairwise_distances
from .path_soft_dtw import PathDTWBatch


def smape_loss(
forecast: torch.Tensor, future_target: torch.Tensor, prediction_length: int
) -> torch.Tensor:
denominator = (torch.abs(future_target) + torch.abs(forecast)).detach()
flag = denominator == 0

return (200 / prediction_length) * torch.mean(
(torch.abs(future_target - forecast) * torch.logical_not(flag))
/ (denominator + flag),
dim=1,
)


def mape_loss(
forecast: torch.Tensor, future_target: torch.Tensor, prediction_length: int
) -> torch.Tensor:
denominator = torch.abs(future_target)
flag = denominator == 0

return (100 / prediction_length) * torch.mean(
(torch.abs(future_target - forecast) * torch.logical_not(flag))
/ (denominator + flag),
dim=1,
)


def mase_loss(
forecast: torch.Tensor,
future_target: torch.Tensor,
past_target: torch.Tensor,
context_length: int,
prediction_length: int,
periodicity: int,
) -> torch.Tensor:
factor = 1 / (context_length + prediction_length - periodicity)

whole_target = torch.cat((past_target, future_target), dim=1)
seasonal_error = factor * torch.mean(
torch.abs(
whole_target[:, periodicity:, ...] - whole_target[:, :-periodicity:, ...]
),
dim=1,
)
flag = seasonal_error == 0

return (
torch.mean(torch.abs(future_target - forecast), dim=1) * torch.logical_not(flag)
) / (seasonal_error + flag)


def dilate_loss(forecast, future_target, alpha=0.5, gamma=0.01):
batch_size, N_output = forecast.shape

pairwise_distance = torch.zeros((batch_size, N_output, N_output)).to(
forecast.device
)
for k in range(batch_size):
Dk = pairwise_distances(
future_target[k, :].view(-1, 1), forecast[k, :].view(-1, 1)
)
pairwise_distance[k : k + 1, :, :] = Dk

softdtw_batch = SoftDTWBatch.apply
loss_shape = softdtw_batch(pairwise_distance, gamma)

path_dtw = PathDTWBatch.apply
path = path_dtw(pairwise_distance, gamma)

omega = pairwise_distances(torch.arange(1, N_output + 1).view(N_output, 1)).to(
forecast.device
)
loss_temporal = torch.sum(path * omega) / (N_output * N_output)

return alpha * loss_shape + (1 - alpha) * loss_temporal
144 changes: 144 additions & 0 deletions pts/modules/path_soft_dtw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import numpy as np
import torch
from torch.autograd import Function
from numba import jit


@jit(nopython=True)
def my_max(x, gamma):
# use the log-sum-exp trick
max_x = np.max(x)
exp_x = np.exp((x - max_x) / gamma)
Z = np.sum(exp_x)
return gamma * np.log(Z) + max_x, exp_x / Z


@jit(nopython=True)
def my_min(x, gamma):
min_x, argmax_x = my_max(-x, gamma)
return -min_x, argmax_x


@jit(nopython=True)
def my_max_hessian_product(p, z, gamma):
return (p * z - p * np.sum(p * z)) / gamma


@jit(nopython=True)
def my_min_hessian_product(p, z, gamma):
return -my_max_hessian_product(p, z, gamma)


@jit(nopython=True)
def dtw_grad(theta, gamma):
m = theta.shape[0]
n = theta.shape[1]
V = np.zeros((m + 1, n + 1))
V[:, 0] = 1e10
V[0, :] = 1e10
V[0, 0] = 0

Q = np.zeros((m + 2, n + 2, 3))

for i in range(1, m + 1):
for j in range(1, n + 1):
# theta is indexed starting from 0.
v, Q[i, j] = my_min(
np.array([V[i, j - 1], V[i - 1, j - 1], V[i - 1, j]]), gamma
)
V[i, j] = theta[i - 1, j - 1] + v

E = np.zeros((m + 2, n + 2))
E[m + 1, :] = 0
E[:, n + 1] = 0
E[m + 1, n + 1] = 1
Q[m + 1, n + 1] = 1

for i in range(m, 0, -1):
for j in range(n, 0, -1):
E[i, j] = (
Q[i, j + 1, 0] * E[i, j + 1]
+ Q[i + 1, j + 1, 1] * E[i + 1, j + 1]
+ Q[i + 1, j, 2] * E[i + 1, j]
)

return V[m, n], E[1 : m + 1, 1 : n + 1], Q, E


@jit(nopython=True)
def dtw_hessian_prod(theta, Z, Q, E, gamma):
m = Z.shape[0]
n = Z.shape[1]

V_dot = np.zeros((m + 1, n + 1))
V_dot[0, 0] = 0

Q_dot = np.zeros((m + 2, n + 2, 3))
for i in range(1, m + 1):
for j in range(1, n + 1):
# theta is indexed starting from 0.
V_dot[i, j] = (
Z[i - 1, j - 1]
+ Q[i, j, 0] * V_dot[i, j - 1]
+ Q[i, j, 1] * V_dot[i - 1, j - 1]
+ Q[i, j, 2] * V_dot[i - 1, j]
)

v = np.array([V_dot[i, j - 1], V_dot[i - 1, j - 1], V_dot[i - 1, j]])
Q_dot[i, j] = my_min_hessian_product(Q[i, j], v, gamma)
E_dot = np.zeros((m + 2, n + 2))

for j in range(n, 0, -1):
for i in range(m, 0, -1):
E_dot[i, j] = (
Q_dot[i, j + 1, 0] * E[i, j + 1]
+ Q[i, j + 1, 0] * E_dot[i, j + 1]
+ Q_dot[i + 1, j + 1, 1] * E[i + 1, j + 1]
+ Q[i + 1, j + 1, 1] * E_dot[i + 1, j + 1]
+ Q_dot[i + 1, j, 2] * E[i + 1, j]
+ Q[i + 1, j, 2] * E_dot[i + 1, j]
)

return V_dot[m, n], E_dot[1 : m + 1, 1 : n + 1]


class PathDTWBatch(Function):
@staticmethod
def forward(ctx, D, gamma): # D.shape: [batch_size, N , N]
batch_size, N, N = D.shape
device = D.device
D_cpu = D.detach().cpu().numpy()
gamma_gpu = torch.FloatTensor([gamma]).to(device)

grad_gpu = torch.zeros((batch_size, N, N)).to(device)
Q_gpu = torch.zeros((batch_size, N + 2, N + 2, 3)).to(device)
E_gpu = torch.zeros((batch_size, N + 2, N + 2)).to(device)

for k in range(0, batch_size): # loop over all D in the batch
_, grad_cpu_k, Q_cpu_k, E_cpu_k = dtw_grad(D_cpu[k, :, :], gamma)
grad_gpu[k, :, :] = torch.FloatTensor(grad_cpu_k).to(device)
Q_gpu[k, :, :, :] = torch.FloatTensor(Q_cpu_k).to(device)
E_gpu[k, :, :] = torch.FloatTensor(E_cpu_k).to(device)
ctx.save_for_backward(grad_gpu, D, Q_gpu, E_gpu, gamma_gpu)
return torch.mean(grad_gpu, dim=0)

@staticmethod
def backward(ctx, grad_output):
device = grad_output.device
grad_gpu, D_gpu, Q_gpu, E_gpu, gamma = ctx.saved_tensors
D_cpu = D_gpu.detach().cpu().numpy()
Q_cpu = Q_gpu.detach().cpu().numpy()
E_cpu = E_gpu.detach().cpu().numpy()
gamma = gamma.detach().cpu().numpy()[0]
Z = grad_output.detach().cpu().numpy()

batch_size, N, N = D_cpu.shape
Hessian = torch.zeros((batch_size, N, N)).to(device)
for k in range(0, batch_size):
_, hess_k = dtw_hessian_prod(
D_cpu[k, :, :], Z, Q_cpu[k, :, :, :], E_cpu[k, :, :], gamma
)
Hessian[k : k + 1, :, :] = torch.FloatTensor(hess_k).to(device)

return Hessian, None

Loading