Skip to content

Commit 787b0df

Browse files
committed
feat(ppsci): support data_effient_nopt
1 parent f6a6192 commit 787b0df

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

examples/data_efficient_nopt/config/data_efficient_nopt_fno_poisson.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ run_name: r0
77
use_ddp: False
88
config: pois-64-pretrain-e1_20_m3
99
sweep_id: ''
10+
logdir: exp
1011

1112
train_config:
1213
default: &DEFAULT
13-
num_data_workers: 0
14+
num_data_workers: 4
1415
# model
1516
model: 'fno'
1617
depth: 5
@@ -289,7 +290,7 @@ infer_config:
289290
scales_path: 'data/possion_64/poisson_64_e5_15_train_scale.npy'
290291
ckpt_path: checkpoint/finetune_b01_m0_n8192.pdparams
291292

292-
num_data_workers: 0
293+
num_data_workers: 4
293294
subsample: 1
294295
num_demos: 0
295296
shuffle: False

examples/data_efficient_nopt/data_efficient_nopt.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
# limitations under the License.
1414

1515
import gc
16-
import logging
1716
import os
1817
import random
1918
from argparse import Namespace
2019
from collections import OrderedDict
20+
from os import path as osp
2121

2222
import hydra
2323
import numpy as np
@@ -44,8 +44,7 @@
4444
from ppsci.arch.data_efficient_nopt_model import param_norm
4545
from ppsci.data.dataset.data_efficient_nopt_dataset import MixedDatasetLoader
4646
from ppsci.data.dataset.data_efficient_nopt_dataset import PoisHelmDatasetLoader
47-
48-
logger = logging.getLogger(__name__)
47+
from ppsci.utils import logger
4948

5049

5150
class Trainer:
@@ -437,7 +436,8 @@ def single_dset_val(self, subset, logs, cutoff=40):
437436
del temp_loader
438437
break
439438
count += 1
440-
input, label = data
439+
input = data[0]
440+
label = data[1]
441441

442442
# unsupervised pretrain
443443
if self.params.mode == "train":
@@ -699,6 +699,7 @@ def inference(config):
699699
config_name="data_efficient_nopt_fno_poisson",
700700
)
701701
def main(config: DictConfig):
702+
logger.init_logger("ppsci", osp.join(config.logdir, f"{config.mode}.log"), "info")
702703
if config.mode == "train" or config.mode == "finetune":
703704
train(config)
704705
elif config.mode == "infer":

ppsci/data/dataset/data_efficient_nopt_dataset.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# refs: https://github.com/delta-lab-ai/data_efficient_nopt
1616

1717
import glob
18-
import logging
1918
import os
2019
from typing import Iterator
2120
from typing import TypeVar
@@ -29,7 +28,7 @@
2928
from paddle.io import RandomSampler
3029
from paddle.io import Sampler
3130

32-
logger = logging.getLogger(__name__)
31+
from ppsci.utils import logger
3332

3433
__all__ = [
3534
"MultisetSampler",
@@ -556,12 +555,12 @@ def __init__(self, params, location, transform, train):
556555
if self.train:
557556
if hasattr(self.params, "train_rand_idx_path"):
558557
self.train_rand_idx = np.load(self.params.train_rand_idx_path)
559-
logging.info("Randomizing train dataset using given random index path")
558+
logger.info("Randomizing train dataset using given random index path")
560559
else:
561560
self.train_rand_idx = range(self.data.shape[0])
562561
self.train_rand_idx = self.train_rand_idx[self.pt_idxs[0] : self.pt_idxs[1]]
563562
self.data = self.data[()][self.train_rand_idx, ...]
564-
logging.info(
563+
logger.info(
565564
"Getting only data idx for training set for length: {}".format(
566565
len(self.train_rand_idx)
567566
)
@@ -576,7 +575,7 @@ def __init__(self, params, location, transform, train):
576575
def _get_files_stats(self):
577576
self.file = self.location
578577
with h5py.File(self.file, "r") as _f:
579-
logging.info("Getting file stats from {}".format(self.file))
578+
logger.info("Getting file stats from {}".format(self.file))
580579
if len(_f["fields"].shape) == 4:
581580
self.n_demos = None
582581
self.n_samples = _f["fields"].shape[0]
@@ -599,7 +598,7 @@ def _get_files_stats(self):
599598
self.pt_split = self.params.pt_split
600599
else:
601600
self.pt_split = [0.9, 0.1]
602-
logging.info(
601+
logger.info(
603602
"Split training set into {} for pretrain, {} for train. ".format(
604603
self.pt_split[0], self.pt_split[1]
605604
)
@@ -619,7 +618,7 @@ def _get_files_stats(self):
619618
)
620619
self.n_samples /= self.subsample
621620
self.n_samples = int(self.n_samples)
622-
logging.info(
621+
logger.info(
623622
"Found data at path {}. Number of examples: {}. Image Shape: {} x {}".format(
624623
self.location, self.n_samples, self.img_shape_x, self.img_shape_y
625624
)
@@ -631,12 +630,12 @@ def _get_files_stats(self):
631630
measure_x = self.scales[-2] / self.img_shape_x
632631
measure_y = self.scales[-1] / self.img_shape_y
633632
self.measure = measure_x * measure_y
634-
logging.info(
633+
logger.info(
635634
"Scales for PDE are (source, tensor, sol, domain): {}".format(
636635
self.scales
637636
)
638637
)
639-
logging.info(
638+
logger.info(
640639
"Measure of the set is lx/nx * ly/ny = {}/{} * {}/{}".format(
641640
self.scales[-2], self.img_shape_x, self.scales[-1], self.img_shape_y
642641
)

0 commit comments

Comments
 (0)