- 
                Notifications
    You must be signed in to change notification settings 
- Fork 44
Open
Description
tests/schedulers/test_scheduler_flax.py:304 (FlaxDDPMSchedulerTest.test_full_loop_no_noise)
Array(3.7847595, dtype=float32) != 0.01
Expected :0.01
Actual   :Array(3.7847595, dtype=float32)
self = <test_scheduler_flax.FlaxDDPMSchedulerTest testMethod=test_full_loop_no_noise>
    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)
        state = scheduler.create_state()
    
        num_trained_timesteps = len(scheduler)
    
        model = self.dummy_model()
        sample = self.dummy_sample_deter
        key1, key2 = random.split(random.PRNGKey(0))
    
        for t in reversed(range(num_trained_timesteps)):
            # 1. predict noise residual
            residual = model(sample, t)
    
            # 2. predict previous mean of sample x_t-1
            output = scheduler.step(state, residual, t, sample, key1)
            pred_prev_sample = output.prev_sample
            state = output.state
            key1, key2 = random.split(key2)
    
            # if t > 0:
            #     noise = self.dummy_sample_deter
            #     variance = scheduler.get_variance(t) ** (0.5) * noise
            #
            # sample = pred_prev_sample + variance
            sample = pred_prev_sample
    
        result_sum = jnp.sum(jnp.abs(sample))
        result_mean = jnp.mean(jnp.abs(sample))
    
        if jax_device == "tpu":
            assert abs(result_sum - 251.26245) < 1e-2
            assert abs(result_mean - 0.32716465) < 1e-3
        else:
>           assert abs(result_sum - 255.1113) < 1e-2
E           assert Array(3.7847595, dtype=float32) < 0.01
E            +  where Array(3.7847595, dtype=float32) = abs((Array(251.32654, dtype=float32) - 255.1113))
schedulers/test_scheduler_flax.py:341: AssertionError
Running this without a TPU or GPU; but an M3 Pro.
Planning on going through all your tests and dependencies until 3.10, 3.11, 3.12, 3.13 are supported in addition to your existent 3.8 & 3.9 support.
PS: Your grain-nightly dependency doesn't seem to support 3.8, 3.9:
ERROR: Ignored the following versions that require a different python version: 0.0.1 Requires-Python >=3.10; 0.0.2 Requires-Python >=3.10; 0.0.3 Requires-Python >=3.10; 0.0.4 Requires-Python >=3.10
ERROR: Could not find a version that satisfies the requirement grain-nightly (from versions: none)
ERROR: No matching distribution found for grain-nightly
Is your setup.py up-to-date? - What Python [CPython] versions are you testing on?
Metadata
Metadata
Assignees
Labels
No labels