@@ -161,7 +161,7 @@ def _post_process_results(self) -> None:
161
161
RuntimeError: If the tensors sharing the same buffer have different
162
162
quantization settings.
163
163
"""
164
- self ._check_buffer_sharing ()
164
+ self ._check_and_fix_buffer_sharing ()
165
165
166
166
def _update_model_quant_results (
167
167
self ,
@@ -273,9 +273,11 @@ def _mark_tensors_requiring_buffer_duplication(
273
273
"""Mark tensors that require buffer duplication.
274
274
275
275
Marking a tensor means adding a DUPLICATE_BUFFER transformation as the first
276
- transformation to be applied for each consumer of the tensor.
276
+ transformation to be applied for each consumer of the tensor. Need to do
277
+ that for each consumer to preserve a zero layer and not affect the
278
+ horizontal optimization later in the transformation instructions generator.
277
279
278
- Mark all tensors within each of the provided buffers as requiring buffer
280
+ Marks all tensors within each of the provided buffers as requiring buffer
279
281
duplication, except for the last tensor. The order of tensors is assumed to
280
282
be the same during both the marking and transformation performer steps, as
281
283
determined by `self.buffer_to_tensors`. This allows the final tensor to
@@ -292,65 +294,148 @@ def _mark_tensors_requiring_buffer_duplication(
292
294
0 , _QuantTrans .DUPLICATE_BUFFER
293
295
)
294
296
295
- def _check_buffer_sharing (self ) -> None :
296
- """Check if tensors sharing the same buffer have the same quantization.
297
+ def _mark_tensors_requiring_tensor_duplication (
298
+ self , tensor_names_to_duplicate
299
+ ) -> None :
300
+ """Mark tensors that require tensor duplication.
301
+
302
+ Marking a tensor means adding a DUPLICATE_TENSOR transformation as the first
303
+ transformation to be applied for each consumer of the tensor. Need to do
304
+ that for each consumer to preserve a zero layer and not affect the
305
+ horizontal optimization later in the transformation instructions generator.
306
+
307
+ Args:
308
+ tensor_names_to_duplicate: Names of tensors to duplicate.
309
+ """
310
+ for tensor_name in tensor_names_to_duplicate :
311
+ for consumer_params in self .model_quant_results [tensor_name ].consumers :
312
+ consumer_params .transformations .insert (0 , _QuantTrans .DUPLICATE_TENSOR )
313
+
314
+ def _check_buffer_sharing_for_tensor (self , tensor : Any ) -> bool :
315
+ """Check buffer sharing for the tensor against itself.
316
+
317
+ Args:
318
+ tensor: The tensor to check.
319
+
320
+ Returns:
321
+ Whether the tensor has compatible quantization parameters.
297
322
298
323
Raises:
299
- RuntimeError: If the tensors sharing the same buffer have different
300
- quantization settings .
324
+ RuntimeError: If the tensor has incompatible quantization parameters
325
+ and the buffer is not constant .
301
326
"""
302
- def get_result (tensor : Any ):
303
- return self .model_quant_results .get (
304
- tfl_flatbuffer_utils .get_tensor_name (tensor ), None
327
+ tensor_params = self .model_quant_results .get (
328
+ tfl_flatbuffer_utils .get_tensor_name (tensor ), None
329
+ )
330
+ if tensor_params is None :
331
+ return True
332
+
333
+ if _are_tensor_consumer_params_compatible (tensor_params ):
334
+ return True
335
+ elif _is_constant_tensor (tensor , self .flatbuffer_model .buffers ):
336
+ return False
337
+ else :
338
+ error_msg = (
339
+ f'The tensor { tensor .name } consumers do not have the same'
340
+ ' quantization parameters. Please modify your quantization recipe to'
341
+ ' make sure the two tensors have the same quantization settings.'
305
342
)
343
+ raise RuntimeError (error_msg )
306
344
307
- buffers_to_duplicate = []
308
- for tensors in self . buffer_to_tensors . values ():
309
- if not tensors :
310
- continue
345
+ def _check_buffer_sharing_for_self_compatible_tensors (
346
+ self , tensor1 : Any , tensor2 : Any
347
+ ) -> bool :
348
+ """Check a pair of self compatible tensors have the same quantization params.
311
349
312
- first_tensor = tensors [0 ]
313
- first_tensor_params = get_result (first_tensor )
314
- if first_tensor_params is None :
315
- continue
350
+ Self compatible means that all tensor's consumers have the same quantization
351
+ parameters.
316
352
317
- for tensor in tensors : # Also checking against itself.
318
- tensor_params = get_result (tensor )
319
- if tensor_params is None :
320
- continue
353
+ Args:
354
+ tensor1: The first tensor to check.
355
+ tensor2: The second tensor to check.
321
356
322
- if not _compatible_tensor_transformation_params (
323
- first_tensor_params , tensor_params
324
- ):
325
- if _are_distinct_tensors_with_shared_buffer (
326
- first_tensor , tensor , self .flatbuffer_model .buffers
327
- ):
328
- buffers_to_duplicate .append (first_tensor .buffer )
329
- break
330
- else :
331
- error_msg = (
332
- f'The tensors { first_tensor .name } and { tensor .name } do not have'
333
- ' the same quantization parameters even though they share the'
334
- ' same buffer. Please modify your quantization recipe to make'
335
- ' sure the two tensors have the same quantization settings.'
336
- )
337
- raise RuntimeError (error_msg )
357
+ Returns:
358
+ Whether the tensors have compatible quantization parameters.
338
359
339
- self ._mark_tensors_requiring_buffer_duplication (buffers_to_duplicate )
360
+ Raises:
361
+ RuntimeError: If the tensors have incompatible quantization parameters
362
+ and the buffer is not constant.
363
+ """
364
+ tensor1_params = self .model_quant_results .get (
365
+ tfl_flatbuffer_utils .get_tensor_name (tensor1 ), None
366
+ )
367
+ tensor2_params = self .model_quant_results .get (
368
+ tfl_flatbuffer_utils .get_tensor_name (tensor2 ), None
369
+ )
340
370
371
+ if tensor1_params is None or tensor2_params is None :
372
+ return True
341
373
342
- def _compatible_tensor_transformation_params (
343
- params1 : qtyping .TensorTransformationParams ,
344
- params2 : qtyping .TensorTransformationParams ,
345
- ) -> bool :
346
- """Check if two tensor transformation params are compatible."""
347
- return (
348
- _are_tensor_consumer_params_compatible (params1 )
349
- and _are_tensor_consumer_params_compatible (params2 )
350
- and _are_self_compatible_tensors_compatible_to_each_other (
351
- params1 , params2
374
+ if _are_self_compatible_tensors_compatible_to_each_other (
375
+ tensor1_params , tensor2_params
376
+ ):
377
+ return True
378
+ elif _is_constant_tensor (tensor1 , self .flatbuffer_model .buffers ):
379
+ return False
380
+ else :
381
+ error_msg = (
382
+ f'The tensors { tensor1 .name } and { tensor2 .name } do not have'
383
+ ' the same quantization parameters even though they share the'
384
+ ' same buffer. Please modify your quantization recipe to make'
385
+ ' sure the two tensors have the same quantization settings.'
352
386
)
353
- )
387
+ raise RuntimeError (error_msg )
388
+
389
+ def _check_and_fix_buffer_sharing (self ) -> None :
390
+ """Check and fix tensor/buffer sharing issues when possible.
391
+
392
+ This function checks if tensors sharing the same buffer have the same
393
+ quantization settings. If not, when it's possible, it will fix it by marking
394
+ such tensors or buffers to be duplicated. Otherwise, it will raise an error.
395
+
396
+ Possible cases that can be fixed by duplication:
397
+ 1. A constant tensor recieves different quantization parameters from its
398
+ consumers. In this case, the tensor is marked for duplication.
399
+ 2. Two or more tensors share the same constant buffer and have different
400
+ quantization parameters. In this case, the buffer is marked for
401
+ duplication.
402
+
403
+ Raises:
404
+ RuntimeError: If the tensors sharing the same buffer have different
405
+ quantization settings and it can't be resolved by duplicating the
406
+ buffer/tensor.
407
+ """
408
+ buffers_to_duplicate = []
409
+ tensor_names_to_duplicate = []
410
+ for buffer_idx , tensors in self .buffer_to_tensors .items ():
411
+ if not tensors :
412
+ continue
413
+ # Check if any of the tensors needs to be duplicated.
414
+ for tensor in tensors :
415
+ if not self ._check_buffer_sharing_for_tensor (tensor ):
416
+ tensor_names_to_duplicate .append (
417
+ tfl_flatbuffer_utils .get_tensor_name (tensor )
418
+ )
419
+ # Check if the buffer needs to be duplicated.
420
+ tensor_1 = tensors [0 ]
421
+ tensor_name_1 = tfl_flatbuffer_utils .get_tensor_name (tensor_1 )
422
+ if tensor_name_1 in tensor_names_to_duplicate :
423
+ buffers_to_duplicate .append (buffer_idx )
424
+ continue
425
+ for tensor_2 in tensors [1 :]:
426
+ tensor_name_2 = tfl_flatbuffer_utils .get_tensor_name (tensor_2 )
427
+ if (
428
+ tensor_name_2 in tensor_names_to_duplicate
429
+ or not self ._check_buffer_sharing_for_self_compatible_tensors (
430
+ tensor_1 , tensor_2
431
+ )
432
+ ):
433
+ buffers_to_duplicate .append (buffer_idx )
434
+ break
435
+
436
+ # Fix the buffer sharing issues.
437
+ self ._mark_tensors_requiring_buffer_duplication (buffers_to_duplicate )
438
+ self ._mark_tensors_requiring_tensor_duplication (tensor_names_to_duplicate )
354
439
355
440
356
441
def _are_tensor_consumer_params_compatible (
@@ -447,12 +532,6 @@ def _compatible_tensor_params(
447
532
return False
448
533
449
534
450
- def _are_distinct_tensors_with_shared_buffer (
451
- tensor1 : Any , tensor2 : Any , buffers : list [Any ]
452
- ) -> bool :
453
- """Check if two tensors are different and share a constant buffer."""
454
- are_different_tensors = tensor1 .name != tensor2 .name
455
- do_share_buffer = tensor1 .buffer == tensor2 .buffer
456
- is_constant_buffer = buffers [tensor1 .buffer ].data is not None
457
-
458
- return are_different_tensors and do_share_buffer and is_constant_buffer
535
+ def _is_constant_tensor (tensor : Any , buffers : Sequence [Any ]) -> bool :
536
+ """Check if the tensor is a constant tensor."""
537
+ return buffers [tensor .buffer ].data is not None
0 commit comments