You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
validate forward_fun output shape in FeatureAblation (#1091)
Summary:
Correctly validate the `forward_fun`'s output shape in `FeatureAblation` (and `FeaturePermutation`)
Abandoned previous flawed assumption of "aggregation mode", which forbid the support for multi-outputs models (ref #1047)
New logic does not care output shape when `perturbations_per_eval == 1`. Only when `perturbations_per_eval > 1`, it require "Non-Aggregation mode", which is defined as the 1st dim of the model's output should grow with the input's batch size in the same ratio. This is achieved by actually comparing the output shape of 2 different inputs instead of making any assumption based on other user config:
- The baseline output is from the initial eval with the original inputs which we have to run anyway.
- The expanded output is from the 1st ablated eval whose input batch size has been expanded for more feature perturbation.
This way does not even introduce any extra forward calls.
Pull Request resolved: #1091
Reviewed By: vivekmig
Differential Revision: D42027843
Pulled By: aobo-y
fbshipit-source-id: cbcbad64bb1695e7be9c9447ebde6ec3f0cb8a90
0 commit comments