@@ -407,14 +407,18 @@ def get_observer(self) -> str:
407407
408408
409409def round_to_quantized_type_dtype (
410- tensor : torch .Tensor , dtype : torch .dtype
410+ tensor : torch .Tensor ,
411+ dtype : torch .dtype ,
412+ cast_to_original_dtype : Optional [bool ] = True ,
411413) -> torch .Tensor :
412414 """
413415 Rounds an input tensor to the nearest quantized representation given a dtype.
414416 The original dtype is kept post-rounding.
415417
416418 :param tensor: tensor to round
417419 :param dtype: dtype to use for rounding
420+ :param cast_to_original_dtype: whether or not we cast the rounded tensor to
421+ the original dtype
418422 :return: rounded tensor
419423 """
420424 original_dtype = tensor .dtype
@@ -425,14 +429,17 @@ def round_to_quantized_type_dtype(
425429 iinfo = torch .iinfo (dtype )
426430 rounded = torch .round (torch .clamp (tensor , iinfo .min , iinfo .max ))
427431
428- return rounded .to (original_dtype )
432+ if cast_to_original_dtype :
433+ return rounded .to (original_dtype )
434+ return rounded
429435
430436
431437def round_to_quantized_type_args (
432438 tensor : torch .Tensor ,
433439 args : QuantizationArgs ,
434440 min : torch .Tensor ,
435441 max : torch .Tensor ,
442+ cast_to_original_dtype : Optional [bool ] = True ,
436443) -> torch .Tensor :
437444 """
438445 Rounds an input tensor to the nearest quantized representation given
@@ -442,6 +449,8 @@ def round_to_quantized_type_args(
442449 :param args: quantization args to use for rounding
443450 :param min: min value to use for clamping
444451 :param max: max value to use for clamping
452+ :param cast_to_original_dtype: whether or not we cast the rounded tensor to
453+ the original dtype
445454 :return: rounded tensor
446455 """
447456
@@ -459,4 +468,6 @@ def round_to_quantized_type_args(
459468 else :
460469 raise ValueError (f"Invalid quantization type { args .type } " )
461470
462- return rounded .to (original_dtype )
471+ if cast_to_original_dtype :
472+ return rounded .to (original_dtype )
473+ return rounded
0 commit comments