-
Notifications
You must be signed in to change notification settings - Fork 94
🎨 Refactor ModelABC
to Help Use Default Torch Models
#867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-define-engines-abc
Are you sure you want to change the base?
🎨 Refactor ModelABC
to Help Use Default Torch Models
#867
Conversation
Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com>
Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com>
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## dev-define-engines-abc #867 +/- ##
==========================================================
- Coverage 91.77% 91.69% -0.09%
==========================================================
Files 73 73
Lines 9354 9359 +5
Branches 1224 1224
==========================================================
- Hits 8585 8582 -3
- Misses 756 764 +8
Partials 13 13 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
with torch.inference_mode(): | ||
output = model(img_patches_device) | ||
# Output should be a single tensor or scalar | ||
return {"probabilities": output.cpu().numpy()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the current develop
branch, neither CNNModel
, nor CNNBackbone
returned dictionaries as output of their infer_batch()
methods. Also, CNNModel
currently returns an array, while CNNBackbone
returns a list with the array. It might be fine, just wanted to highlight this.
CNNModel
return output.cpu().numpy() |
CNNBackbone
return [output.cpu().numpy()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. We are aware of this. Our preference is to use torch nn models but to generalise for multi modal output we may need dictionaries. This PR is to check if we can move to generic torch models or we will need a sub class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.
…rch-nn-model # Conflicts: # tests/models/test_arch_vanilla.py # tiatoolbox/models/architecture/vanilla.py
…rch-nn-model # Conflicts: # tiatoolbox/models/engine/engine_abc.py
ModelABC
to Help Use Default Torch Modelsinfer_batch
fromModelABC