diff --git a/maslibpy/reasoning/mathematical.py b/maslibpy/reasoning/mathematical.py index 7f63c04..affa3b8 100644 --- a/maslibpy/reasoning/mathematical.py +++ b/maslibpy/reasoning/mathematical.py @@ -18,7 +18,7 @@ class Mathematical(): def _softmax(x: np.ndarray) -> np.ndarray: """Compute softmax values for each set of logits.""" exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) - return exp_x / np.sum(exp_x, axis=-1, keepdims=True) + 1e-8 + return exp_x / (np.sum(exp_x, axis=-1, keepdims=True) + 1e-8) def __init__(self, use_gpu: bool = True, model_weights: dict = None): self.device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu')