-
Notifications
You must be signed in to change notification settings - Fork 220
【Hackathon 8th No.16】 data_efficient_nopt 论文复现 #1111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
Thanks for your contribution! |
如果有可复现的精度结果,可以日志截图到github+上传log,这边可以开始测试 |
ok,已经参考实现了paddle的gaussian_blur |
poisson fno推理结果,采用官方提供权重.
|
helmholtz_64 fno和possion_64 fno一致,采用相同模型结构。 |
@@ -0,0 +1,49 @@ | |||
import paddle | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
清理下torch相关内容
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个脚本用于ckpt权重转换,因为目前没有资源去完整预训练一个模型。
@@ -0,0 +1,13 @@ | |||
# Automatically generated by https://github.com/damnever/pigar. | |||
|
|||
adan-pytorch==0.1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done. thx.
""" | ||
loss functions | ||
# """ | ||
# import logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要清理注释内容,准备合入
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done. thx.
|
||
|
||
def get_forcing(S): | ||
# x1 = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done. thx.
@@ -0,0 +1,40 @@ | |||
# Usage | |||
|
|||
## 1. Data Download |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/development/#3
文档需要按照复现指南进行编写(ReadME 改为 doc的形式)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文档部分完成,请问一下结果展示部分怎么写?
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
from timm.models.layers import DropPath |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处的torch库会报错:
Traceback (most recent call last):
File "/workspace/PaddleScience_repo/data_efficient_nopt/examples/data_efficient_nopt/inference_fno_helmholtz_poisson.py", line 17, in <module>
from models.fno import build_fno
File "/workspace/PaddleScience_repo/data_efficient_nopt/examples/data_efficient_nopt/models/fno.py", line 9, in <module>
from timm.models.layers import DropPath
ModuleNotFoundError: No module named 'timm'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已移除相关依赖
目前模型文件夹需要迁移到arch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fix
|
||
论文通过以下方案解决上述提到的问题: | ||
|
||
1. 无监督预训练 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文档对科学计算案例是至关重要的,对于案例和代码的可传播性非常关键
需要写的更详细一些(参考drivaernetpluplu的文档,写的非常详细),文档出现的图片打包发给我,我帮你制作链接
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充
k = k.replace("running_var", "_variance") | ||
k = k.replace("running_mean", "_mean") | ||
k = k.replace("module.", "") | ||
# 添加到飞桨权重字典中 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
中文注释需要清理一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
and len(self.masking) == 2 | ||
): # and self.masking[1] > 0.: | ||
mask = self.mask_generator() | ||
# return x, file_idx, paddle.to_tensor(self.subset_dict[self.sub_dsets[file_idx].get_name()]), bcs, y, mask, x_blur |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要清理注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@@ -0,0 +1,376 @@ | |||
default: &DEFAULT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分yaml文件感觉没办法再进一步优化了,都是一些训练相关的不同配置参数,主体部分我放在了data_efficient_nopt.yaml中
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
paddle.set_device(device) | ||
|
||
# Modify params | ||
params["batch_size"] = int(params.batch_size // world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参数需要整理到对应的yaml文件中,以提升代码可读性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
from ruamel.yaml.comments import CommentedMap as ruamelDict | ||
from scipy.stats import linregress | ||
from tqdm import tqdm | ||
from utils import logging_utils |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
from tqdm import tqdm | ||
|
||
|
||
def _get_act(activation): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
激活函数的编码是否可以复用以下文件的代码, 使得复现代码更简洁、紧凑
https://github.com/PaddlePaddle/PaddleScience/blob/develop/ppsci/arch/activation.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
# Update the file paths in `cexamples/data_efficient_nopt/config/data_efficient_nopt.yaml`, specify to mode in `train`, and then specify to `train_path`, `val_path`, `test_path`, `scales_path` and `train_rand_idx_path` | ||
|
||
# pretrain or finetune, for possion_64 or helmholtz_64. | ||
# specify config_name to fno_possion using `data_efficient_nopt_fno_poisson`, or to fno_helmholtz using `data_efficient_nopt_fno_helmholtz` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要补充复现的精度指标
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
|
||
=== "模型评估命令" | ||
|
||
暂无 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要补充模型推理指标,补充案例Checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
|
||
下方展示了部分实验结果: | ||
|
||
| Model | Checkpoint | **$RMSE$** | **RMSE (normalized)$** | **R2** | **Slope** | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
移动到开头
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
|
||
总而言之,这篇论文提出了一种创新且高效的神经算子学习框架,通过无监督预训练在大量廉价的无标签物理数据上学习通用表示,并通过情境学习在推理阶段利用少量相似案例来提升OOD泛化能力。这一框架显著降低了对昂贵模拟数据的需求,并提高了模型在复杂物理问题中的适应性和泛化性,为科学机器学习的数据高效发展开辟了新途径。 | ||
|
||
下方展示了部分实验结果: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要补充可视化的结果对比
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
源码中没有可以使用的可视化脚本
self.split_offset = 0 | ||
self.len = self.offsets[-1] | ||
else: | ||
print("Using train/val/test split: {}".format(self.train_val_test)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
判断注释是否无用?可以考虑去掉或者改为logger进行打印
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
for d in queue: | ||
yield d | ||
except Exception as err: | ||
print("ERRRR", err) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
清理print
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
except Exception as err: | ||
print("ERRRR", err) | ||
sampler_choices.pop(index_sampled) | ||
print( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
清理print
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
try: | ||
x, y = self.sub_dsets[file_idx][local_idx] | ||
except: # noqa | ||
print( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
""" | ||
self.model.eval() | ||
if full: | ||
cutoff = 999999999999 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要对硬编码进行处理
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有一处小问题,还麻烦修改下
|
||
logger = logging.getLogger(__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请不要使用自定义的logger,ppsci.utils下有logger
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. 已修改。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to fix
count = 0 | ||
for _, data in enumerate(temp_loader): | ||
if count > cutoff: | ||
del temp_loader |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
导致内存泄露
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是因为dataloader的num workers设置为0导致,修改为1或4正常。
|
||
1. 通过预测的相似性。 论文通过计算它们在输出空间中的距离来找到空间和时间上的相似演示 。这意味着,对于空间和时间域上的两个输入位置,如果论文发现它们经过训练的神经算子的输出相似,那么论文就将它们视为相似样本 。遵循 [24, 25],论文假设演示与查询共享相同的物理参数分布 。 | ||
|
||
2. 聚合。 对于查询的每个空间-时间位置,在找到其在演示中的相似样本后,论文聚合并平均它们的解作为预测 。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文档加大力度,继续翻译+贴图
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充。
e57264b
to
5d858dc
Compare
Signed-off-by: WG <39621324+wangguan1995@users.noreply.github.com>
Signed-off-by: WG <39621324+wangguan1995@users.noreply.github.com>
Update data_efficient_nopt.py
PR types
New Features
PR changes
Others
Describe
support data_efficient_nopt