Skip to content

Add ESS resampler to avoid resampling every iteration#67

Open
jihodori wants to merge 12 commits into
JuliaPOMDP:masterfrom
jihodori:effective-sample-size
Open

Add ESS resampler to avoid resampling every iteration#67
jihodori wants to merge 12 commits into
JuliaPOMDP:masterfrom
jihodori:effective-sample-size

Conversation

@jihodori
Copy link
Copy Markdown

Adding ESS rampler and setting the default threshold to 0.5.

Tested with runtests.jl

Test Summary: | Pass  Total  Time
implemented   |    2      2  0.0s
WARNING: redefinition of constant Main.A. This may fail, cause incorrect answers, or produce other errors.
WARNING: redefinition of constant Main.B. This may fail, cause incorrect answers, or produce other errors.
WARNING: redefinition of constant Main.W. This may fail, cause incorrect answers, or produce other errors.
WARNING: redefinition of constant Main.V. This may fail, cause incorrect answers, or produce other errors.
....................................................................................................Test Summary: |Time
example       | None  0.4s
Test Summary:   | Pass  Total  Time
domain_specific |    7      7  0.2s
Test Summary: | Pass  Total  Time
beliefs       |   24     24  0.1s
Test Summary: | Pass  Total  Time
infer         |    2      2  0.0s
Test Summary:  | Pass  Total  Time
pomdp terminal |    1      1  0.4s
Test Summary: | Pass  Total  Time
alpha         |    2      2  0.4s

@jihodori jihodori changed the title add ess resampler to avoid resampling every iteration Add ESS resampler to avoid resampling every iteration Apr 26, 2024
Copy link
Copy Markdown
Member

@zsunberg zsunberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I think we need to make some bigger changes for this to work correctly. See #68.

Comment thread src/basic.jl Outdated
Comment thread src/basic.jl Outdated
Comment thread src/basic.jl Outdated
Comment on lines +56 to +66
if (calculate_ess(wm) < up.resampling_threshold)
resampled_particle_collection = resample(
up.resampler,
WeightedParticleBelief(pm, wm, sum(wm), nothing),
up.predict_model,
up.reweight_model,
b, a, o,
up.rng)
num_particles = n_particles(resampled_particle_collection)
return WeightedParticleBelief(resampled_particle_collection.particles, fill(1.0 / num_particles, num_particles))
end
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, if you resample before reweight!, the update function does not work as expected

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now resampling at the end

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this decision. It seems much better to resample before prediction and I don't understand why it didn't work.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this decision. It seems much better to resample before prediction and I don't understand why it didn't work.

I handled the edge case (first iteration) by checking wm is empty. The resample function should not be called when the update function is called for the first time

Copy link
Copy Markdown
Member

@zsunberg zsunberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks pretty good, but please address my inline comment about resampling before predicting.

(also, I am seeing many things that I don't like about the code I wrote - it seems like it might be time for a major revamp after this is merged 😄 )

Comment thread src/basic.jl Outdated
Comment on lines +56 to +66
if (calculate_ess(wm) < up.resampling_threshold)
resampled_particle_collection = resample(
up.resampler,
WeightedParticleBelief(pm, wm, sum(wm), nothing),
up.predict_model,
up.reweight_model,
b, a, o,
up.rng)
num_particles = n_particles(resampled_particle_collection)
return WeightedParticleBelief(resampled_particle_collection.particles, fill(1.0 / num_particles, num_particles))
end
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this decision. It seems much better to resample before prediction and I don't understand why it didn't work.

Copy link
Copy Markdown
Member

@zsunberg zsunberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good. I just have one important comment about checking the weights, and a few minor code comments.

Comment thread test/runtests.jl Outdated
pf = BootstrapFilter(pomdp, 100)
bp = update(pf, initialize_belief(pf, Categorical([0.5, 0.5])), -1, 1.0)
@test all(particles(bp) .== 1)
@test abs(mean(bp) - 1.0) < 1e-5
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to test that all particles with nonzero weight are 1? or pdf(bp, 1.0) is approximately 1?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to test that all particles with nonzero weight are 1

I agree with this. Addressed this in the newer commit

Comment thread src/basic.jl Outdated
function update(up::BasicParticleFilter, b::AbstractParticleBelief, a, o)
pm = up._particle_memory
wm = up._weight_memory
if (!isempty(wm) && calculate_ess(wm) < up.resampling_threshold)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to check if wm is empty? Shouldn't calculate_ess prevent resampling on the first step? Also, the same updater might be re-used for multiple simulations, so this is not a reliable way to check.

(In general, to write maintainable code, it is not good to check implicit side effects (e.g. wm will probably be empty at the beginning, but that is not guaranteed) it is better to explicitly check whether it is the first step 😄 but, by design, that is not possible)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry it took me some time to figure out why I had this conditional statement. I realized I was testing this after calling the resize! function with an empty list (the initial case), which causes wm to have actual small weights such that calculate_ess is always lower than the resampling_threshold.

Now calculate_ess is called before the resize! function and prevents resampling on the first step as it returns Inf

Comment thread src/basic.jl Outdated
ESS is divided by `num_particles` to make it easier to get a uniform threshold (scale between 0 and 1) for resampling across different particle filters.

M. S. Arulampalam, S. Maskell, N. Gordon and T. Clapp, "A tutorial on particle filters for online nonlinear/non-Gaussian Bayesian tracking," in IEEE Transactions on Signal Processing, vol. 50, no. 2, pp. 174-188, Feb. 2002, doi: 10.1109/78.978374.
keywords: {Tutorial;Particle filters;Nonlinear dynamical systems;Costs;Signal processing;Bayesian methods;Particle tracking;Kalman filters;Filtering;Monte Carlo methods},
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete these keywords?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted

Comment thread src/basic.jl Outdated
M. S. Arulampalam, S. Maskell, N. Gordon and T. Clapp, "A tutorial on particle filters for online nonlinear/non-Gaussian Bayesian tracking," in IEEE Transactions on Signal Processing, vol. 50, no. 2, pp. 174-188, Feb. 2002, doi: 10.1109/78.978374.
keywords: {Tutorial;Particle filters;Nonlinear dynamical systems;Costs;Signal processing;Bayesian methods;Particle tracking;Kalman filters;Filtering;Monte Carlo methods},
"""
function calculate_ess(weights)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function calculate_ess(weights)
function normalized_ess(weights)

"calculate" does not really add anything to the name of this function. Also, If someone just said "ESS", does it mean a number between 0 and 1 or an actual number of particles? If ESS means a number of particles, I think we should call this function normalized_ess.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed the name to normalized_ess

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants