Skip to content

fix: avoid clearing in-flight pipeline states in custom kernel cache#3434

Open
Ziqiao-git wants to merge 1 commit intoml-explore:mainfrom
Ziqiao-git:fix/custom-kernel-cache
Open

fix: avoid clearing in-flight pipeline states in custom kernel cache#3434
Ziqiao-git wants to merge 1 commit intoml-explore:mainfrom
Ziqiao-git:fix/custom-kernel-cache

Conversation

@Ziqiao-git
Copy link
Copy Markdown
Contributor

@Ziqiao-git Ziqiao-git commented Apr 21, 2026

Root cause update

After further debugging I found this is not a use-after-free caused by commandBufferWithUnretainedReferences() (as I initially suspected in #3378). The real root cause is in custom_kernel.cpp's CustomKernelCache.

What actually happens

  1. mx.fast.metal_kernel returns a callable that regenerates the full kernel source (write_signature() + user body) on every call, using the current inputs' dtype to form the function signature.
  2. The same kernel name can therefore produce different source strings across calls — e.g. if one call has a float16 input and another has a float32 input, write_signature() emits const device float16_t* x vs const device float* x.
  3. CustomKernel::eval_gpu() looks up the kernel in CustomKernelCache by name only. When it detects it->second != source_ it calls d.clear_library(name_), which drops the MTL::Library and all cached NS::SharedPtr<MTL::ComputePipelineState> for that library.
  4. If a previous dispatch with the old pipeline state is still sitting in an uncommitted command buffer, clear_library releases it out from under the command buffer → Metal validation reports command buffer references deallocated object. Without validation you get non-deterministic NaN.

Concrete trigger in ParoQuant

GET_KERNEL paro_rotate_r1 lib=A kernel=K1          # in_proj_qkv: x is float16
GET_KERNEL paro_rotate_r1 lib=A kernel=K1          # in_proj_z:   x is float16
GET_KERNEL gated_delta_step lib=B kernel=K2
SOURCE_MISMATCH paro_rotate_r1 cached=3131 new=3123 # out_proj now receives float32 x
CLEAR_LIBRARY paro_rotate_r1                       # releases K1 while a command buffer still references it
GET_KERNEL paro_rotate_r1 lib=A' kernel=K3         # new float-variant pipeline state

The dtype change comes from ParoQuant storing scales as float32 (biases as float16): mx.quantized_matmul(float16 x, uint32 w, float32 scales, float16 biases) promotes output to float32. That propagates through conv1dgated_delta_updateRMSNormGatedout_proj's input. Same rotation.metal body, different signature.

Proposed fix

Keep per-source variants coexisting in the cache instead of clearing the old one:

-  {
-    // Clear kernels from the device library cache if needed
-    auto& kernel_cache = cache();
-    if (auto it = kernel_cache.libraries.find(name_);
-        it != kernel_cache.libraries.end()) {
-      if (it->second != source_) {
-        auto& d = metal::device(s.device);
-        d.clear_library(name_);
-        it->second = source_;
-      }
-    } else {
-      kernel_cache.libraries.emplace(name_, source_);
-    }
-  }
-
-  auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
-  auto kernel = d.get_kernel(name_, lib);
+  // Use a source-dependent library key so different source variants
+  // (e.g. from write_signature picking different dtype qualifiers across
+  // calls) coexist without evicting each other. Clearing a library while
+  // its pipeline states are still referenced by an in-flight command
+  // buffer causes use-after-free.
+  auto source_hash = std::hash<std::string>{}(source_);
+  auto lib_key = name_ + "_" + std::to_string(source_hash);
+
+  auto lib = d.get_library(lib_key, [this] { return metal::utils() + source_; });
+  auto kernel = d.get_kernel(name_, lib);

This:

  • Removes the clear_library path entirely, so no pipeline state is released while still in flight.
  • Keeps name_ as the kernel-function symbol (no ABI / dispatch change).
  • Doesn't touch commandBufferWithUnretainedReferences, so no performance trade-off on older hardware.

Trade-offs / open questions

  • Cache growth: each unique source variant now keeps a live MTL::Library and pipeline state. In practice, variants are bounded by the small number of (dtype, shape-info) combinations used in a model (ParoQuant sees 2–3 variants per model). But for pathological callers this is unbounded — happy to add an LRU bound if preferred.
  • Hash collisions: std::hash<std::string> is 64-bit. Collisions are astronomically unlikely for real source strings but not impossible. Using the source string itself as the map key (instead of a hash) is safer at the cost of memory; I defaulted to hash for compactness.
  • Why name_ for get_kernel: the Metal function symbol inside each library is custom_kernel_<name>, which is the same across variants, so get_kernel(name_, lib) correctly resolves per-library.

Verification

write_signature() generates different source code for the same kernel
name when input dtypes change. The old code detected this source
mismatch and called clear_library(), which deallocates cached pipeline
states that may still be referenced by an in-flight command buffer,
causing use-after-free (Metal validation: 'command buffer references
deallocated object').

Fix: use source-hash-dependent library cache keys so different source
variants coexist without evicting each other. Removes the clear_library
path entirely.

Fixes ml-explore#3347
@Ziqiao-git
Copy link
Copy Markdown
Contributor Author

Update: z-lab/paroquant#38 has landed and fixes the actual source of the dtype mismatch — PR #19 of ParoQuant accidentally promoted the quantized scales from fp16 to fp32 while keeping biases fp16, and that asymmetry is what caused mx.quantized_matmul to produce fp32 output and trigger the source-mismatch path in the custom kernel cache. With the ParoQuant fix, the repro in #3347 no longer hits clear_library.

This PR still targets a latent issue — any caller that invokes the same mx.fast.metal_kernel with varying input dtypes will hit the same use-after-free — but it is no longer required for the ParoQuant use case. I'll leave the decision on whether to keep or close up to you.

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better fix would be putting the dtypes of inputs in kernel_name?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants