fix: avoid clearing in-flight pipeline states in custom kernel cache#3434
fix: avoid clearing in-flight pipeline states in custom kernel cache#3434Ziqiao-git wants to merge 1 commit intoml-explore:mainfrom
Conversation
write_signature() generates different source code for the same kernel name when input dtypes change. The old code detected this source mismatch and called clear_library(), which deallocates cached pipeline states that may still be referenced by an in-flight command buffer, causing use-after-free (Metal validation: 'command buffer references deallocated object'). Fix: use source-hash-dependent library cache keys so different source variants coexist without evicting each other. Removes the clear_library path entirely. Fixes ml-explore#3347
|
Update: z-lab/paroquant#38 has landed and fixes the actual source of the dtype mismatch — PR #19 of ParoQuant accidentally promoted the quantized scales from fp16 to fp32 while keeping biases fp16, and that asymmetry is what caused This PR still targets a latent issue — any caller that invokes the same |
zcbenz
left a comment
There was a problem hiding this comment.
I think a better fix would be putting the dtypes of inputs in kernel_name?
Root cause update
After further debugging I found this is not a use-after-free caused by
commandBufferWithUnretainedReferences()(as I initially suspected in #3378). The real root cause is incustom_kernel.cpp'sCustomKernelCache.What actually happens
mx.fast.metal_kernelreturns a callable that regenerates the full kernel source (write_signature()+ user body) on every call, using the current inputs' dtype to form the function signature.float16input and another has afloat32input,write_signature()emitsconst device float16_t* xvsconst device float* x.CustomKernel::eval_gpu()looks up the kernel inCustomKernelCacheby name only. When it detectsit->second != source_it callsd.clear_library(name_), which drops theMTL::Libraryand all cachedNS::SharedPtr<MTL::ComputePipelineState>for that library.clear_libraryreleases it out from under the command buffer → Metal validation reportscommand buffer references deallocated object. Without validation you get non-deterministic NaN.Concrete trigger in ParoQuant
The dtype change comes from ParoQuant storing scales as float32 (biases as float16):
mx.quantized_matmul(float16 x, uint32 w, float32 scales, float16 biases)promotes output to float32. That propagates throughconv1d→gated_delta_update→RMSNormGated→out_proj's input. Samerotation.metalbody, different signature.Proposed fix
Keep per-source variants coexisting in the cache instead of clearing the old one:
This:
clear_librarypath entirely, so no pipeline state is released while still in flight.name_as the kernel-function symbol (no ABI / dispatch change).commandBufferWithUnretainedReferences, so no performance trade-off on older hardware.Trade-offs / open questions
MTL::Libraryand pipeline state. In practice, variants are bounded by the small number of (dtype, shape-info) combinations used in a model (ParoQuant sees 2–3 variants per model). But for pathological callers this is unbounded — happy to add an LRU bound if preferred.std::hash<std::string>is 64-bit. Collisions are astronomically unlikely for real source strings but not impossible. Using the source string itself as the map key (instead of a hash) is safer at the cost of memory; I defaulted to hash for compactness.name_forget_kernel: the Metal function symbol inside each library iscustom_kernel_<name>, which is the same across variants, soget_kernel(name_, lib)correctly resolves per-library.Verification
MTL_DEBUG_LAYER=1reproduction from [BUG] mx.fast.metal_kernel: use-after-free when multiple custom kernels compose in lazy graph #3347 no longer triggers thedeallocated objectassertion._force_evalworkaround).