Skip to content

Commit ccdce42

Browse files
marialyucopybara-github
authored andcommitted
Enable duplication of a constant tensor when it receives different quant params
1. Add DUPLICATE_TENSOR to tensor transformation instructions in params_generator 2. Update end-to-end test PiperOrigin-RevId: 746556560
1 parent 37457ea commit ccdce42

File tree

3 files changed

+322
-136
lines changed

3 files changed

+322
-136
lines changed

ai_edge_quantizer/params_generator.py

Lines changed: 138 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _post_process_results(self) -> None:
161161
RuntimeError: If the tensors sharing the same buffer have different
162162
quantization settings.
163163
"""
164-
self._check_buffer_sharing()
164+
self._check_and_fix_buffer_sharing()
165165

166166
def _update_model_quant_results(
167167
self,
@@ -273,9 +273,11 @@ def _mark_tensors_requiring_buffer_duplication(
273273
"""Mark tensors that require buffer duplication.
274274
275275
Marking a tensor means adding a DUPLICATE_BUFFER transformation as the first
276-
transformation to be applied for each consumer of the tensor.
276+
transformation to be applied for each consumer of the tensor. Need to do
277+
that for each consumer to preserve a zero layer and not affect the
278+
horizontal optimization later in the transformation instructions generator.
277279
278-
Mark all tensors within each of the provided buffers as requiring buffer
280+
Marks all tensors within each of the provided buffers as requiring buffer
279281
duplication, except for the last tensor. The order of tensors is assumed to
280282
be the same during both the marking and transformation performer steps, as
281283
determined by `self.buffer_to_tensors`. This allows the final tensor to
@@ -292,65 +294,148 @@ def _mark_tensors_requiring_buffer_duplication(
292294
0, _QuantTrans.DUPLICATE_BUFFER
293295
)
294296

