Skip to content

Commit b432db8

Browse files
committed
fix connection problem.zoo
1 parent 6b355b4 commit b432db8

File tree

1 file changed

+57
-25
lines changed

1 file changed

+57
-25
lines changed

pina/problem/zoo/inverse_poisson_2d_square.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,49 @@
11
"""Formulation of the inverse Poisson problem in a square domain."""
22

33
import requests
4+
import warnings
45
import torch
56
from io import BytesIO
7+
from requests.exceptions import RequestException
68
from ... import Condition
79
from ... import LabelTensor
810
from ...operator import laplacian
911
from ...domain import CartesianDomain
1012
from ...equation import Equation, FixedValue
1113
from ...problem import SpatialProblem, InverseProblem
14+
from ...utils import custom_warning_format
15+
16+
warnings.formatwarning = custom_warning_format
17+
warnings.filterwarnings("always", category=ResourceWarning)
18+
19+
20+
def _load_tensor_from_url(url, labels):
21+
"""
22+
Downloads a tensor file from a URL and wraps it in a LabelTensor.
23+
24+
This function fetches a `.pth` file containing tensor data, extracts it,
25+
and returns it as a LabelTensor using the specified labels. If the file
26+
cannot be retrieved (e.g., no internet connection), a warning is issued
27+
and None is returned.
28+
29+
:param str url: URL to the remote `.pth` tensor file.
30+
:param list[str] | tuple[str] labels: Labels for the resulting LabelTensor.
31+
:return: A LabelTensor object if successful, otherwise None.
32+
:rtype: LabelTensor | None
33+
"""
34+
try:
35+
response = requests.get(url, timeout=5)
36+
tensor = torch.load(
37+
BytesIO(response.content), weights_only=False
38+
).tensor.detach()
39+
return LabelTensor(tensor, labels)
40+
except RequestException as e:
41+
warnings.warn(
42+
f"Could not download data from '{url}'. "
43+
f"Reason: {e}. Skipping data loading.",
44+
ResourceWarning,
45+
)
46+
return None
1247

1348

1449
def laplace_equation(input_, output_, params_):
@@ -29,35 +64,13 @@ def laplace_equation(input_, output_, params_):
2964
return delta_u - force_term
3065

3166

32-
# URL of the file
33-
url = "https://github.com/mathLab/PINA/raw/refs/heads/master/tutorials/tutorial7/data/pts_0.5_0.5"
34-
# Download the file
35-
response = requests.get(url)
36-
response.raise_for_status()
37-
file_like_object = BytesIO(response.content)
38-
# Set the data
39-
input_data = LabelTensor(
40-
torch.load(file_like_object, weights_only=False).tensor.detach(),
41-
["x", "y", "mu1", "mu2"],
42-
)
43-
44-
# URL of the file
45-
url = "https://github.com/mathLab/PINA/raw/refs/heads/master/tutorials/tutorial7/data/pinn_solution_0.5_0.5"
46-
# Download the file
47-
response = requests.get(url)
48-
response.raise_for_status()
49-
file_like_object = BytesIO(response.content)
50-
# Set the data
51-
output_data = LabelTensor(
52-
torch.load(file_like_object, weights_only=False).tensor.detach(), ["u"]
53-
)
54-
55-
5667
class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem):
5768
r"""
5869
Implementation of the inverse 2-dimensional Poisson problem in the square
5970
domain :math:`[0, 1] \times [0, 1]`,
6071
with unknown parameter domain :math:`[-1, 1] \times [-1, 1]`.
72+
The `"data"` condition is added only if the required files are
73+
downloaded successfully.
6174
6275
:Example:
6376
>>> problem = InversePoisson2DSquareProblem()
@@ -83,5 +96,24 @@ class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem):
8396
"g3": Condition(domain="g3", equation=FixedValue(0.0)),
8497
"g4": Condition(domain="g4", equation=FixedValue(0.0)),
8598
"D": Condition(domain="D", equation=Equation(laplace_equation)),
86-
"data": Condition(input=input_data, target=output_data),
8799
}
100+
101+
def __init__(self):
102+
"""
103+
Initialization of the :class:`InversePoisson2DSquareProblem` class.
104+
105+
:param alpha: Parameter of the forcing term.
106+
:type alpha: float | int
107+
"""
108+
super().__init__()
109+
110+
input_url = "https://github.com/mathLab/PINA/raw/refs/heads/master/tutorials/tutorial7/data/pts_0.5_0.5"
111+
output_url = "https://github.com/mathLab/PINA/raw/refs/heads/master/tutorials/tutorial7/data/pinn_solution_0.5_0.5"
112+
113+
input_data = _load_tensor_from_url(input_url, ["x", "y", "mu1", "mu2"])
114+
output_data = _load_tensor_from_url(output_url, ["u"])
115+
116+
if input_data is not None and output_data is not None:
117+
self.conditions["data"] = Condition(
118+
input=input_data, target=output_data
119+
)

0 commit comments

Comments
 (0)