-
Notifications
You must be signed in to change notification settings - Fork 381
Added EXP3 Scoring function in continuation with pr #2358 #3013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3013
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
Just need tests, docstrings and add it to the docs (see docs/ directory where you'll need to manually add the classes where they fit, I can help if it's unclear).
pass | ||
|
||
|
||
class PUCTScore(MCTSScore): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add some docstrings there!
Should include an example
return node | ||
|
||
|
||
class UCBScore(MCTSScore): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
return node | ||
|
||
|
||
class EXP3Score(MCTSScore): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
node.set(self.weights_key, weights) | ||
|
||
|
||
class MCTSScores(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vmoens ?? Any changes needed here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry but I didn't get that 😅
|
||
def update_weights( | ||
self, node: TensorDictBase, action_idx: int, reward: float | ||
) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Public methods need docstrings too
@vmoens I branched out from your PR branch. I will be adding more docstrings, tests and also the other two methods which are yet to be implemented |
Description
Added EXP3 Scoring function in continuation with pr #2358
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!