295-
def _check_buffer_sharing(self) -> None:
296-
"""Check if tensors sharing the same buffer have the same quantization.
297+
def _mark_tensors_requiring_tensor_duplication(
298+
self, tensor_names_to_duplicate
299+
) -> None:
300+
"""Mark tensors that require tensor duplication.
301+
302+
Marking a tensor means adding a DUPLICATE_TENSOR transformation as the first
303+
transformation to be applied for each consumer of the tensor. Need to do
304+
that for each consumer to preserve a zero layer and not affect the
305+
horizontal optimization later in the transformation instructions generator.
306+
307+
Args:
308+
tensor_names_to_duplicate: Names of tensors to duplicate.
309+
"""
310+
for tensor_name in tensor_names_to_duplicate:
311+
for consumer_params in self.model_quant_results[tensor_name].consumers:
312+
consumer_params.transformations.insert(0, _QuantTrans.DUPLICATE_TENSOR)
313+
314+
def _check_buffer_sharing_for_tensor(self, tensor: Any) -> bool:
315+
"""Check buffer sharing for the tensor against itself.
316+
317+
Args:
318+
tensor: The tensor to check.
319+
320+
Returns:
321+
Whether the tensor has compatible quantization parameters.
297322
298323
Raises:
299-
RuntimeError: If the tensors sharing the same buffer have different
300-
quantization settings.
324+
RuntimeError: If the tensor has incompatible quantization parameters
325+
and the buffer is not constant.
301326
"""
302-
def get_result(tensor: Any):
303-
return self.model_quant_results.get(
304-
tfl_flatbuffer_utils.get_tensor_name(tensor), None
327+
tensor_params = self.model_quant_results.get(
328+
tfl_flatbuffer_utils.get_tensor_name(tensor), None
329+
)
330+
if tensor_params is None:
331+
return True
332+
333+
if _are_tensor_consumer_params_compatible(tensor_params):
334+
return True
335+
elif _is_constant_tensor(tensor, self.flatbuffer_model.buffers):
336+
return False
337+
else:
338+
error_msg = (
339+
f'The tensor {tensor.name} consumers do not have the same'
340+
' quantization parameters. Please modify your quantization recipe to'
341+
' make sure the two tensors have the same quantization settings.'
305342
)
343+
raise RuntimeError(error_msg)
306344

307-
buffers_to_duplicate = []
308-
for tensors in self.buffer_to_tensors.values():
309-
if not tensors:
310-
continue
345+
def _check_buffer_sharing_for_self_compatible_tensors(
346+
self, tensor1: Any, tensor2: Any
347+
) -> bool:
348+
"""Check a pair of self compatible tensors have the same quantization params.
311349
312-
first_tensor = tensors[0]
313-
first_tensor_params = get_result(first_tensor)
314-
if first_tensor_params is None:
315-
continue
350+
Self compatible means that all tensor's consumers have the same quantization
351+
parameters.
316352
317-
for tensor in tensors: # Also checking against itself.
318-
tensor_params = get_result(tensor)
319-
if tensor_params is None:
320-
continue
353+
Args:
354+
tensor1: The first tensor to check.
355+
tensor2: The second tensor to check.
321356
322-
if not _compatible_tensor_transformation_params(
323-
first_tensor_params, tensor_params
324-
):
325-
if _are_distinct_tensors_with_shared_buffer(
326-
first_tensor, tensor, self.flatbuffer_model.buffers
327-
):
328-
buffers_to_duplicate.append(first_tensor.buffer)
329-
break
330-
else:
331-
error_msg = (
332-
f'The tensors {first_tensor.name} and {tensor.name} do not have'
333-
' the same quantization parameters even though they share the'
334-
' same buffer. Please modify your quantization recipe to make'
335-
' sure the two tensors have the same quantization settings.'
336-
)
337-
raise RuntimeError(error_msg)
357+
Returns:
358+
Whether the tensors have compatible quantization parameters.
338359
339-
self._mark_tensors_requiring_buffer_duplication(buffers_to_duplicate)
360+
Raises:
361+
RuntimeError: If the tensors have incompatible quantization parameters
362+
and the buffer is not constant.
363+
"""
364+
tensor1_params = self.model_quant_results.get(
365+
tfl_flatbuffer_utils.get_tensor_name(tensor1), None
366+
)
367+
tensor2_params = self.model_quant_results.get(
368+
tfl_flatbuffer_utils.get_tensor_name(tensor2), None
369+
)
340370

371+
if tensor1_params is None or tensor2_params is None:
372+
return True
341373

342-
def _compatible_tensor_transformation_params(
343-
params1: qtyping.TensorTransformationParams,
344-
params2: qtyping.TensorTransformationParams,
345-
) -> bool:
346-
"""Check if two tensor transformation params are compatible."""
347-
return (
348-
_are_tensor_consumer_params_compatible(params1)
349-
and _are_tensor_consumer_params_compatible(params2)
350-
and _are_self_compatible_tensors_compatible_to_each_other(
351-
params1, params2
374+
if _are_self_compatible_tensors_compatible_to_each_other(
375+
tensor1_params, tensor2_params
376+
):
377+
return True
378+
elif _is_constant_tensor(tensor1, self.flatbuffer_model.buffers):
379+
return False
380+
else:
381+
error_msg = (
382+
f'The tensors {tensor1.name} and {tensor2.name} do not have'
383+
' the same quantization parameters even though they share the'
384+
' same buffer. Please modify your quantization recipe to make'
385+
' sure the two tensors have the same quantization settings.'
352386
)
353-
)
387+
raise RuntimeError(error_msg)
388+
389+
def _check_and_fix_buffer_sharing(self) -> None:
390+
"""Check and fix tensor/buffer sharing issues when possible.
391+
392+
This function checks if tensors sharing the same buffer have the same
393+
quantization settings. If not, when it's possible, it will fix it by marking
394+
such tensors or buffers to be duplicated. Otherwise, it will raise an error.
395+
396+
Possible cases that can be fixed by duplication:
397+
1. A constant tensor recieves different quantization parameters from its
398+
consumers. In this case, the tensor is marked for duplication.
399+
2. Two or more tensors share the same constant buffer and have different
400+
quantization parameters. In this case, the buffer is marked for
401+
duplication.
402+
403+
Raises:
404+
RuntimeError: If the tensors sharing the same buffer have different
405+
quantization settings and it can't be resolved by duplicating the
406+
buffer/tensor.
407+
"""
408+
buffers_to_duplicate = []
409+
tensor_names_to_duplicate = []
410+
for buffer_idx, tensors in self.buffer_to_tensors.items():
411+
if not tensors:
412+
continue
413+
# Check if any of the tensors needs to be duplicated.
414+
for tensor in tensors:
415+
if not self._check_buffer_sharing_for_tensor(tensor):
416+
tensor_names_to_duplicate.append(
417+
tfl_flatbuffer_utils.get_tensor_name(tensor)
418+
)
419+
# Check if the buffer needs to be duplicated.
420+
tensor_1 = tensors[0]
421+
tensor_name_1 = tfl_flatbuffer_utils.get_tensor_name(tensor_1)
422+
if tensor_name_1 in tensor_names_to_duplicate:
423+
buffers_to_duplicate.append(buffer_idx)
424+
continue
425+
for tensor_2 in tensors[1:]:
426+
tensor_name_2 = tfl_flatbuffer_utils.get_tensor_name(tensor_2)
427+
if (
428+
tensor_name_2 in tensor_names_to_duplicate
429+
or not self._check_buffer_sharing_for_self_compatible_tensors(
430+
tensor_1, tensor_2
431+
)
432+
):
433+
buffers_to_duplicate.append(buffer_idx)
434+
break
435+
436+
# Fix the buffer sharing issues.
437+
self._mark_tensors_requiring_buffer_duplication(buffers_to_duplicate)
438+
self._mark_tensors_requiring_tensor_duplication(tensor_names_to_duplicate)
354439

355440

356441
def _are_tensor_consumer_params_compatible(
@@ -447,12 +532,6 @@ def _compatible_tensor_params(
447532
return False
448533

449534

450-
def _are_distinct_tensors_with_shared_buffer(
451-
tensor1: Any, tensor2: Any, buffers: list[Any]
452-
) -> bool:
453-
"""Check if two tensors are different and share a constant buffer."""
454-
are_different_tensors = tensor1.name != tensor2.name
455-
do_share_buffer = tensor1.buffer == tensor2.buffer
456-
is_constant_buffer = buffers[tensor1.buffer].data is not None
457-
458-
return are_different_tensors and do_share_buffer and is_constant_buffer
535+
def _is_constant_tensor(tensor: Any, buffers: Sequence[Any]) -> bool:
536+
"""Check if the tensor is a constant tensor."""
537+
return buffers[tensor.buffer].data is not None

0 commit comments

Comments
 (0)