Skip to content

Commit 360943e

Browse files
author
Monthly Tag bot
committed
fix tests
1 parent 1e9e920 commit 360943e

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

pina/callback/refinement/refinement_interface.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lightning.pytorch import Callback
88
from ...utils import check_consistency
99
from ...solver.physics_informed_solver import PINNInterface
10+
from ...condition import DomainEquationCondition
1011

1112

1213
class RefinementInterface(Callback, metaclass=ABCMeta):
@@ -44,7 +45,10 @@ def on_train_start(self, trainer, solver):
4445
"""
4546
# check we have valid conditions names
4647
if self._condition_to_update is None:
47-
self._condition_to_update = list(solver.problem.conditions.keys())
48+
self._condition_to_update = [
49+
name for name, cond in solver.problem.conditions.items()
50+
if hasattr(cond, "domain")
51+
]
4852

4953
for cond in self._condition_to_update:
5054
if cond not in solver.problem.conditions:

tests/test_callback/test_adaptive_refinement_callback.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,28 @@ def test_constructor():
3131
R3Refinement(sample_every=10, condition_to_update=3)
3232

3333

34-
def test_sample():
34+
@pytest.mark.parametrize(
35+
"condition_to_update", [["D", "g1"], ["D", "g1", "g2", "g3", "g4"]]
36+
)
37+
def test_sample(condition_to_update):
3538
trainer = Trainer(
3639
solver=solver,
37-
callbacks=[R3Refinement(sample_every=1)],
40+
callbacks=[
41+
R3Refinement(
42+
sample_every=1, condition_to_update=condition_to_update
43+
)
44+
],
3845
accelerator="cpu",
3946
max_epochs=5,
4047
)
4148
before_n_points = {
42-
loc: len(pts) for loc, pts in trainer.solver.problem.input_pts.items()
49+
loc: len(trainer.solver.problem.input_pts[loc])
50+
for loc in condition_to_update
4351
}
4452
trainer.train()
4553
after_n_points = {
46-
loc: len(pts)
47-
for loc, pts in trainer.data_module.train_dataset.input.items()
54+
loc: len(trainer.data_module.train_dataset.input[loc])
55+
for loc in condition_to_update
4856
}
4957
assert before_n_points == trainer.callbacks[0].initial_population_size
5058
assert before_n_points == after_n_points

0 commit comments

Comments
 (0)