Skip to content

Commit cdcde1a

Browse files
committed
Fix progress bar
1 parent 4f7a0fa commit cdcde1a

4 files changed

Lines changed: 46 additions & 6 deletions

File tree

deepecho/models/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,17 @@
66
from deepecho.sequences import assemble_sequences
77

88

9+
def _format_score(score):
10+
"""Format a score as a fixed-length string ``±XX.XX``.
11+
12+
Values are clipped to the range ``[-99.99, +99.99]`` so the result
13+
is always exactly 6 characters.
14+
"""
15+
score = max(-99.99, min(99.99, score))
16+
sign = '+' if score >= 0 else '-'
17+
return f'{sign}{abs(score):05.2f}'
18+
19+
920
class DeepEcho:
1021
"""The base class for DeepEcho models."""
1122

deepecho/models/basic_gan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from tqdm import tqdm
99

10-
from deepecho.models.base import DeepEcho
10+
from deepecho.models.base import DeepEcho, _format_score
1111

1212
LOGGER = logging.getLogger(__name__)
1313

@@ -547,7 +547,10 @@ def fit_sequences(self, sequences, context_types, data_types):
547547
if self._verbose:
548548
d_loss = discriminator_score.item()
549549
g_loss = generator_score.item()
550-
iterator.set_description(f'Epoch {epoch + 1} | D Loss {d_loss} | G Loss {g_loss}')
550+
iterator.set_description(
551+
f'Epoch {epoch + 1} | D Loss {_format_score(d_loss)}'
552+
f' | G Loss {_format_score(g_loss)}'
553+
)
551554

552555
def sample_sequence(self, context, sequence_length=None):
553556
"""Sample a single sequence conditioned on context.

deepecho/models/par.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from tqdm import tqdm
99

10-
from deepecho.models.base import DeepEcho
10+
from deepecho.models.base import DeepEcho, _format_score
1111

1212
LOGGER = logging.getLogger(__name__)
1313

@@ -336,8 +336,8 @@ def fit_sequences(self, sequences, context_types, data_types):
336336

337337
iterator = tqdm(range(self.epochs), disable=(not self.verbose))
338338
if self.verbose:
339-
pbar_description = 'Loss ({loss:.3f})'
340-
iterator.set_description(pbar_description.format(loss=0))
339+
pbar_description = 'Loss ({loss})'
340+
iterator.set_description(pbar_description.format(loss=_format_score(0)))
341341

342342
# Reset loss_values dataframe
343343
self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])
@@ -364,7 +364,7 @@ def fit_sequences(self, sequences, context_types, data_types):
364364
self.loss_values = epoch_loss_df
365365

366366
if self.verbose:
367-
iterator.set_description(pbar_description.format(loss=loss.item()))
367+
iterator.set_description(pbar_description.format(loss=_format_score(loss.item())))
368368

369369
optimizer.step()
370370

tests/unit/test_base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Unit tests for the ``base`` module."""
2+
3+
import pytest
4+
5+
from deepecho.models.base import _format_score
6+
7+
8+
@pytest.mark.parametrize(
9+
'score, expected',
10+
[
11+
(0, '+00.00'),
12+
(1.233434, '+01.23'),
13+
(-0.93, '-00.93'),
14+
(0.01, '+00.01'),
15+
(-1.21, '-01.21'),
16+
(99.99, '+99.99'),
17+
(-99.99, '-99.99'),
18+
(150, '+99.99'),
19+
(-200, '-99.99'),
20+
],
21+
)
22+
def test__format_score(score, expected):
23+
"""Test the ``_format_score`` method."""
24+
result = _format_score(score)
25+
assert result == expected
26+
assert len(result) == 6

0 commit comments

Comments
 (0)