1+ #include < cub/block/block_radix_sort.cuh>
2+ #include < cub/warp/warp_reduce.cuh>
3+ #include < cub/block/block_load.cuh>
4+ #include < cub/block/block_discontinuity.cuh>
5+ #include < cub/block/block_store.cuh>
6+ #include < cub/block/block_reduce.cuh>
7+ #include < cub/cub.cuh>
8+ #include < math_constants.h>
9+ #include < thrust/host_vector.h>
10+ #include < thrust/device_vector.h>
11+ #include < mma.h>
12+ #include " helper.h"
13+ #include < iostream>
14+ using namespace std ;
15+
16+ #define HLF_MAX 65504
17+ #define TH 1024
18+ #define NUM 4
19+ #define NUM_BLOCK 4096
20+
21+ __device__ unsigned char dQuantizeNF4 (float x)
22+ {
23+
24+ // the values for this tree was generated by test_normal_map_tree
25+ // in the file tests/test_functional.py
26+ if (x > 0 .03979014977812767f )
27+ if (x > 0 .3893125355243683f ) // 1
28+ if (x > 0 .6427869200706482f ) // 11
29+ if (x > 0 .8614784181118011f ) // 111
30+ return 0b1111 ;
31+ else
32+ return 0b1110 ;
33+ else
34+ if (x > 0 .5016634166240692f ) // 110
35+ return 0b1101 ;
36+ else
37+ return 0b1100 ;
38+ else
39+ if (x > 0 .2035212516784668f ) // 10
40+ if (x > 0 .2920137718319893f ) // 101
41+ return 0b1011 ;
42+ else
43+ return 0b1010 ;
44+ else
45+ if (x > 0 .1202552504837513f ) // 100
46+ return 0b1001 ;
47+ else
48+ return 0b1000 ;
49+ else
50+ if (x > -0 .33967943489551544f ) // 0
51+ if (x > -0 .13791173323988914f ) // 01
52+ if (x > -0 .045525018125772476f ) // 011
53+ return 0b0111 ;
54+ else
55+ return 0b0110 ;
56+ else
57+ if (x > -0 .23460740596055984f ) // 010
58+ return 0b0101 ;
59+ else
60+ return 0b0100 ;
61+ else
62+ if (x > -0 .6106329262256622f ) // 00
63+ if (x > -0 .4599952697753906f ) // 001
64+ return 0b0011 ;
65+ else
66+ return 0b0010 ;
67+ else
68+ if (x > -0 .8480964004993439f ) // 000
69+ return 0b0001 ;
70+ else
71+ return 0b0000 ;
72+ }
73+
74+ template <typename T, int BLOCK_SIZE, int NUM_PER_TH>
75+ // __launch_bounds__(TH, 4)
76+ __global__ void kQuantizeBlockwiseNF4 (const T* A, float *absmax, unsigned char *out, const int n)
77+ {
78+ // 所有的 CUDA blocks 处理的所有元素个数
79+ const int n_full = gridDim .x * BLOCK_SIZE;
80+ int valid_items = 0 ;
81+ // 当前 CUDA block 处理元素的起始索引
82+ const int base_idx = (blockIdx .x * BLOCK_SIZE);
83+ // 当前 CUDA thread 处理的输入元素
84+ T vals[NUM_PER_TH];
85+ // 当前 CUDA thread 处理的输出元素个数
86+ const int output_num_per_thread = NUM_PER_TH/2 ;
87+ // 当前 CUDA thread 处理的输出元素
88+ unsigned char qvals[output_num_per_thread];
89+ // float local_abs_max = -FLT_MAX;
90+ float local_abs_max = 0 .0f ;
91+ typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
92+ typedef cub::BlockStore<unsigned char , BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH/2 , cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
93+ typedef cub::BlockReduce<float , BLOCK_SIZE/NUM_PER_TH> BlockReduce;
94+ typedef cub::BlockLoad<float , BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
95+
96+ __shared__ typename LoadT::TempStorage loadt;
97+ __shared__ typename LoadFloat::TempStorage loadf;
98+ __shared__ typename StoreChar::TempStorage storec;
99+ __shared__ typename BlockReduce::TempStorage reduce;
100+ // 每个CUDA block (也是每个 quantization block)的absmax
101+ __shared__ float smem_absmax_value[1 ];
102+
103+ for (unsigned int i = base_idx; i < n_full; i += gridDim .x *BLOCK_SIZE)
104+ {
105+ valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
106+ local_abs_max = -FLT_MAX;
107+
108+ __syncthreads ();
109+ LoadT (loadt).Load (&(A[i]), vals, valid_items, (T)0 .0f );
110+
111+ // 1. compute local max
112+ // 2. broadcast local max
113+ // 3. normalize inputs and quantize
114+
115+ #pragma unroll NUM_PER_TH
116+ for (int j = 0 ; j < NUM_PER_TH; j++)
117+ local_abs_max = fmaxf (local_abs_max, fabsf ((float )vals[j]));
118+
119+ local_abs_max = BlockReduce (reduce).Reduce (local_abs_max, cub::Max (), valid_items);
120+
121+ if (threadIdx .x == 0 )
122+ smem_absmax_value[0 ] = local_abs_max;
123+
124+ __syncthreads ();
125+
126+ if (threadIdx .x == 0 )
127+ absmax[i/BLOCK_SIZE] = local_abs_max;
128+ else
129+ local_abs_max = smem_absmax_value[0 ];
130+
131+ __syncwarp ();
132+
133+ local_abs_max = 1 .0f /local_abs_max;
134+
135+ unsigned char packed_4bit = 0 ;
136+
137+ #pragma unroll NUM_PER_TH
138+ for (int j = 0 ; j < NUM_PER_TH/2 ; j++)
139+ {
140+ packed_4bit |= dQuantizeNF4 (((float )vals[2 *j])*local_abs_max) << 4 ;
141+ packed_4bit |= dQuantizeNF4 (((float )vals[2 *j+1 ])*local_abs_max);
142+ qvals[j] = packed_4bit;
143+ }
144+
145+ __syncthreads ();
146+ StoreChar (storec).Store (&(out[i/2 ]), qvals, (valid_items+1 )/2 );
147+ }
148+ }
149+
150+ #define MAKE_kQuantizeBlockwiseNF4 (dtype, blocksize, num_per_thread ) \
151+ template __global__ void kQuantizeBlockwiseNF4 <dtype, blocksize, num_per_thread>(const dtype * A, float *absmax, unsigned char *out, const int n); \
152+
153+ MAKE_kQuantizeBlockwiseNF4 (half, 4096 , 4 )
154+ MAKE_kQuantizeBlockwiseNF4(half, 1024 , 4 )
155+ MAKE_kQuantizeBlockwiseNF4(half, 512 , 2 )
156+ MAKE_kQuantizeBlockwiseNF4(half, 256 , 2 )
157+ MAKE_kQuantizeBlockwiseNF4(half, 128 , 2 )
158+ MAKE_kQuantizeBlockwiseNF4(half, 64 , 2 )
159+
160+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 4096 , 4 )
161+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 1024 , 4 )
162+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 512 , 2 )
163+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 256 , 2 )
164+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 128 , 2 )
165+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 64 , 2 )
166+
167+ MAKE_kQuantizeBlockwiseNF4(float , 4096 , 4 )
168+ MAKE_kQuantizeBlockwiseNF4(float , 1024 , 4 )
169+ MAKE_kQuantizeBlockwiseNF4(float , 512 , 2 )
170+ MAKE_kQuantizeBlockwiseNF4(float , 256 , 2 )
171+ MAKE_kQuantizeBlockwiseNF4(float , 128 , 2 )
172+ MAKE_kQuantizeBlockwiseNF4(float , 64 , 2 )
173+
174+ template <paddle::DataType D>
175+ std::vector<paddle::Tensor> LaunchQuantizeNF4(const paddle::Tensor& input, int block_size) {
176+ cout << " LaunchQuantizeNF4 begin-------" << endl;
177+ typedef PDTraits<D> traits_;
178+ typedef typename traits_::DataType DataType_;
179+ typedef typename traits_::data_t data_t ;
180+ auto input_shape = input.shape ();
181+ auto output = paddle::full (input_shape, 1 , paddle::DataType::UINT8, input.place ());
182+ const int n = input.numel ();
183+ int num_blocks = n/block_size;
184+ num_blocks = n % block_size == 0 ? num_blocks : num_blocks + 1 ;
185+
186+ auto abs_max = paddle::full ({num_blocks}, 1 , paddle::DataType::FLOAT32, input.place ());
187+
188+ const DataType_ *in_ptr = reinterpret_cast <const DataType_*>(input.data <data_t >());
189+ unsigned char *out_ptr = output.mutable_data <unsigned char >();
190+ float *abs_max_ptr = abs_max.mutable_data <float >();
191+
192+ if (block_size == 2048 ) {
193+ kQuantizeBlockwiseNF4 <DataType_, 2048 , 4 ><<<num_blocks, 512 >>> (in_ptr, abs_max_ptr, out_ptr, n);
194+ } else if (block_size == 1024 ) {
195+ kQuantizeBlockwiseNF4 <DataType_, 1024 , 4 ><<<num_blocks, 256 >>> (in_ptr, abs_max_ptr, out_ptr, n);
196+ } else if (block_size == 512 ) {
197+ kQuantizeBlockwiseNF4 <DataType_, 512 , 2 ><<<num_blocks, 256 >>> (in_ptr, abs_max_ptr, out_ptr, n);
198+ } else if (block_size == 256 ) {
199+ kQuantizeBlockwiseNF4 <DataType_, 256 , 2 ><<<num_blocks, 128 >>> (in_ptr, abs_max_ptr, out_ptr, n);
200+ } else if (block_size == 128 ) {
201+ kQuantizeBlockwiseNF4 <DataType_, 128 , 2 ><<<num_blocks, 64 >>> (in_ptr, abs_max_ptr, out_ptr, n);
202+ } else if (block_size == 64 ) {
203+ kQuantizeBlockwiseNF4 <DataType_, 64 , 2 ><<<num_blocks, 32 >>> (in_ptr, abs_max_ptr, out_ptr, n);
204+ }
205+ return {output, abs_max};
206+ }
207+
208+ std::vector<paddle::Tensor> QuantizeNF4 (const paddle::Tensor& input, int block_size) {
209+ cout << " QuantizeNF4 begin-------" << endl;
210+ switch (input.type ()) {
211+ case paddle::DataType::BFLOAT16: {
212+ return LaunchQuantizeNF4<paddle::DataType::BFLOAT16>(input, block_size);
213+ }
214+ case paddle::DataType::FLOAT16: {
215+ return LaunchQuantizeNF4<paddle::DataType::FLOAT16>(input, block_size);
216+ }
217+ case paddle::DataType::FLOAT32: {
218+ return LaunchQuantizeNF4<paddle::DataType::FLOAT32>(input, block_size);
219+ }
220+ default : {
221+ PD_THROW (
222+ " NOT supported data type. "
223+ " Only bfloat16, float16 and float32 are supported. " );
224+ break ;
225+ }
226+ }
227+ }
228+
229+
230+
231+
232+ PD_BUILD_OP (quantize_nf4)
233+ .Inputs({" input" })
234+ .Outputs({" out" , " abs_max" })
235+ .SetKernelFn(PD_KERNEL(QuantizeNF4));
0 commit comments