-
Notifications
You must be signed in to change notification settings - Fork 75
Fix adaptive refinement #571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
90701f6
Fix adaptive refinement
FilippoOlivo eb18138
Reimplement refinement
FilippoOlivo 518264d
Fix doc
FilippoOlivo 1a05642
Fix doc
FilippoOlivo abb8718
Fix test
FilippoOlivo 1e9e920
clean code
360943e
fix tests
36d4924
Fix codacy
FilippoOlivo 0666cf4
Fix docstring
FilippoOlivo 9e69796
add white space
dario-coscia eb0b17c
update doc
dario-coscia 2a316a8
Fixes
FilippoOlivo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
.../callback/adaptive_refinment_callback.rst → ...rst/callback/refinement/r3_refinement.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
Refinments callbacks | ||
======================= | ||
|
||
.. currentmodule:: pina.callback.adaptive_refinement_callback | ||
.. currentmodule:: pina.callback.refinement | ||
.. autoclass:: R3Refinement | ||
:members: | ||
:show-inheritance: |
7 changes: 7 additions & 0 deletions
7
docs/source/_rst/callback/refinement/refinement_interface.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Refinement Interface | ||
======================= | ||
|
||
.. currentmodule:: pina.callback.refinement | ||
.. autoclass:: RefinementInterface | ||
:members: | ||
:show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
""" | ||
Module for Pina Refinement callbacks. | ||
""" | ||
|
||
__all__ = [ | ||
"RefinementInterface", | ||
"R3Refinement", | ||
] | ||
|
||
from .refinement_interface import RefinementInterface | ||
from .r3_refinement import R3Refinement |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""Module for the R3Refinement callback.""" | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn.modules.loss import _Loss | ||
from .refinement_interface import RefinementInterface | ||
from ...label_tensor import LabelTensor | ||
from ...utils import check_consistency | ||
from ...loss import LossInterface | ||
|
||
|
||
class R3Refinement(RefinementInterface): | ||
""" | ||
PINA Implementation of an R3 Refinement Callback. | ||
""" | ||
|
||
def __init__( | ||
self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None | ||
): | ||
""" | ||
This callback implements the R3 (Retain-Resample-Release) routine for | ||
sampling new points based on adaptive search. | ||
The algorithm incrementally accumulates collocation points in regions | ||
of high PDE residuals, and releases those with low residuals. | ||
Points are sampled uniformly in all regions where sampling is needed. | ||
|
||
.. seealso:: | ||
|
||
Original Reference: Daw, Arka, et al. *Mitigating Propagation | ||
Failures in Physics-informed Neural Networks | ||
using Retain-Resample-Release (R3) Sampling. (2023)*. | ||
DOI: `10.48550/arXiv.2207.02338 | ||
<https://doi.org/10.48550/arXiv.2207.02338>`_ | ||
|
||
:param int sample_every: Frequency for sampling. | ||
:param loss: Loss function | ||
:type loss: LossInterface | ~torch.nn.modules.loss._Loss | ||
:param condition_to_update: The conditions to update during the | ||
refinement process. If None, all conditions with a conditions will | ||
be updated. Default is None. | ||
:type condition_to_update: list(str) | tuple(str) | str | ||
:raises ValueError: If the condition_to_update is not a string or | ||
iterable of strings. | ||
:raises TypeError: If the residual_loss is not a subclass of | ||
torch.nn.Module. | ||
|
||
|
||
Example: | ||
>>> r3_callback = R3Refinement(sample_every=5) | ||
""" | ||
super().__init__(sample_every, condition_to_update) | ||
# check consistency loss | ||
check_consistency(residual_loss, (LossInterface, _Loss), subclass=True) | ||
self.loss_fn = residual_loss(reduction="none") | ||
|
||
def sample(self, current_points, condition_name, solver): | ||
""" | ||
Sample new points based on the R3 refinement strategy. | ||
|
||
:param current_points: Current points in the domain. | ||
:param condition_name: Name of the condition to update. | ||
:param PINNInterface solver: The solver object. | ||
:return: New points sampled based on the R3 strategy. | ||
:rtype: LabelTensor | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please, add parameters' type to this docstring. |
||
# Compute residuals for the given condition (average over fields) | ||
condition = solver.problem.conditions[condition_name] | ||
target = solver.compute_residual( | ||
current_points.requires_grad_(True), condition.equation | ||
) | ||
residuals = self.loss_fn(target, torch.zeros_like(target)).mean( | ||
dim=tuple(range(1, target.ndim)) | ||
) | ||
|
||
# Prepare new points | ||
labels = current_points.labels | ||
domain_name = solver.problem.conditions[condition_name].domain | ||
domain = solver.problem.domains[domain_name] | ||
num_old_points = self.initial_population_size[condition_name] | ||
mask = (residuals > residuals.mean()).flatten() | ||
|
||
if mask.any(): # Use high-residual points | ||
pts = current_points[mask] | ||
pts.labels = labels | ||
retain_pts = len(pts) | ||
samples = domain.sample(num_old_points - retain_pts, "random") | ||
return LabelTensor.cat([pts, samples]) | ||
return domain.sample(num_old_points, "random") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.