2727 FP8_E4M3_DATA ,
2828 ActivationOrdering ,
2929 DynamicType ,
30- KVCacheScaleType ,
3130 QuantizationArgs ,
3231 QuantizationMetadata ,
3332 QuantizationScheme ,
4241 disable_hf_hook ,
4342 get_execution_device ,
4443 get_head_dim ,
45- < << << << HEAD
4644 get_num_attn_heads ,
4745 get_num_kv_heads ,
48- == == == =
49- >> >> >> > 05 ec17e (WIP )
5046 register_offload_parameter ,
5147)
5248from torch .nn import Module , Parameter
53- from transformers import PretrainedConfig
5449
5550
5651__all__ = [
@@ -294,7 +289,6 @@ def initialize_attn_qparams(
294289 kv_cache : Optional [QuantizedKVCache ] = getattr (module , KV_CACHE_ATTR , None )
295290
296291 if impl is None and kv_cache is None :
297- < << << << HEAD
298292 raise ValueError ("Attention module has quantization scheme but no attached" )
299293
300294 _validate_attention_scheme (scheme )
@@ -310,27 +304,12 @@ def initialize_attn_qparams(
310304 kv_observed_shape = (num_kv_heads , None , head_dim )
311305 observed_dtype = next (module .parameters ()).dtype
312306
313- == == == =
314- raise ValueError ("Attention module has quantization scheme but no attached " )
315-
316- config : PretrainedConfig = getattr (impl , "config" , None ) or getattr (
317- kv_cache , "config" , None
318- )
319- head_dim = get_head_dim (config )
320- observed_shape = (head_dim ,) # (batch_size, num_attention_heads, slen, head_dim)
321- observed_dtype = next (module .parameters ()).dtype
322-
323- >> >> >> > 05 ec17e (WIP )
324307 if impl is not None :
325308 initialize_qparams (
326309 module ,
327310 "q" ,
328311 scheme .input_activations ,
329- << << << < HEAD
330312 observed_shape = q_observed_shape ,
331- == == == =
332- observed_shape = observed_shape ,
333- > >> >> >> 05 ec17e (WIP )
334313 observed_dtype = observed_dtype ,
335314 force_zero_point = force_zero_point ,
336315 )
@@ -340,19 +319,14 @@ def initialize_attn_qparams(
340319 module ,
341320 "k" ,
342321 scheme .input_activations ,
343- << << << < HEAD
344322 observed_shape = kv_observed_shape ,
345- == == == =
346- observed_shape = observed_shape ,
347- >> >> >> > 05 ec17e (WIP )
348323 observed_dtype = observed_dtype ,
349324 force_zero_point = force_zero_point ,
350325 )
351326 initialize_qparams (
352327 module ,
353328 "v" ,
354329 scheme .input_activations ,
355- << << << < HEAD
356330 observed_shape = kv_observed_shape ,
357331 observed_dtype = observed_dtype ,
358332 force_zero_point = force_zero_point ,
@@ -373,9 +347,3 @@ def _validate_attention_scheme(scheme: QuantizationScheme):
373347
374348 if scheme .output_activations is not None :
375349 raise ValueError ("Cannot apply output quantization to attention" )
376- == == == =
377- observed_shape = observed_shape ,
378- observed_dtype = observed_dtype ,
379- force_zero_point = force_zero_point ,
380- )
381- >> >> >> > 05 ec17e (WIP )
0 commit comments