diff --git a/llm/run_finetune.py b/llm/run_finetune.py index ecac55927e39..0954c2e14b0b 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -164,6 +164,13 @@ def main(): qlora_weight_blocksize=model_args.qlora_weight_blocksize, qlora_weight_double_quant=model_args.qlora_weight_double_quant, qlora_weight_double_quant_block_size=model_args.qlora_weight_double_quant_block_size, + apply_hadamard=model_args.apply_hadamard, + hadamard_block_size=model_args.hadamard_block_size, + quant_input_grad=model_args.quant_input_grad, + quant_weight_grad=model_args.quant_weight_grad, + apply_online_actscale_step=model_args.apply_online_actscale_step, + actscale_moving_rate=model_args.actscale_moving_rate, + fp8_format_type=model_args.fp8_format_type, ) model_config = AutoConfig.from_pretrained( @@ -293,6 +300,7 @@ def neft_post_hook(module, input, output): logging.info("Using ReFT with layers: ", reft_layers) # init chat_template for tokenizer init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template) + tokenizer.chat_template = None # if using chat_template, data_args.eval_with_do_generation must be false if tokenizer.chat_template is not None: @@ -445,7 +453,9 @@ def compute_metrics_do_generation(eval_preds): gen_args=gen_args, data_args=data_args, ) - trainable_parameters = [p for p in model.parameters() if not p.stop_gradient] + trainable_parameters = [ + p for p in model.parameters() if not p.stop_gradient or ("quantization_linear" in p.name and "w_1" in p.name) + ] trainer.set_optimizer_grouped_parameters(trainable_parameters) # Train diff --git a/paddlenlp/quantization/hadamard_utils.py b/paddlenlp/quantization/hadamard_utils.py index 531d8f208826..f31beda512d4 100644 --- a/paddlenlp/quantization/hadamard_utils.py +++ b/paddlenlp/quantization/hadamard_utils.py @@ -14,6 +14,8 @@ import paddle +from paddlenlp.utils import infohub + def matmul_hadU(X): @@ -31,22 +33,43 @@ def matmul_hadU(X): return input.reshape(X.shape) -def random_hadamard_matrix(size, dtype, is_block=False): - if not is_block: - A = paddle.randint(low=0, high=2, shape=[size, size]).astype("float32") * 2 - 1 - Q, _ = paddle.linalg.qr(A) - return Q.astype(dtype), 1 +def random_hadamard_matrix(block_size, dtype): + Q = paddle.diag(paddle.ones((block_size), dtype=dtype)) + block = matmul_hadU(Q) + return block + + +def create_hadamard_matrix(block_size, dtype): + Q = paddle.diag(paddle.ones((block_size), dtype=dtype)) + block = matmul_hadU(Q) + return block + + +def hadamard_matmul(input, side, hadamard_matrix, block_size): + # left -> H.T@input right -> input@H + origin_shape = input.shape + input = input.reshape([-1, origin_shape[-1]]) + if side == "left": + # H.T@input -> (input.T@H).T + input = input.transpose([1, 0]) + block_num = input.shape[-1] // block_size + output = input.reshape([-1, block_num, block_size]) @ hadamard_matrix + output = output.reshape([-1, block_num * block_size]) + if side == "left": + output = output.transpose([1, 0]) + output = output.reshape(origin_shape) + + return output + + +def apply_hadamard_matmul(x, side, block_size): + if getattr(infohub, "hadamard") is None: + setattr(infohub, "hadamard", {}) + + if block_size in infohub.hadamard: + hadamard_matrix = infohub.hadamard[block_size] else: - num_blocks = size - while not (num_blocks % 2): - num_blocks = num_blocks // 2 - block_size = size // num_blocks - Q = paddle.diag(paddle.ones((block_size,), dtype="float32")) - block = matmul_hadU(Q) - large_matrix = paddle.zeros([size, size]) - - for i in range(num_blocks): - start_row = i * block_size - start_col = i * block_size - large_matrix[start_row : start_row + block_size, start_col : start_col + block_size] = block - return large_matrix.cast(dtype), block_size + hadamard_matrix = create_hadamard_matrix(block_size, x.dtype) + infohub.hadamard[block_size] = hadamard_matrix + target_x = hadamard_matmul(x, side, hadamard_matrix, block_size) + return target_x diff --git a/paddlenlp/quantization/qat_utils.py b/paddlenlp/quantization/qat_utils.py index 693a3af6f8a3..ab6b2665e3c9 100644 --- a/paddlenlp/quantization/qat_utils.py +++ b/paddlenlp/quantization/qat_utils.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy import paddle from paddle.autograd import PyLayer -from paddlenlp.utils import infohub - -from .hadamard_utils import random_hadamard_matrix +from .hadamard_utils import apply_hadamard_matmul try: from transformer_engine import transformer_engine_paddle as tex @@ -35,147 +32,150 @@ except ImportError: USE_FP8_GEMM = False - -def quantize_tensorwise(x, quantization_config=None, bit_length=8, state=0, training=False, act_scale=None): - qmax = (1 << (bit_length - 1)) - 1 - qmin = -1 * qmax - 1 - if quantization_config.apply_hadamard: - target_x = x @ infohub.hadamard[x.shape[-1]][0] +QMIN_QMAX_MAPPING = { + "a8w8linear_activation": (-128, 127), + "a8w4linear_activation": (-128, 127), + "a8w8linear_weight": (-128, 127), + "a8w4linear_weight": (-8, 7), + "float8_e4m3fn": (-488, 488), + "float8_e5m2": (-57344, 57344), +} + + +def quantize( + x, + weight_quantize_algo, + tensor_type, + quantization_config, + side="right", + apply_hadamard=False, + act_scale=None, + state=0, + training=False, + group=None, +): + if apply_hadamard: + target_x = apply_hadamard_matmul(x, side, quantization_config.hadamard_block_size) + hadamard_scale = quantization_config.hadamard_block_size else: - target_x = x - - if act_scale is not None: - if training: - scale = paddle.max(paddle.abs(target_x)) / qmax + quantization_config.epsilon - if state < quantization_config.skip_first_act_scale_step: - act_scale.set_value((state * act_scale + scale) / (state + 1)) + target_x, hadamard_scale = x, 1.0 + if weight_quantize_algo in ["fp8linear"]: + qmin, qmax = QMIN_QMAX_MAPPING[quantization_config.fp8_format[tensor_type]] + else: + qmin, qmax = QMIN_QMAX_MAPPING[weight_quantize_algo + "_" + tensor_type] + if tensor_type == "activation": + if act_scale is not None: + if training: + scale = paddle.max(paddle.abs(target_x)) / qmax + if group is not None: + paddle.distributed.all_reduce(scale, op=paddle.distributed.ReduceOp.MAX, group=group, sync_op=True) + if state < quantization_config.apply_online_actscale_step: + act_scale[:] = (state * act_scale + scale) / (state + 1) + else: + scale = ( + 1 - quantization_config.actscale_moving_rate + ) * act_scale + quantization_config.actscale_moving_rate * scale + act_scale[:] = scale else: - act_scale.set_value( - (1 - quantization_config.moving_rate) * act_scale + quantization_config.moving_rate * scale - ) scale = act_scale else: - scale = act_scale - else: - scale = paddle.max(paddle.abs(target_x)) / qmax + quantization_config.epsilon - - x_int8 = paddle.clip((target_x / scale).round(), qmin, qmax).astype("int8") - return x_int8, scale - - -def dequantize_tensorwise(x_int8, scale, apply_hadamard=False): - x = x_int8.astype(scale.dtype) * scale - if apply_hadamard: - x = x @ infohub.hadamard[x.shape[-1]][0].T - return x - - -def fp8_quantize_tensorwise(x, tensor_type, quantization_config=None, state=0, training=False, act_scale=None): - assert tensor_type in ["weight", "activation", "grad_output"], "Only support weight, activation and grad_output" - fp8_format = quantization_config.fp8_format[tensor_type] - qmin, qmax = (-448, 448) if fp8_format == "float8_e4m3fn" else (-57344, 57344) - tensor_type_to_shape_index = {"weight": 0, "activation": -1, "grad_output": -2} - - if quantization_config is not None and quantization_config.apply_hadamard: - if getattr(infohub, "hadamard") is None: - setattr(infohub, "hadamard", {}) - - hadamard_matrix_shape = x.shape[tensor_type_to_shape_index[tensor_type]] - if hadamard_matrix_shape in infohub.hadamard: - hadamard_matrix, block_size = infohub.hadamard[hadamard_matrix_shape] + scale = paddle.max(paddle.abs(target_x)) / qmax + if weight_quantize_algo in ["a8w8linear", "a8w4linear"]: + quant_x = paddle.clip((target_x / scale).round(), qmin, qmax).astype("int8") + elif weight_quantize_algo in ["fp8linear"]: + quant_x = (target_x / scale).astype(quantization_config.fp8_format[tensor_type]).view("int8") else: - hadamard_matrix, block_size = random_hadamard_matrix(hadamard_matrix_shape, x.dtype, is_block=True) - infohub.hadamard[hadamard_matrix_shape] = (hadamard_matrix, block_size) - target_x = hadamard_matrix.T @ x if tensor_type in ["weight", "grad_output"] else x @ hadamard_matrix - else: - target_x = x - block_size = 1 - - if act_scale is not None: - if training: - scale = paddle.max(paddle.abs(target_x)) / qmax + quantization_config.epsilon - if state < quantization_config.skip_first_act_scale_step: - act_scale.set_value((state * act_scale + scale) / (state + 1)) - else: - act_scale.set_value( - (1 - quantization_config.moving_rate) * act_scale + quantization_config.moving_rate * scale - ) - # scale = act_scale + raise NotImplementedError(f"Unknown {weight_quantize_algo}.") + elif tensor_type == "weight": + if weight_quantize_algo in ["a8w8linear", "a8w4linear"]: + # channelwise + scale = paddle.max(paddle.abs(target_x), axis=0, keepdim=True) / qmax + if group is not None: + paddle.distributed.all_reduce(scale, op=paddle.distributed.ReduceOp.MAX, group=group, sync_op=True) + quant_x = paddle.clip((target_x / scale).round(), qmin, qmax).astype("int8").T + scale = scale.squeeze(0) / hadamard_scale + elif weight_quantize_algo in ["fp8linear"]: + scale = paddle.max(paddle.abs(target_x)) / qmax + if group is not None: + paddle.distributed.all_reduce(scale, op=paddle.distributed.ReduceOp.MAX, group=group, sync_op=True) + quant_x = (target_x / scale).astype(quantization_config.fp8_format[tensor_type]).view("int8").T + scale = (scale / hadamard_scale).reshape([1]) else: - scale = act_scale + raise NotImplementedError(f"Unknown {weight_quantize_algo}.") + elif tensor_type == "grad_output": + if weight_quantize_algo in ["fp8linear"]: + scale = paddle.max(paddle.abs(target_x)) / qmax + quant_x = (target_x / scale).astype(quantization_config.fp8_format[tensor_type]) + scale = scale / hadamard_scale + else: + raise NotImplementedError(f"Unknown {weight_quantize_algo}.") else: - scale = paddle.max(paddle.abs(target_x)) / qmax + quantization_config.epsilon - - x_fp8 = target_x / scale - x_fp8 = x_fp8.astype(fp8_format).view("int8") - x_fp8 = x_fp8.T if tensor_type == "weight" else x_fp8 + raise NotImplementedError(f"Unknown {tensor_type}.") scale.stop_gradient = True - scale = scale / block_size if tensor_type in ["weight", "grad_output"] else scale - return x_fp8, scale - - -def fp8_dequantize_tensorwise(x_fp8, scale, tensor_type, quantization_config=None): - x_fp8 = x_fp8.view(quantization_config.fp8_format[tensor_type]) - x_fp8 = x_fp8.T if tensor_type == "weight" else x_fp8 - x = x_fp8.astype(scale.dtype) * scale - if quantization_config.apply_hadamard: - hadamard_matrix_shape = x.shape[0] if tensor_type == "weight" else x.shape[-1] - hadamard_matrix, _ = infohub.hadamard[hadamard_matrix_shape] - x = hadamard_matrix @ x if tensor_type == "weight" else x @ hadamard_matrix.T - return x + return quant_x, scale -def quantize_channelwise(w, apply_hadamard=False, bit_length=8): - qmax = (1 << (bit_length - 1)) - 1 - qmin = -1 * qmax - 1 - if apply_hadamard: - if getattr(infohub, "hadamard") is None: - setattr(infohub, "hadamard", {}) - if w.shape[0] in infohub.hadamard: - hadamard_matrix, block_size = infohub.hadamard[w.shape[0]] +def dequantize( + quant_x, scale, tensor_type, weight_quantize_algo, quantization_config, apply_hadamard=False, side="left" +): + if tensor_type == "weight": + if weight_quantize_algo in ["a8w8linear", "a8w4linear"]: + x = quant_x.T.astype(scale.dtype) + elif weight_quantize_algo in ["fp8linear"]: + x = quant_x.view(quantization_config.fp8_format[tensor_type]).T.astype(scale.dtype) else: - hadamard_matrix, block_size = random_hadamard_matrix(w.shape[0], w.dtype, is_block=True) - infohub.hadamard[w.shape[0]] = (hadamard_matrix, block_size) - w = hadamard_matrix.T @ w + raise NotImplementedError(f"Unknown weight_quantize_algo: {weight_quantize_algo}") + if apply_hadamard: + x = apply_hadamard_matmul(x, side, quantization_config.hadamard_block_size) + x *= scale else: - block_size = 1 - scale = paddle.max(paddle.abs(w), axis=0, keepdim=True) / qmax - w_int8 = paddle.clip((w / scale).round(), qmin, qmax).astype("int8") - scale.stop_gradient = True - return w_int8.T, scale.squeeze(0) / block_size - - -def dequantize_channelwise(w_int8, scale, apply_hadamard=False): - w = w_int8.T.astype(scale.dtype) * scale - if apply_hadamard: - w = infohub.hadamard[w_int8.shape[1]][0] @ w - return w + raise NotImplementedError(f"Unknown {tensor_type}.") + return x -def a8w8_forward( - x, w_int8, w_scale=None, bias=None, dtype=None, quantization_config=None, state=0, training=False, act_scale=None +def int8_forward( + x, + quant_w, + scale_w, + weight_quantize_algo, + bias=None, + quantization_config=None, + state=0, + training=False, + act_scale=None, + group=None, ): - x_int8, x_scale = quantize_tensorwise( - x, quantization_config, bit_length=8, state=state, training=training, act_scale=act_scale + quant_x, scale_x = quantize( + x=x, + weight_quantize_algo=weight_quantize_algo, + tensor_type="activation", + quantization_config=quantization_config, + side="right", + apply_hadamard=quantization_config.apply_hadamard, + act_scale=act_scale, + state=state, + training=training, + group=group, ) - out = paddle.matmul(x_int8, w_int8.T).astype(dtype) * (x_scale * w_scale.unsqueeze(0)) + + out = paddle.matmul(quant_x, quant_w.T).astype(scale_w.dtype) * (scale_x * scale_w) if bias is not None: out += bias - return out, x_int8, x_scale + return out, quant_x, scale_x -def a8w8_backward(ctx, x, grad_output, quant_weight, quant_scale, quant_x, x_scale): +def int8_backward(ctx, x, grad_output, quant_weight, quant_scale, quant_x, x_scale): if not ctx.x_stop_gradient: - if ctx.quantization_config.quant_input_grad: - grad_output_int8, grad_output_scale = quantize_tensorwise(grad_output * quant_scale) - input_grad = paddle.matmul(grad_output_int8, quant_weight).astype(ctx.dtype) * grad_output_scale - if ctx.quantization_config.apply_hadamard: - input_grad = input_grad @ infohub.hadamard[quant_weight.shape[-1]][0].T - else: - qdq_weight = dequantize_channelwise( - quant_weight, quant_scale, apply_hadamard=ctx.quantization_config.apply_hadamard - ) - input_grad = paddle.matmul(grad_output, qdq_weight.T) + qdq_weight = dequantize( + quant_weight, + quant_scale, + "weight", + ctx.weight_quantize_algo, + ctx.quantization_config, + ctx.quantization_config.apply_hadamard, + "left", + ) + input_grad = paddle.matmul(grad_output, qdq_weight.T) else: input_grad = None @@ -193,15 +193,29 @@ def a8w8_backward(ctx, x, grad_output, quant_weight, quant_scale, quant_x, x_sca def fp8_forward( - x, w_fp8, w_scale=None, bias=None, dtype=None, quantization_config=None, state=0, training=False, act_scale=None + x, + w_fp8, + w_scale, + weight_quantize_algo, + bias=None, + dtype=None, + quantization_config=None, + state=0, + training=False, + act_scale=None, + group=None, ): - x_fp8, x_scale = fp8_quantize_tensorwise( + x_fp8, x_scale = quantize( x, - tensor_type="activation", - quantization_config=quantization_config, + weight_quantize_algo, + "activation", + quantization_config, + side="right", + apply_hadamard=quantization_config.apply_hadamard, + act_scale=act_scale, state=state, training=training, - act_scale=act_scale, + group=group, ) x_fp8 = x_fp8.view(quantization_config.fp8_format["activation"]) w_fp8 = w_fp8.view(quantization_config.fp8_format["weight"]) @@ -240,10 +254,13 @@ def fp8_forward( def fp8_backward(ctx, x, grad_output, quant_weight, quant_scale, quant_x, x_scale): if not ctx.x_stop_gradient: if ctx.quantization_config.quant_input_grad: - grad_output_fp8, grad_output_scale = fp8_quantize_tensorwise( + grad_output_fp8, grad_output_scale = quantize( grad_output, - tensor_type="grad_output", - quantization_config=ctx.quantization_config, + ctx.weight_quantize_algo, + "grad_output", + ctx.quantization_config, + side="left", + apply_hadamard=False, ) grad_output_fp8 = grad_output_fp8.view(ctx.quantization_config.fp8_format["grad_output"]) quant_weight = quant_weight.view(ctx.quantization_config.fp8_format["weight"]) @@ -271,11 +288,16 @@ def fp8_backward(ctx, x, grad_output, quant_weight, quant_scale, quant_x, x_scal weight_ = quant_weight.astype(ctx.dtype) * quant_scale input_grad = paddle.matmul(grad_output_, weight_).astype(ctx.dtype) if ctx.quantization_config.apply_hadamard: - input_grad = infohub.hadamard[grad_output.shape[-2]][0] @ input_grad - input_grad = input_grad @ infohub.hadamard[quant_weight.shape[-1]][0].T + input_grad = apply_hadamard_matmul(input_grad, "right", ctx.quantization_config.hadamard_block_size) else: - qdq_weight = fp8_dequantize_tensorwise( - quant_weight, quant_scale, tensor_type="weight", quantization_config=ctx.quantization_config + qdq_weight = dequantize( + quant_weight, + quant_scale, + "weight", + ctx.weight_quantize_algo, + ctx.quantization_config, + apply_hadamard=ctx.quantization_config.apply_hadamard, + side="left", ) input_grad = paddle.matmul(grad_output, qdq_weight.T) else: @@ -283,14 +305,13 @@ def fp8_backward(ctx, x, grad_output, quant_weight, quant_scale, quant_x, x_scal if not ctx.w_stop_gradient: if ctx.quantization_config.quant_weight_grad: - quantization_config_ = deepcopy(ctx.quantization_config) - quantization_config_.apply_hadamard = False - grad_output_fp8, grad_output_scale = fp8_quantize_tensorwise( - grad_output, + grad_output_fp8, grad_output_scale = quantize( + x=grad_output, + weight_quantize_algo=ctx.weight_quantize_algo, tensor_type="grad_output", - quantization_config=quantization_config_, + quantization_config=ctx.quantization_config, + apply_hadamard=False, ) - grad_output_fp8 = grad_output_fp8.view(ctx.quantization_config.fp8_format["grad_output"]) quant_x = quant_x.view(ctx.quantization_config.fp8_format["activation"]) if USE_FP8_GEMM: quant_x = quant_x.view((-1, quant_x.shape[-1])) @@ -321,9 +342,8 @@ def fp8_backward(ctx, x, grad_output, quant_weight, quant_scale, quant_x, x_scal grad_output_.reshape([-1, grad_output_.shape[-1]]), ).astype(ctx.dtype) if ctx.quantization_config.apply_hadamard: - hadamard_matrix, block_size = infohub.hadamard[quant_x.shape[-1]] - weight_grad = weight_grad / block_size - weight_grad = hadamard_matrix @ weight_grad + weight_grad = weight_grad / ctx.quantization_config.hadamard_block_size + weight_grad = apply_hadamard_matmul(weight_grad, "left", ctx.quantization_config.hadamard_block_size) else: if len(x.shape) == 2: weight_grad = paddle.matmul(x.transpose([1, 0]), grad_output) @@ -350,33 +370,39 @@ def forward( state, training, act_scale, + weight_quantize_algo, + group, ): quant_x, x_scale = None, None - if quantization_config.weight_quantize_algo in ["fp8linear"]: + if weight_quantize_algo in ["fp8linear"]: output, quant_x, x_scale = fp8_forward( x, quant_weight, w_scale=quant_scale, + weight_quantize_algo=weight_quantize_algo, bias=bias, dtype=dtype, quantization_config=quantization_config, state=state, training=training, act_scale=act_scale, + group=group, ) else: - output, quant_x, x_scale = a8w8_forward( + output, quant_x, x_scale = int8_forward( x, - quant_weight, - w_scale=quant_scale, + quant_w=quant_weight, + scale_w=quant_scale, + weight_quantize_algo=weight_quantize_algo, bias=bias, - dtype=dtype, quantization_config=quantization_config, state=state, training=training, act_scale=act_scale, + group=group, ) ctx.quantization_config = quantization_config + ctx.weight_quantize_algo = weight_quantize_algo ctx.dtype = dtype ctx.x_stop_gradient = x.stop_gradient ctx.w_stop_gradient = quant_weight.stop_gradient @@ -399,7 +425,7 @@ def backward(ctx, grad_output): if ctx.quantization_config.weight_quantize_algo in ["fp8linear"]: input_grad, weight_grad = fp8_backward(ctx, x, grad_output, quant_weight, quant_scale, quant_x, x_scale) else: - input_grad, weight_grad = a8w8_backward(ctx, x, grad_output, quant_weight, quant_scale, quant_x, x_scale) + input_grad, weight_grad = int8_backward(ctx, x, grad_output, quant_weight, quant_scale, quant_x, x_scale) if not ctx.b_stop_gradient: bias_grad = grad_output.sum(axis=[0, 1]) diff --git a/paddlenlp/quantization/quantization_config.py b/paddlenlp/quantization/quantization_config.py index 59e15cd63d0c..e68dd894d427 100644 --- a/paddlenlp/quantization/quantization_config.py +++ b/paddlenlp/quantization/quantization_config.py @@ -64,12 +64,12 @@ def __init__( dtype=None, ignore_modules=None, group_size=-1, - apply_hadamard=True, + apply_hadamard=False, + hadamard_block_size=32, quant_input_grad=False, quant_weight_grad=False, - skip_first_act_scale_step=20, - moving_rate=0.01, - epsilon=1e-8, + apply_online_actscale_step=200, + actscale_moving_rate=0.01, fp8_format_type="hybrid", **kwargs, ): @@ -139,11 +139,12 @@ def __init__( self.ignore_modules = ignore_modules self.group_size = group_size self.apply_hadamard = apply_hadamard + self.hadamard_block_size = hadamard_block_size self.quant_input_grad = quant_input_grad self.quant_weight_grad = quant_weight_grad - self.skip_first_act_scale_step = skip_first_act_scale_step - self.moving_rate = moving_rate - self.epsilon = epsilon + self.apply_online_actscale_step = apply_online_actscale_step + self.actscale_moving_rate = actscale_moving_rate + self.fp8_format_type = fp8_format_type self.fp8_format = fp8_format_mapping[fp8_format_type] def is_weight_quantize(self): @@ -214,6 +215,7 @@ def to_diff_dict(self): config_dict = self.to_dict() # get the default config dict + default_config_dict = QuantizationConfig().to_dict() serializable_config_dict = {} diff --git a/paddlenlp/quantization/quantization_linear.py b/paddlenlp/quantization/quantization_linear.py index 70e3ee80951a..1709d7a534ec 100644 --- a/paddlenlp/quantization/quantization_linear.py +++ b/paddlenlp/quantization/quantization_linear.py @@ -15,6 +15,7 @@ import paddle import paddle.nn as nn from paddle.autograd import PyLayer +from paddle.distributed import fleet from paddle.distributed.fleet.base import topology as tp from paddle.distributed.fleet.layers.mpu import mp_ops from paddle.distributed.fleet.utils.sequence_parallel_utils import ( @@ -23,6 +24,8 @@ ) from paddle.nn.quant import llm_int8_linear, weight_dequantize, weight_only_linear +from paddlenlp.utils import infohub + from .qat_utils import QATFunc try: @@ -209,10 +212,20 @@ def quant_weight_linear( ): if weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]: - state, training, act_scale = act_state + state, training, act_scale, group = act_state return QATFunc.apply( - x, quant_weight, bias, quant_scale, quantization_config, dtype, state, training, act_scale + x, + quant_weight, + bias, + quant_scale, + quantization_config, + dtype, + state, + training, + act_scale, + weight_quantize_algo, + group, ) else: return QuantizationLinearFunc.apply( @@ -228,6 +241,19 @@ def quant_weight_linear( ) +def get_act_scale_group(is_row=False): + if not paddle.distributed.is_initialized() or not is_row: + return None + + if getattr(infohub, "scale_group") is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + setattr(infohub, "scale_group", group) + else: + group = infohub.scale_group + return group + + class QuantizationLinear(nn.Layer): """Quantization Linear layer.""" @@ -278,9 +304,10 @@ def __init__( raise NotImplementedError("Not yet support grouwise weightonly quantization.") if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]: self.act_scale = self.create_parameter( - shape=[], dtype=self._dtype, is_bias=False, default_initializer=nn.initializer.Constant(value=0.0) + shape=[1], dtype=self._dtype, is_bias=False, default_initializer=nn.initializer.Constant(value=0.0) ) self.act_scale.stop_gradient = True + self.group = get_act_scale_group() elif self.weight_quantize_algo in ["fp4", "nf4"]: if qlora_weight_linear is None: @@ -340,6 +367,7 @@ def __init__( for p in self.parameters(): p.is_distributed = is_distributed p.mp_moe = mp_moe + self.quant_weight.weight_quantize_algo = self.weight_quantize_algo def forward(self, x): output = quant_weight_linear( @@ -354,7 +382,7 @@ def forward(self, x): if (self.weight_quantize_algo in ["fp4", "nf4"] and self.quantization_config.qlora_weight_double_quant) else None, bias=self.bias, - act_state=(self.state, self.training, self.act_scale) + act_state=(self.state, self.training, self.act_scale, self.group) if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"] else None, ) @@ -434,7 +462,10 @@ def __init__( is_bias=False, ) self.quant_scale.stop_gradient = True - self.quant_scale.is_distributed = True if self.is_mp else False + if self.weight_quantize_algo not in ["fp8linear", "a8w4linear", "fp8linear"]: + self.quant_scale.is_distributed = False + else: + self.quant_scale.is_distributed = True if self.is_mp else False if self.quant_scale.is_distributed: self.quant_scale.split_axis = 0 else: @@ -442,10 +473,11 @@ def __init__( raise NotImplementedError("Not yet support grouwise weightonly quantization.") if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]: self.act_scale = self.create_parameter( - shape=[], dtype=self._dtype, is_bias=False, default_initializer=nn.initializer.Constant(value=0.0) + shape=[1], dtype=self._dtype, is_bias=False, default_initializer=nn.initializer.Constant(value=0.0) ) - self.act_scale.is_distributed = True if self.is_mp else False + self.act_scale.is_distributed = False self.act_scale.stop_gradient = True + self.group = get_act_scale_group() else: raise NotImplementedError(f"Not yet support weight_quantize_algo: {self.weight_quantize_algo}") if bias_attr is False: @@ -460,6 +492,7 @@ def __init__( self.bias.is_distributed = True if self.is_mp else False if self.bias.is_distributed: self.bias.split_axis = 0 + self.quant_weight.weight_quantize_algo = self.weight_quantize_algo def forward(self, x): if self.is_mp: @@ -486,7 +519,7 @@ def forward(self, x): if (self.weight_quantize_algo in ["fp4", "nf4"] and self.quantization_config.qlora_weight_double_quant) else None, bias=self.bias, - act_state=(self.state, self.training, self.act_scale) + act_state=(self.state, self.training, self.act_scale, self.group) if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"] else None, ) @@ -573,7 +606,10 @@ def __init__( is_bias=False, ) self.quant_scale.stop_gradient = True - self.quant_scale.is_distributed = True if self.is_mp else False + if self.weight_quantize_algo not in ["fp8linear", "a8w4linear", "fp8linear"]: + self.quant_scale.is_distributed = False + else: + self.quant_scale.is_distributed = True if self.is_mp else False if self.quant_scale.is_distributed: self.quant_scale.split_axis = 0 else: @@ -583,8 +619,9 @@ def __init__( self.act_scale = self.create_parameter( shape=[1], dtype=self._dtype, is_bias=False, default_initializer=nn.initializer.Constant(value=0.0) ) - self.act_scale.is_distributed = True if self.is_mp else False + self.act_scale.is_distributed = False self.act_scale.stop_gradient = True + self.group = get_act_scale_group(is_row=True) else: raise NotImplementedError(f"Not yet support weight_quantize_algo: {self.weight_quantize_algo}") @@ -598,6 +635,8 @@ def __init__( is_bias=True, ) + self.quant_weight.weight_quantize_algo = self.weight_quantize_algo + def forward(self, x): if self.input_is_parallel or (not self.is_mp): input_parallel = x @@ -619,7 +658,7 @@ def forward(self, x): if (self.weight_quantize_algo in ["fp4", "nf4"] and self.quantization_config.qlora_weight_double_quant) else None, bias=None, - act_state=(self.state, self.training, self.act_scale) + act_state=(self.state, self.training, self.act_scale, self.group) if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"] else None, ) @@ -647,7 +686,7 @@ def forward(self, x): if (self.weight_quantize_algo in ["fp4", "nf4"] and self.quantization_config.qlora_weight_double_quant) else None, bias=self.bias, - act_state=(self.state, self.training, self.act_scale) + act_state=(self.state, self.training, self.act_scale, self.group) if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"] else None, ) diff --git a/paddlenlp/quantization/quantization_utils.py b/paddlenlp/quantization/quantization_utils.py index 4d320a979956..d5b0bf395ff4 100644 --- a/paddlenlp/quantization/quantization_utils.py +++ b/paddlenlp/quantization/quantization_utils.py @@ -34,7 +34,7 @@ qlora_weight_quantize = None from ..utils.log import logger -from .qat_utils import fp8_quantize_tensorwise, quantize_channelwise +from .qat_utils import quantize from .quantization_linear import ( ColumnParallelQuantizationLinear, QuantizationLinear, @@ -155,25 +155,17 @@ def convert_to_weight_quantize_state_dict(state_dict, name, quantization_config, if weight_name in state_dict: # gpu weight_quantize will fix in future target_weight = state_dict.pop(weight_name).cast(dtype).cuda() - if weight_quantize_algo in ["a8w8linear"]: - quant_weight, quant_scale = quantize_channelwise( - target_weight, quantization_config.apply_hadamard, bit_length=8 - ) - act_scale = paddle.zeros([], dtype="bfloat16").cuda() - act_scale.stop_gradient = True - state_dict[act_scale_name] = act_scale - elif weight_quantize_algo in ["a8w4linear"]: - quant_weight, quant_scale = quantize_channelwise( - target_weight, quantization_config.apply_hadamard, bit_length=4 - ) - act_scale = paddle.zeros([], dtype="bfloat16").cuda() - act_scale.stop_gradient = True - state_dict[act_scale_name] = act_scale - elif weight_quantize_algo in ["fp8linear"]: - quant_weight, quant_scale = fp8_quantize_tensorwise( - target_weight, tensor_type="weight", quantization_config=quantization_config + + if weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]: + quant_weight, quant_scale = quantize( + target_weight, + weight_quantize_algo, + "weight", + quantization_config, + side="left", + apply_hadamard=quantization_config.apply_hadamard, ) - act_scale = paddle.zeros([], dtype="bfloat16").cuda() + act_scale = paddle.ones([1], dtype=dtype).cuda() act_scale.stop_gradient = True state_dict[act_scale_name] = act_scale else: diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index df1cae02f341..aaea0fdbc4ea 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1952,8 +1952,10 @@ def apply_decay_param_fun(x): return x in decay_parameters optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) - if self.args.optim == OptimizerNames.AdamW_Qweight: + if self.args.optim == OptimizerNames.ADAMW_CUSTOM: optimizer_kwargs["quantization_config"] = self.model.config.quantization_config + optimizer_kwargs["use_lowprecision_moment"] = self.args.use_lowprecision_moment + optimizer_kwargs["tensorwise_offload_optimizer"] = self.args.tensorwise_offload_optimizer if hasattr(optimizer_cls, "_create_master_weight") and self.args.fp16_opt_level == "O2": optimizer_kwargs["multi_precision"] = True @@ -2107,16 +2109,6 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: optimizer_cls = AdamWCustom optimizer_kwargs.update(adam_kwargs) - elif args.optim == OptimizerNames.ADAMW_16BIT_MOMENT: - from ..utils import AdamW_16Bit - - optimizer_cls = AdamW_16Bit - optimizer_kwargs.update(adam_kwargs) - elif args.optim == OptimizerNames.AdamW_Qweight: - from ..utils import AdamWQweight - - optimizer_cls = AdamWQweight - optimizer_kwargs.update(adam_kwargs) else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs @@ -2755,7 +2747,7 @@ def _save_checkpoint(self, model, metrics=None): optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}") - if self.args.unified_checkpoint and self.args.offload_optim: + if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer): self._reload_optimizer() if self.args.use_hybrid_parallel: @@ -2838,7 +2830,7 @@ def _save_checkpoint(self, model, metrics=None): ): paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) - if self.args.unified_checkpoint and self.args.offload_optim: + if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer): self._offload_optimizer() self.runtime_timer.stop() diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 6eb401f2c262..629ef81dd451 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -318,8 +318,6 @@ class OptimizerNames(ExplicitEnum): ADAFACTOR = "adafactor" ADAMW_MINI = "adamw_mini" ADAMW_CUSTOM = "adamw_custom" - ADAMW_16BIT_MOMENT = "adamw_16bit_moment" - AdamW_Qweight = "adamw_qweight" class ShardingOption(ExplicitEnum): diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 6f528f939d85..ef5dc7851818 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -868,6 +868,10 @@ class TrainingArguments: default="adamw", metadata={"help": "The optimizer to use."}, ) + use_lowprecision_moment: bool = field( + default=False, + metadata={"help": "AdamW use lowbit moment as parameter."}, + ) report_to: Optional[List[str]] = field( default=None, metadata={"help": "The list of integrations to report the results and logs to."} ) @@ -996,6 +1000,10 @@ class TrainingArguments: default=False, metadata={"help": "Offload optimizer after optimizer.step()"}, ) + tensorwise_offload_optimizer: Optional[bool] = field( + default=False, + metadata={"help": "Offload optimizer tensor by tensor"}, + ) save_sharding_stage1_model_include_freeze_params: Optional[bool] = field( default=False, metadata={"help": "Save Sharding Stage1 Model Exclude Freeze Params"} ) diff --git a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py index dec974b2f5e8..deb85d6c7c87 100644 --- a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py +++ b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py @@ -286,7 +286,9 @@ def load_resolved_archive_file( if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): tp_actions = model._get_tensor_parallel_convert_actions(model_keys, is_split=True, ignore_error=True) else: - tp_actions = model.get_tensor_parallel_convert_actions(model.config, model_keys, ignore_error=True) + tp_actions = model.get_tensor_parallel_convert_actions( + model.config, model_keys, ignore_error=True, is_optim=True + ) if not is_master_weights: tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 30b0c9b1c0e1..b63f8c4266cc 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -648,7 +648,7 @@ def unified_optimizer_into_shards( tp_actions = model._get_tensor_parallel_convert_actions(model_keys, is_split=False, ignore_error=True) else: tp_actions = model.get_tensor_parallel_convert_actions( - model.config, model_keys, is_split=False, ignore_error=True + model.config, model_keys, is_split=False, ignore_error=True, is_optim=True ) logger.info("Unified optimizer tensor parallel in shards") optim_state_dict = merge_tensor_parallel_for_optimizer( diff --git a/paddlenlp/transformers/conversion_utils.py b/paddlenlp/transformers/conversion_utils.py index 6f4d69e157ce..6f3ab32ff9de 100644 --- a/paddlenlp/transformers/conversion_utils.py +++ b/paddlenlp/transformers/conversion_utils.py @@ -17,8 +17,10 @@ import inspect import json import os +import re from copy import deepcopy from dataclasses import dataclass +from functools import partial from typing import ( TYPE_CHECKING, Callable, @@ -58,6 +60,45 @@ PytorchTensor = TypeVar("PytorchTensor") +def add_quant_mapping(name_action_mappings, quantization_config, is_optim=False): + mapping_keys = list(name_action_mappings.keys()) + pattern = r"^(?:.*\.)?layers(\.[a-zA-Z0-9_]+)*\.weight$" + for key in mapping_keys: + if re.match(pattern, key): + quant_key = key.replace("weight", "quant_weight") + quant_scale_key = key.replace("weight", "quant_scale") + fn = name_action_mappings.pop(key) + if is_optim: + name_action_mappings[quant_key] = fn + else: + if isinstance(fn, partial): + if "is_column" in fn.keywords: + old_value = fn.keywords["is_column"] + new_value = not old_value + name_action_mappings[quant_key] = partial( + fn.func, *fn.args, **{**fn.keywords, "is_column": new_value} + ) + if quantization_config.weight_quantize_algo not in ["fp8linear"] and old_value: + name_action_mappings[quant_scale_key] = partial( + fn.func, *fn.args, **{**fn.keywords, "is_column": new_value} + ) + elif "is_quant" in fn.keywords: + old_value = fn.keywords["is_quant"] + new_value = not old_value + name_action_mappings[quant_key] = partial( + fn.func, *fn.args, **{**fn.keywords, "is_quant": new_value} + ) + if quantization_config.weight_quantize_algo not in ["fp8linear"]: + name_action_mappings[quant_scale_key] = split_or_merge_func( + is_split=fn.keywords["tensor_parallel_degree"], + tensor_parallel_degree=fn.keywords["tensor_parallel_degree"], + tensor_parallel_rank=fn.keywords["tensor_parallel_rank"], + num_attention_heads=fn.keywords["num_attention_head"], + ) + + return name_action_mappings + + def tensor_summary(tensor: Union[str, Tensor, PytorchTensor, tuple, list, ndarray]): """get summary of values which can be some of different values @@ -1207,8 +1248,12 @@ def get_tensor_parallel_convert_actions( is_split=True, ignore_error=False, base_model_prefix=None, + post_quantize=False, + is_optim=False, ): name_action_mappings = cls._get_tensor_parallel_mappings(config, is_split=is_split) + if config.quantization_config.is_weight_quantize() and not post_quantize: + name_action_mappings = add_quant_mapping(name_action_mappings, config.quantization_config, is_optim) state_keys_map = cls._resolve_prefix_keys( name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, base_model_prefix=base_model_prefix ) @@ -1230,6 +1275,8 @@ def convert_tensor_parallel( """ name_action_mappings = cls._get_tensor_parallel_mappings(config) + if config.quantization_config.is_weight_quantize(): + name_action_mappings = add_quant_mapping(name_action_mappings, config.quantization_config) if state_dict is None: with device_guard("cpu"): state_dict = paddle.load(weight_file, return_numpy=False) @@ -1261,6 +1308,8 @@ def merge_tensor_parallel(cls, state_dict, config) -> None: config (PretrainedConfig): the PretrainedConfig instance of model """ name_action_mappings = cls._get_tensor_parallel_mappings(config, is_split=False) + if config.quantization_config.is_weight_quantize(): + name_action_mappings = add_quant_mapping(name_action_mappings, config.quantization_config) state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), state_dict.keys()) for k, v in state_keys_map.items(): diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index 87b97671e4d7..c2644a53cfce 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -77,6 +77,8 @@ from ..generation import GenerationConfig, GenerationMixin from ..quantization.quantization_utils import ( convert_to_quantize_state_dict, + convert_to_weight_quantize_state_dict, + parse_weight_quantize_algo, replace_with_quantization_linear, update_loaded_state_dict_keys, ) @@ -360,7 +362,14 @@ def _split_keys_evenly(keys: list, n: int) -> list: def _load_part_state_dict( - keys, checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping, fliter_dict_keys, device + keys, + checkpoint_file: Union[str, os.PathLike], + tensor_parallel_split_mapping, + fliter_dict_keys, + device, + quantization_linear_list=None, + quantization_config=None, + dtype=None, ): """load part state dict from checkpoint file. @@ -391,15 +400,44 @@ def _load_part_state_dict( continue py_safe_slice_ = f.get_slice(key) - if key in tensor_parallel_split_mapping: - weight = tensor_parallel_split_mapping[key](py_safe_slice_) + if quantization_linear_list is not None and key.split(".weight")[0] in quantization_linear_list: + # numpy.array -> paddle.tensor + weight = paddle.Tensor.__call__(py_safe_slice_[:], zero_copy=True) + key_name = key.split(".weight")[0] + quant_key_name = key_name + ".quant_weight" + quant_scale_name = key_name + ".quant_scale" + # 16bit -> 4/8bit + quant_state_dict = convert_to_weight_quantize_state_dict( + state_dict={key: weight}, + name=key_name, + quantization_config=quantization_config, + dtype=dtype, + weight_quantize_algo=parse_weight_quantize_algo(quantization_config, quant_key_name), + ) + for key in list(quant_state_dict.keys()): + quant_state_dict[key] = quant_state_dict[key].numpy() + if quant_key_name in tensor_parallel_split_mapping: + quant_state_dict[quant_key_name] = tensor_parallel_split_mapping[quant_key_name]( + quant_state_dict[quant_key_name] + ) + if quant_scale_name in tensor_parallel_split_mapping: + quant_state_dict[quant_scale_name] = tensor_parallel_split_mapping[quant_scale_name]( + quant_state_dict[quant_scale_name] + ) + part_state_dict.update(quant_state_dict) else: - weight = py_safe_slice_[:] - if device == "expected": - with device_guard(): - weight = paddle.Tensor.__call__(weight, zero_copy=True) - weight = weight._copy_to(paddle.framework._current_expected_place(), False) - part_state_dict[key] = weight + if key in tensor_parallel_split_mapping: + weight = tensor_parallel_split_mapping[key](py_safe_slice_.get()) + else: + if len(py_safe_slice_.shape) == 0: + weight = py_safe_slice_.get() + else: + weight = py_safe_slice_[:] + if device == "expected": + with device_guard(): + weight = paddle.Tensor.__call__(weight, zero_copy=True) + weight = weight._copy_to(paddle.framework._current_expected_place(), False) + part_state_dict[key] = weight for key in keys: if ( key.endswith(SYMMETRY_QUANT_SCALE) @@ -420,6 +458,9 @@ def load_state_dict( fliter_dict_keys=None, device="cpu", ckpt_quant_stage="O0", + quantization_linear_list=None, + quantization_config=None, + dtype=None, ): """ Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise. @@ -455,6 +496,9 @@ def load_state_dict( tensor_parallel_split_mapping, fliter_dict_keys, device, + quantization_linear_list, + quantization_config, + dtype, ) else: # Load state dict in multi-thread to speed up loading @@ -469,6 +513,9 @@ def load_state_dict( tensor_parallel_split_mapping, fliter_dict_keys, device, + quantization_linear_list, + quantization_config, + dtype, ): keys for keys in keys_groups } @@ -478,8 +525,8 @@ def load_state_dict( scale_dict.update(res_scale_dict) if device == "cpu": - for k in list(state_dict.keys()): - with device_guard(): + with device_guard(): + for k in list(state_dict.keys()): state_dict[k] = paddle.Tensor.__call__(state_dict.pop(k), zero_copy=True) if len(scale_dict) != 0: @@ -1981,7 +2028,22 @@ def _load_pretrained_model( # Weight quantization if not yet quantized & update loaded_keys if quantization_linear_list is not None: - origin_loaded_keys = copy.deepcopy(loaded_keys) + if isinstance(config.quantization_config.weight_quantize_algo, str): + post_quantize = config.quantization_config.weight_quantize_algo in [ + "weight_only_int4", + "weight_only_int8", + ] + elif isinstance(config.quantization_config.weight_quantize_algo, dict): + post_quantize = any( + key in ["weight_only_int4", "weight_only_int8"] + for key in config.quantization_config.weight_quantize_algo.keys() + ) + else: + post_quantize = False + if post_quantize: + origin_loaded_keys = copy.deepcopy(loaded_keys) + else: + origin_loaded_keys = list(model.state_dict()) loaded_keys = update_loaded_state_dict_keys( loaded_keys, quantization_linear_list, config.quantization_config ) @@ -2186,19 +2248,35 @@ def _fuse_or_split_keys( pre_tensor_parallel_split = True assert origin_loaded_keys is not None, "loaded_keys is not None." tp_actions = cls.get_tensor_parallel_convert_actions( - config, origin_loaded_keys, ignore_error=True, base_model_prefix=prefix + config, + origin_loaded_keys, + ignore_error=True, + base_model_prefix=prefix, + post_quantize=post_quantize, + ) + if post_quantize: + # Split -> quantize(Not support mdoel save) + state_dict = load_state_dict( + shard_file, + tp_actions if pre_tensor_parallel_split else None, + None, + ) + state_dict = convert_to_quantize_state_dict( + state_dict, + quantization_linear_list, + config.quantization_config, + dtype, + ) + else: + # quantize -> split(Support mdoel save) + state_dict = load_state_dict( + shard_file, + tp_actions if pre_tensor_parallel_split else None, + None, + quantization_linear_list=quantization_linear_list, + quantization_config=config.quantization_config, + dtype=dtype, ) - state_dict = load_state_dict( - shard_file, - tp_actions if pre_tensor_parallel_split else None, - None, - ) - state_dict = convert_to_quantize_state_dict( - state_dict, - quantization_linear_list, - config.quantization_config, - dtype, - ) else: if ( shard_file.endswith(".safetensors") @@ -2537,6 +2615,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): # load pt weights early so that we know which dtype to init the model under if not is_sharded and state_dict is None: # 4. loading non-sharded ckpt from the state dict + # Quantization: Loading non-sharded ckpt does not support saving with merge_tensor_parallel if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model_state.pdparams"): state_dict = cls.convert_tensor_parallel(resolved_archive_file, config) elif config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model.safetensors"): diff --git a/paddlenlp/transformers/utils.py b/paddlenlp/transformers/utils.py index 76b3279b87db..07bcf9edc2c6 100644 --- a/paddlenlp/transformers/utils.py +++ b/paddlenlp/transformers/utils.py @@ -850,6 +850,8 @@ def dtype_byte_size(dtype): """ if dtype == paddle.bool: return 1 / 8 + if "float8" in str(dtype): + return 1 bit_search = re.search(r"[^\d](\d+)$", str(dtype)) if bit_search is None: raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") diff --git a/paddlenlp/trl/model_config.py b/paddlenlp/trl/model_config.py index 2aa5c5ebc7ac..1e60114ba463 100644 --- a/paddlenlp/trl/model_config.py +++ b/paddlenlp/trl/model_config.py @@ -47,25 +47,6 @@ class ModelConfig: "help": "Whether to train from existing paddlenlp model weights. If set True, the model_name_or_path argument must exist in the paddlenlp models." }, ) - weight_quantize_algo: str = field( - default=None, - metadata={ - "help": "Model weight quantization algorithm including 'nf4', 'fp4','weight_only_int4', 'weight_only_int8'." - }, - ) - qlora_weight_blocksize: int = field( - default=64, - metadata={"help": "Block size for weight quantization(Only available for nf4 or fp4 quant_scale.)."}, - ) - qlora_weight_double_quant: bool = field( - default=False, metadata={"help": "Whether apply double quant(Only available for nf4 or fp4 quant_scale.)."} - ) - qlora_weight_double_quant_block_size: int = field( - default=256, - metadata={ - "help": "Block size for quant_scale of weight quant_scale(Only available for nf4 or fp4 quant_scale.)" - }, - ) # LoRA related parameters lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"}) @@ -123,3 +104,33 @@ class ModelConfig: rope_scaling_factor: float = field(default=1.0, metadata={"help": "Rope extension scaling factor"}) strategy_type: str = field(default=None, metadata={"help": "Long sequence strategy type"}) strategy_name: str = field(default=None, metadata={"help": "Long sequence strategy name"}) + + # Quantization Training Related + weight_quantize_algo: str = field( + default=None, + metadata={ + "help": "Model weight quantization algorithm including 'nf4', 'fp4','weight_only_int4', 'weight_only_int8'." + }, + ) + qlora_weight_blocksize: int = field( + default=64, + metadata={"help": "Block size for weight quantization(Only available for nf4 or fp4 quant_scale.)."}, + ) + qlora_weight_double_quant: bool = field( + default=False, metadata={"help": "Whether apply double quant(Only available for nf4 or fp4 quant_scale.)."} + ) + qlora_weight_double_quant_block_size: int = field( + default=256, + metadata={ + "help": "Block size for quant_scale of weight quant_scale(Only available for nf4 or fp4 quant_scale.)" + }, + ) + apply_hadamard: bool = field(default=False, metadata={"help": "Whether to apply hadamard"}) + hadamard_block_size: int = field(default=32, metadata={"help": "hadamard block size"}) + quant_input_grad: bool = field(default=False, metadata={"help": "Whether to quantize input grad"}) + quant_weight_grad: bool = field(default=False, metadata={"help": "Whether to quantize weight grad"}) + apply_online_actscale_step: int = field( + default=200, metadata={"help": "Use online activation scale for first N step to keep stable training."} + ) + actscale_moving_rate: float = field(default=0.01, metadata={"help": "EMA moving_rate for activation scale"}) + fp8_format_type: str = field(default="hybrid", metadata={"help": "FP8 Format"}) diff --git a/paddlenlp/utils/adamw_triton.py b/paddlenlp/utils/adamw_triton.py new file mode 100644 index 000000000000..08a53544dc63 --- /dev/null +++ b/paddlenlp/utils/adamw_triton.py @@ -0,0 +1,194 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import triton +import triton.language as tl + +DTYPE_MAPPING = { + paddle.bfloat16: tl.bfloat16, + paddle.float32: tl.float32, + paddle.float16: tl.float16, +} + + +@triton.jit +def adamw_kernel( + param_ptr, + grad_ptr, + moment1_ptr, + moment2_ptr, + lr_ptr, + beta1, + beta2, + epsilon, + coeff, + beta1_pow_ptr, + beta2_pow_ptr, + master_weight_ptr, + N, + skip_update_param, + param_dtype: tl.constexpr, + moment_dtype: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + if master_weight_ptr is not None: + param = tl.load(master_weight_ptr + offsets, mask=mask) + else: + param = tl.load(param_ptr + offsets, mask=mask).to(tl.float32) + grad = tl.load(grad_ptr + offsets, mask=mask).to(tl.float32) + + moment1 = tl.load(moment1_ptr + offsets, mask=mask).to(tl.float32) + moment2 = tl.load(moment2_ptr + offsets, mask=mask).to(tl.float32) + lr = tl.load(lr_ptr) + beta1_pow = tl.load(beta1_pow_ptr) + beta2_pow = tl.load(beta2_pow_ptr) + + # Weight Decay + param *= 1.0 - lr * coeff + + # AdamW + moment1 = beta1 * moment1 + (1.0 - beta1) * grad + moment2 = beta2 * moment2 + (1.0 - beta2) * grad * grad + denom = tl.sqrt(moment2) / tl.sqrt(1.0 - beta2_pow) + epsilon + param += (moment1 / denom) * (-lr / (1 - beta1_pow)) + # Update param + if master_weight_ptr is not None: + tl.store(master_weight_ptr + offsets, param, mask=mask) + if not skip_update_param: + tl.store(param_ptr + offsets, param.to(param_dtype), mask=mask) + else: + tl.store(param_ptr + offsets, param.to(param_dtype), mask=mask) + tl.store(moment1_ptr + offsets, moment1.to(moment_dtype), mask=mask) + tl.store(moment2_ptr + offsets, moment2.to(moment_dtype), mask=mask) + + +@triton.jit +def adamw_kernel_skip( + grad_ptr, + moment1_ptr, + moment2_ptr, + lr_ptr, + beta1, + beta2, + epsilon, + coeff, + beta1_pow_ptr, + beta2_pow_ptr, + master_weight_ptr, + N, + skip_update_param, + moment_dtype: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + param = tl.load(master_weight_ptr + offsets, mask=mask) + grad = tl.load(grad_ptr + offsets, mask=mask).to(tl.float32) + + moment1 = tl.load(moment1_ptr + offsets, mask=mask).to(tl.float32) + moment2 = tl.load(moment2_ptr + offsets, mask=mask).to(tl.float32) + lr = tl.load(lr_ptr) + beta1_pow = tl.load(beta1_pow_ptr) + beta2_pow = tl.load(beta2_pow_ptr) + + # Weight Decay + param *= 1.0 - lr * coeff + + # AdamW + moment1 = beta1 * moment1 + (1.0 - beta1) * grad + moment2 = beta2 * moment2 + (1.0 - beta2) * grad * grad + denom = tl.sqrt(moment2) / tl.sqrt(1.0 - beta2_pow) + epsilon + param += (moment1 / denom) * (-lr / (1 - beta1_pow)) + # Update param + tl.store(master_weight_ptr + offsets, param, mask=mask) + tl.store(moment1_ptr + offsets, moment1.to(moment_dtype), mask=mask) + tl.store(moment2_ptr + offsets, moment2.to(moment_dtype), mask=mask) + + +def adamw_triton( + param, + grad, + learning_rate, + moment1, + moment2, + beta1_pow, + beta2_pow, + master_weight, + skip_update, + beta1, + beta2, + epsilon, + lr_ratio, + coeff, + with_decay, + multi_precision, + skip_update_param=False, +): + if skip_update: + return + if not with_decay: + coeff = 0.0 + if not multi_precision: + master_weight = None + lr = learning_rate * lr_ratio + + N = param.numel().item() + BLOCK_SIZE = 512 + grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),) + if skip_update_param: + adamw_kernel_skip[grid]( + grad, + moment1, + moment2, + lr, + beta1, + beta2, + epsilon, + coeff, + beta1_pow, + beta2_pow, + master_weight, + N, + skip_update_param, + DTYPE_MAPPING[moment1.dtype], + BLOCK_SIZE, + ) + else: + adamw_kernel[grid]( + param, + grad, + moment1, + moment2, + lr, + beta1, + beta2, + epsilon, + coeff, + beta1_pow, + beta2_pow, + master_weight, + N, + skip_update_param, + tl.float32 if skip_update_param else DTYPE_MAPPING[param.dtype], # no meaning for tl.float32 + DTYPE_MAPPING[moment1.dtype], + BLOCK_SIZE, + ) + beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:] diff --git a/paddlenlp/utils/optimizer.py b/paddlenlp/utils/optimizer.py index fb62a485913c..8d98612445f8 100644 --- a/paddlenlp/utils/optimizer.py +++ b/paddlenlp/utils/optimizer.py @@ -11,27 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings import paddle from paddle import pir from paddle.base import core, framework from paddle.base.framework import Variable, in_dynamic_or_pir_mode, in_pir_mode from paddle.base.libpaddle import DataType +from paddle.distributed import fleet from paddle.optimizer.adamw import AdamW from paddle.pir import Value try: - from paddlenlp_kernel.triton.optimizer import adamw_16bit_moment + from .adamw_triton import adamw_triton except: - adamw_16bit_moment = None + adamw_triton = None + print("Please install triton to use faster optimizer") -from ..quantization.qat_utils import ( - dequantize_channelwise, - fp8_dequantize_tensorwise, - fp8_quantize_tensorwise, - quantize_channelwise, -) + +from ..quantization.qat_utils import dequantize, quantize class AdamWMini(AdamW): @@ -164,108 +161,19 @@ def adamw_python( class AdamWCustom(AdamW): - def _append_optimize_op(self, block, param_and_grad): - assert isinstance(block, (framework.Block, pir.Block)) - if isinstance(param_and_grad, dict): - param_and_grad = self._update_param_group(param_and_grad) - param, grad = param_and_grad - - # Whether we should do weight decay for the parameter. - with_decay = True - if self._apply_decay_param_fun is not None and not self._apply_decay_param_fun(param.name): - with_decay = False - - moment1 = self._get_accumulator_master(self._moment1_acc_str, param_and_grad[0]) - moment2 = self._get_accumulator_master(self._moment2_acc_str, param_and_grad[0]) - beta1_pow_acc = self._get_accumulator_master(self._beta1_pow_acc_str, param_and_grad[0]) - beta2_pow_acc = self._get_accumulator_master(self._beta2_pow_acc_str, param_and_grad[0]) - find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype) - master_weight = self._master_weights[param_and_grad[0].name] if find_master else None - lr = self._create_param_lr(param_and_grad) - # create the adamw optimize op - if in_dynamic_or_pir_mode(): - lr_ratio_ = 1.0 if self._lr_ratio is None else self._lr_ratio(param_and_grad[0]) - - _beta1 = self._beta1 if not isinstance(self._beta1, Variable) else self._beta1.item(0) - _beta2 = self._beta2 if not isinstance(self._beta2, Variable) else self._beta2.item(0) - - found_inf = self._get_auxiliary_var("found_inf") if in_pir_mode() else None - self.adamw_custom( - param_and_grad[0], - param_and_grad[1], - lr, - moment1, - moment2, - beta1_pow_acc, - beta2_pow_acc, - master_weight, - found_inf, - _beta1, - _beta2, - self._epsilon, - lr_ratio_, - self._weight_decay, - with_decay, - find_master, - ) - return None - else: - raise NotImplementedError("Not implemented yet.") - - def adamw_custom( - self, - param, - grad, - learning_rate, - moment1, - moment2, - beta1_pow, - beta2_pow, - master_weight, - skip_update, - beta1, - beta2, - epsilon, - lr_ratio, - coeff, - with_decay, - multi_precision, - ): - if skip_update: - return - if not with_decay: - coeff = 0.0 - if not multi_precision: - master_weight = None - lr = learning_rate * lr_ratio - if master_weight is not None: - p = master_weight - else: - p = param - p *= 1.0 - lr * coeff - mom1 = moment1 - mom2 = moment2 - - mom1 = beta1 * mom1 + (1.0 - beta1) * grad - mom2 = beta2 * mom2 + (1.0 - beta2) * grad * grad - denom = mom2.sqrt() / (1.0 - beta2_pow).sqrt() + epsilon - p += (mom1 / denom) * (-(lr / (1.0 - beta1_pow))) - if master_weight is not None: - master_weight[:] = p - param[:] = p.astype(param.dtype) - else: - param[:] = p - moment1[:] = mom1 - moment2[:] = mom2 - beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:] - return - + def __init__(self, quantization_config, tensorwise_offload_optimizer, *args, **kwargs): + super().__init__(*args, **kwargs) + self.quant_scale_mapping = {} + for p in self._param_groups: + if "quantization_linear" in p.name and "w_1" in p.name: + self.quant_scale_mapping[p.name.replace("w_1", "w_0")] = p + self.quantization_config = quantization_config + self._hcg = fleet.get_hybrid_communicate_group() + self.mp_group = self._hcg.get_model_parallel_group() + self.tensorwise_offload_optimizer = tensorwise_offload_optimizer -class AdamW_16Bit(AdamW): def _add_moments_pows(self, p, moment_dtype=core.VarDesc.VarType.FP32): acc_dtype = p.dtype - if self._is_dtype_fp16_or_bf16(acc_dtype): - acc_dtype = DataType.FLOAT32 if in_pir_mode() else core.VarDesc.VarType.FP32 self._add_accumulator(self._moment1_acc_str, p, dtype=moment_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=moment_dtype) @@ -301,130 +209,31 @@ def _create_accumulators(self, block, parameters): continue if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) - if str(p.dtype) == "paddle.float16": - moment_dtype = core.VarDesc.VarType.FP16 - elif str(p.dtype) == "paddle.bfloat16": - moment_dtype = core.VarDesc.VarType.BF16 + if self._use_lowprecision_moment: + if p.name in self.quant_scale_mapping: + p_scale = self.quant_scale_mapping[p.name] + if str(p_scale.dtype) == "paddle.float16": + moment_dtype = core.VarDesc.VarType.FP16 + elif str(p_scale.dtype) == "paddle.bfloat16": + moment_dtype = core.VarDesc.VarType.BF16 + else: + if str(p.dtype) == "paddle.float16": + moment_dtype = core.VarDesc.VarType.FP16 + elif str(p.dtype) == "paddle.bfloat16": + moment_dtype = core.VarDesc.VarType.BF16 + else: + moment_dtype = core.VarDesc.VarType.FP32 self._add_moments_pows(master_p, moment_dtype) self._already_create_accumulator.add(p.name) - continue - if self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision: - warnings.warn( - "Accumulating with FP16 or BF16 in optimizer can lead to poor accuracy or slow convergence." - "Consider using multi_precision=True option of the Adam optimizer." - ) - self._add_moments_pows(p) - self._already_create_accumulator.add(p.name) - - def _append_optimize_op(self, block, param_and_grad): - assert isinstance(block, (framework.Block, pir.Block)) - if isinstance(param_and_grad, dict): - param_and_grad = self._update_param_group(param_and_grad) - param, grad = param_and_grad - # Whether we should do weight decay for the parameter. - with_decay = True - if self._apply_decay_param_fun is not None and not self._apply_decay_param_fun(param.name): - with_decay = False - - moment1 = self._get_accumulator_master(self._moment1_acc_str, param_and_grad[0]) - moment2 = self._get_accumulator_master(self._moment2_acc_str, param_and_grad[0]) - beta1_pow_acc = self._get_accumulator_master(self._beta1_pow_acc_str, param_and_grad[0]) - beta2_pow_acc = self._get_accumulator_master(self._beta2_pow_acc_str, param_and_grad[0]) - find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype) - master_weight = self._master_weights[param_and_grad[0].name] if find_master else None - lr = self._create_param_lr(param_and_grad) - # create the adamw optimize op - if in_dynamic_or_pir_mode(): - lr_ratio_ = 1.0 if self._lr_ratio is None else self._lr_ratio(param_and_grad[0]) - - _beta1 = self._beta1 if not isinstance(self._beta1, Variable) else self._beta1.item(0) - _beta2 = self._beta2 if not isinstance(self._beta2, Variable) else self._beta2.item(0) - - found_inf = self._get_auxiliary_var("found_inf") if in_pir_mode() else None - apply_adamw = self.adamw_16bit_moment if adamw_16bit_moment is None else adamw_16bit_moment - apply_adamw( - param_and_grad[0], - param_and_grad[1], - lr, - moment1, - moment2, - beta1_pow_acc, - beta2_pow_acc, - master_weight, - found_inf, - _beta1, - _beta2, - self._epsilon, - lr_ratio_, - self._weight_decay, - with_decay, - find_master, - ) - return None - else: - raise NotImplementedError("Not implemented yet.") - - def adamw_16bit_moment( - self, - param, - grad, - learning_rate, - moment1, - moment2, - beta1_pow, - beta2_pow, - master_weight, - skip_update, - beta1, - beta2, - epsilon, - lr_ratio, - coeff, - with_decay, - multi_precision, - ): - if skip_update: - return - if not with_decay: - coeff = 0.0 - if not multi_precision: - master_weight = None - lr = learning_rate * lr_ratio - if master_weight is not None: - p = master_weight - else: - p = param - p *= 1.0 - lr * coeff - moment_dtype = moment1.dtype - mom1 = moment1 - mom2 = moment2 - - mom1 = beta1 * mom1 + (1.0 - beta1) * grad - mom2 = beta2 * mom2 + (1.0 - beta2) * grad * grad - denom = mom2.sqrt() / (1.0 - beta2_pow).sqrt() + epsilon - p += (mom1 / denom) * (-(lr / (1.0 - beta1_pow))).astype("float32") - if master_weight is not None: - master_weight[:] = p - param[:] = p.astype(param.dtype) - else: - param[:] = p - moment1[:] = mom1.astype(moment_dtype) - moment2[:] = mom2.astype(moment_dtype) - beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:] - return - - -class AdamWQweight(AdamW): - def __init__(self, quantization_config, *args, **kwargs): - super().__init__(*args, **kwargs) - self.quant_scale_mapping = {} - for p in self._param_groups: - if "quantization_linear" in p.name and "w_1" in p.name: - self.quant_scale_mapping[p.name.replace("w_1", "w_0")] = p - - self.quantization_config = quantization_config + elif self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision: + raise NotImplementedError("AdamWCustom only support AMP training") + else: + self._add_moments_pows(p) + self._already_create_accumulator.add(p.name) + if self.tensorwise_offload_optimizer: + self.offload_optim(p) def _create_master_weight(self, param): if param.name in self._master_weights: @@ -433,14 +242,20 @@ def _create_master_weight(self, param): var_name = self._gen_master_weight_var_name(param) if param.name in self.quant_scale_mapping: quant_scale = self.quant_scale_mapping[param.name] - if self.quantization_config.weight_quantize_algo in ["fp8linear"]: - var = fp8_dequantize_tensorwise( - param, quant_scale, tensor_type="weight", quantization_config=self.quantization_config + if self.quantization_config.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]: + var = dequantize( + param, + quant_scale, + "weight", + self.quantization_config.weight_quantize_algo, + self.quantization_config, + apply_hadamard=self.quantization_config.apply_hadamard, + side="left", ).astype("float32") else: - var = dequantize_channelwise( - param, quant_scale, apply_hadamard=self.quantization_config.apply_hadamard - ).astype("float32") + raise NotImplementedError( + f"Unknown weight_quantize_algo {self.quantization_config.weight_quantize_algo}" + ) else: var = paddle.cast(param, "float32") var.name = var_name @@ -474,6 +289,9 @@ def _append_optimize_op(self, block, param_and_grad): if self._apply_decay_param_fun is not None and not self._apply_decay_param_fun(param.name): with_decay = False + if self.tensorwise_offload_optimizer: + self.reload_optim(param) + moment1 = self._get_accumulator_master(self._moment1_acc_str, param_and_grad[0]) moment2 = self._get_accumulator_master(self._moment2_acc_str, param_and_grad[0]) beta1_pow_acc = self._get_accumulator_master(self._beta1_pow_acc_str, param_and_grad[0]) @@ -493,7 +311,9 @@ def _append_optimize_op(self, block, param_and_grad): _beta2 = self._beta2 if not isinstance(self._beta2, Variable) else self._beta2.item(0) found_inf = self._get_auxiliary_var("found_inf") if in_pir_mode() else None - self.adamw_custom( + skip_update_param = quant_scale is not None + apply_adamw = self.adamw_custom if adamw_triton is None else adamw_triton + apply_adamw( param_and_grad[0], param_and_grad[1], lr, @@ -510,8 +330,33 @@ def _append_optimize_op(self, block, param_and_grad): self._weight_decay, with_decay, find_master, - quant_scale, + skip_update_param, ) + if skip_update_param: + if param.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]: + if "parallel_quantization_linear" not in param.name: + group = None + elif param.weight_quantize_algo in ["a8w8linear", "a8w4linear"] and "row" in param.name: + group = None + else: + group = self.mp_group + print(param.name, master_weight.shape, param.shape, moment1.shape, moment2.shape) + param[:], quant_scale[:] = quantize( + x=master_weight.astype(quant_scale.dtype), + weight_quantize_algo=self.quantization_config.weight_quantize_algo, + tensor_type="weight", + quantization_config=self.quantization_config, + side="left", + apply_hadamard=self.quantization_config.apply_hadamard, + group=group, + ) + else: + raise NotImplementedError( + f"Please check your weight_quantize_algo {self.quantization_config.weight_quantize_algo}." + ) + if self.tensorwise_offload_optimizer: + self.offload_optim(param) + return None else: raise NotImplementedError("Not implemented yet.") @@ -534,7 +379,7 @@ def adamw_custom( coeff, with_decay, multi_precision, - quant_scale, + skip_update_param, ): if skip_update: return @@ -549,8 +394,9 @@ def adamw_custom( p = param p *= 1.0 - lr * coeff - mom1 = moment1 - mom2 = moment2 + moment_dtype = moment1.dtype + mom1 = moment1.astype("float32") + mom2 = moment2.astype("float32") mom1 = beta1 * mom1 + (1.0 - beta1) * grad mom2 = beta2 * mom2 + (1.0 - beta2) * grad * grad @@ -559,27 +405,35 @@ def adamw_custom( if master_weight is not None: master_weight[:] = p - if quant_scale is None: + if not skip_update_param: param[:] = p.astype(param.dtype) - else: - if self.quantization_config.weight_quantize_algo in ["fp8linear"]: - param[:], new_quant_scale = fp8_quantize_tensorwise( - p.astype("bfloat16"), tensor_type="weight", quantization_config=self.quantization_config - ) - quant_scale.set_value(new_quant_scale) - else: - if p.shape[1] == param.shape[0]: - bit_length = 8 - elif p.shape[1] / 2 == param.shape[0]: - bit_length = 4 - param[:], quant_scale[:] = quantize_channelwise( - p.astype("bfloat16"), - apply_hadamard=self.quantization_config.apply_hadamard, - bit_length=bit_length, - ) else: param[:] = p - moment1[:] = mom1 - moment2[:] = mom2 + moment1[:] = mom1.astype(moment_dtype) + moment2[:] = mom2.astype(moment_dtype) beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:] return + + def offload_optim(self, p): + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype) + if find_master: + self._master_weights[p.name] = self._master_weights[p.name].pin_memory() + target_name = self._master_weights[p.name].name + else: + target_name = p.name + for name in [self._moment1_acc_str, self._moment2_acc_str]: + if self._name is not None: + name = self._name + "_" + name + self._accumulators[name][target_name] = self._accumulators[name][target_name].pin_memory() + + def reload_optim(self, p): + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype) + if find_master: + self._master_weights[p.name] = self._master_weights[p.name].cuda() + target_name = self._master_weights[p.name].name + else: + target_name = p.name + for name in [self._moment1_acc_str, self._moment2_acc_str]: + if self._name is not None: + name = self._name + "_" + name + self._accumulators[name][target_name] = self._accumulators[name][target_name].cuda()