-
Notifications
You must be signed in to change notification settings - Fork 196
Fix slow vector field tests #1657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1657 +/- ##
==========================================
+ Coverage 86.59% 86.88% +0.28%
==========================================
Files 135 134 -1
Lines 10931 11909 +978
==========================================
+ Hits 9466 10347 +881
- Misses 1465 1562 +97
Flags with carried forward coverage won't be shown. Click here to find out more.
|
CD is running here: https://github.com/sbi-dev/sbi/actions/runs/17409710472 |
@@ -130,7 +94,7 @@ def set_x( | |||
self, | |||
x_o: Optional[Tensor], | |||
x_is_iid: Optional[bool] = False, | |||
iid_method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss", | |||
iid_method: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why removing the Literals?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was causing type checker issues. As this is mostly used internally it's fine to relax the constraints to just str
I think.
sbi/samplers/score/diffuser.py
Outdated
if save_intermediate: | ||
intermediate_samples.append(samples) | ||
|
||
# Check for NaN values after predictor | ||
if torch.isnan(samples).any(): | ||
raise RuntimeError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the new runtime error which fails the test, right?
This is already happening if a single sample in the batch becomes nan, which might is to strict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, good point. but it can happen that all samples are NaN and it takes a very long time until this is detected by accept_reject. Thus, it's probably better to fix this detection problem and allow NaNs for some samples here. will look into that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this check to the posterior level where the final samples are passed on. Also, I changed it to .all()
, because this caused the main issue: 'FMPE' with 'auto-gauss' was returning only NaNs. It seems NPSE
with iid-sampling returns only sometimes NaNs.
Therefore, NPSE iid-score tests are passing now. FMPE
iid-score tests are skipped, except for fnpe
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, I will have a closer look into it later.
It could be that a single sample or so becomes nan and the newly introduce RuntimeError is raised. Previously, this wasnt a problem because this would be handled by the accept_reject, no?
… sampling warning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great thanks.
Mhh, from your description it could be some of the terminal points on FMPE which can have a singularity in the drift (i.e. in diffusion we add a nugget i.e. we go to exact 0, but not sure for FMPE). Will look after this in #1656 , but fine to merge this already.
Edit: Nvm, if its prior dependent then its likely something numerical with the marginal moments given a Uniform distribution.
Edit: Nvm, it is this.
The slow vf tests were running for hours and were eventually killed by GH action runners. I believe there were some nested fixture calls causing this issue. This is now fixed and the runtime is "down" to 30min for the slow vf tests.
I also did some refactoring the the vf utils here and there to make it more transparent. IID inference for FMPE is working in parts, see #1656