16
16
17
17
#include " paddle/phi/backends/gpu/gpu_context.h"
18
18
#include " paddle/phi/backends/gpu/gpu_primitives.h"
19
+ #include " paddle/phi/common/memory_utils.h"
19
20
#include " paddle/phi/core/kernel_registry.h"
20
21
#include " paddle/phi/kernels/funcs/elementwise_base.h"
21
22
#include " paddle/phi/kernels/funcs/functors.h"
@@ -28,7 +29,7 @@ namespace phi {
28
29
using IndexType = int64_t ;
29
30
using phi::PADDLE_CUDA_NUM_THREADS;
30
31
31
- inline int GET_BLOCKS (const int N) {
32
+ inline int64_t GET_BLOCKS (const int64_t N) {
32
33
return (N + PADDLE_CUDA_NUM_THREADS - 1 ) / PADDLE_CUDA_NUM_THREADS;
33
34
}
34
35
@@ -46,18 +47,18 @@ __device__ static IndexType GetBin(T input_value,
46
47
template <typename T, typename IndexType, typename Out_T>
47
48
__global__ void KernelHistogram (const T* input,
48
49
const T* weight,
49
- const int total_elements,
50
+ const int64_t total_elements,
50
51
const int64_t nbins,
51
52
const T* min_value,
52
53
const T* max_value,
53
54
Out_T* output) {
54
55
extern __shared__ float buf_hist[];
55
- for (int i = threadIdx .x ; i < nbins; i += blockDim .x ) {
56
+ for (int64_t i = threadIdx .x ; i < nbins; i += blockDim .x ) {
56
57
buf_hist[i] = 0 ;
57
58
}
58
59
__syncthreads ();
59
60
60
- CUDA_KERNEL_LOOP (input_index, total_elements) {
61
+ CUDA_KERNEL_LOOP_TYPE (input_index, total_elements, IndexType ) {
61
62
// const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x;
62
63
const auto input_value = input[input_index];
63
64
if (input_value >= *min_value && input_value <= *max_value) {
@@ -69,23 +70,23 @@ __global__ void KernelHistogram(const T* input,
69
70
}
70
71
__syncthreads ();
71
72
72
- for (int i = threadIdx .x ; i < nbins; i += blockDim .x ) {
73
+ for (int64_t i = threadIdx .x ; i < nbins; i += blockDim .x ) {
73
74
phi::CudaAtomicAdd (&output[i], buf_hist[i]);
74
75
}
75
76
}
76
77
77
78
template <typename T>
78
79
__global__ void KernelMinMax (const T* input,
79
- const int numel,
80
- const int block_num,
80
+ const int64_t numel,
81
+ const int64_t block_num,
81
82
T* min_ptr,
82
83
T* max_ptr) {
83
- int64_t index = threadIdx .x + blockIdx .x * blockDim .x ;
84
+ int64_t index = threadIdx .x + blockIdx .x * static_cast < int64_t >( blockDim .x ) ;
84
85
int64_t i = index;
85
86
T min_value = static_cast <T>(i < numel ? input[i] : input[0 ]);
86
87
T max_value = static_cast <T>(i < numel ? input[i] : input[0 ]);
87
88
88
- for (; i < numel; i += blockDim .x * gridDim .x ) {
89
+ for (; i < numel; i += blockDim .x * static_cast < int64_t >( gridDim .x ) ) {
89
90
T value = static_cast <T>(input[i]);
90
91
min_value = value < min_value ? value : min_value;
91
92
max_value = value > max_value ? value : max_value;
@@ -106,9 +107,11 @@ __global__ void KernelMinMax(const T* input,
106
107
min_value = min_ptr[0 ];
107
108
max_value = max_ptr[0 ];
108
109
for (int64_t i = 1 ; i < block_num; i++) {
109
- min_ptr[ 0 ] = min_ptr[i] < min_value ? min_ptr[i] : min_value;
110
- max_ptr[ 0 ] = max_ptr[i] > max_value ? max_ptr[i] : max_value;
110
+ min_value = min_ptr[i] < min_value ? min_ptr[i] : min_value;
111
+ max_value = max_ptr[i] > max_value ? max_ptr[i] : max_value;
111
112
}
113
+ min_ptr[0 ] = min_value;
114
+ max_ptr[0 ] = max_value;
112
115
if (min_ptr[0 ] == max_ptr[0 ]) {
113
116
min_ptr[0 ] = min_ptr[0 ] - 1 ;
114
117
max_ptr[0 ] = max_ptr[0 ] + 1 ;
@@ -128,13 +131,6 @@ __global__ void KernelMinMax(const T min_value,
128
131
}
129
132
}
130
133
131
- __global__ void KernelMul (float * data, float * scale, int64_t numel) {
132
- size_t index = threadIdx .x + blockIdx .x * blockDim .x ;
133
- if (index < numel) {
134
- data[index] /= *scale;
135
- }
136
- }
137
-
138
134
template <typename T, typename Context>
139
135
void HistogramKernel (const Context& dev_ctx,
140
136
const DenseTensor& input,
@@ -149,35 +145,42 @@ void HistogramKernel(const Context& dev_ctx,
149
145
auto & maxval = max;
150
146
151
147
const T* input_data = input.data <T>();
152
- const int input_numel = input.numel ();
148
+ const int64_t input_numel = input.numel ();
153
149
auto weight_data = weight.get_ptr () ? weight.get_ptr ()->data <T>() : nullptr ;
154
150
155
151
if (input_data == nullptr ) return ;
156
152
157
153
T output_min = static_cast <T>(minval);
158
154
T output_max = static_cast <T>(maxval);
159
155
DenseTensor min_max;
160
- int block_num = GET_BLOCKS (input_numel);
156
+ int64_t block_num = GET_BLOCKS (input_numel);
157
+ block_num = std::min (
158
+ block_num, static_cast <int64_t >(dev_ctx.GetCUDAMaxGridDimSize ()[0 ]));
161
159
min_max.Resize ({2 * block_num});
162
160
auto * min_block_ptr = dev_ctx.template Alloc <T>(&min_max);
163
161
auto * max_block_ptr = min_block_ptr + block_num;
164
162
if (min == max) {
165
- KernelMinMax<T><<<GET_BLOCKS(input_numel),
166
- PADDLE_CUDA_NUM_THREADS,
167
- 0 ,
168
- dev_ctx.stream()>>> (
169
- input_data, input_numel, block_num, min_block_ptr, max_block_ptr);
163
+ KernelMinMax<T>
164
+ <<<block_num, PADDLE_CUDA_NUM_THREADS, 0 , dev_ctx.stream()>>> (
165
+ input_data, input_numel, block_num, min_block_ptr, max_block_ptr);
166
+ // copy min max value from GPU to CPU
167
+ phi::memory_utils::Copy (phi::CPUPlace (),
168
+ &output_min,
169
+ min_max.place (),
170
+ min_block_ptr,
171
+ sizeof (T),
172
+ dev_ctx.stream ());
173
+ phi::memory_utils::Copy (phi::CPUPlace (),
174
+ &output_max,
175
+ min_max.place (),
176
+ max_block_ptr,
177
+ sizeof (T),
178
+ dev_ctx.stream ());
170
179
} else {
171
180
KernelMinMax<T><<<1 , 1 , 0 , dev_ctx.stream()>>> (
172
181
output_min, output_max, min_block_ptr, max_block_ptr);
173
182
}
174
183
175
- // copy min max value from GPU to CPU
176
- std::vector<T> min_max_vec;
177
- phi::TensorToVector (min_max, dev_ctx, &min_max_vec);
178
- output_min = min_max_vec[0 ];
179
- output_max = min_max_vec[1 ];
180
-
181
184
// check if out of range
182
185
double range =
183
186
static_cast <double >(output_max) - static_cast <double >(output_min);
@@ -212,46 +215,48 @@ void HistogramKernel(const Context& dev_ctx,
212
215
213
216
auto stream = dev_ctx.stream ();
214
217
215
- if (!density && ! weight_data) {
218
+ if (!density && weight_data == nullptr ) {
216
219
int64_t * out_data = dev_ctx.template Alloc <int64_t >(output);
217
220
phi::funcs::SetConstant<Context, int64_t >()(dev_ctx, output, 0 );
218
- KernelHistogram<T, IndexType, int64_t ><<<GET_BLOCKS(input_numel),
219
- PADDLE_CUDA_NUM_THREADS,
220
- nbins * sizeof (int64_t ),
221
- stream>>> (input_data,
222
- weight_data,
223
- input_numel,
224
- nbins,
225
- min_block_ptr,
226
- max_block_ptr,
227
- out_data);
228
- return ;
229
-
221
+ KernelHistogram<T, IndexType, int64_t >
222
+ <<<block_num, PADDLE_CUDA_NUM_THREADS, nbins * sizeof (float ), stream>>> (
223
+ input_data,
224
+ weight_data,
225
+ input_numel,
226
+ nbins,
227
+ min_block_ptr,
228
+ max_block_ptr,
229
+ out_data);
230
230
} else {
231
231
float * out_data = dev_ctx.template Alloc <float >(output);
232
232
phi::funcs::SetConstant<Context, float >()(
233
233
dev_ctx, output, static_cast <float >(0 ));
234
- KernelHistogram<T, IndexType, float ><<<GET_BLOCKS(input_numel),
235
- PADDLE_CUDA_NUM_THREADS,
236
- nbins * sizeof (int64_t ),
237
- stream>>> (input_data,
238
- weight_data,
239
- input_numel,
240
- nbins,
241
- min_block_ptr,
242
- max_block_ptr,
243
- out_data);
234
+ KernelHistogram<T, IndexType, float >
235
+ <<<block_num, PADDLE_CUDA_NUM_THREADS, nbins * sizeof (float ), stream>>> (
236
+ input_data,
237
+ weight_data,
238
+ input_numel,
239
+ nbins,
240
+ min_block_ptr,
241
+ max_block_ptr,
242
+ out_data);
244
243
if (density) {
245
244
DenseTensor sum = phi::Sum<float , Context>(
246
245
dev_ctx, *output, phi::IntArray ({0 }), phi::DataType::FLOAT32, false );
246
+ float sum_cpu;
247
+ phi::memory_utils::Copy (phi::CPUPlace (),
248
+ &sum_cpu,
249
+ sum.place (),
250
+ sum.data <float >(),
251
+ sizeof (float ),
252
+ dev_ctx.stream ());
247
253
float gap = static_cast <float >(nbins) /
248
254
static_cast <float >(output_max - output_min);
249
255
std::vector<const DenseTensor*> ins = {output};
250
256
std::vector<DenseTensor*> outs = {output};
251
- auto functor = phi::funcs::ScaleFunctor<float >(gap);
257
+ float scale = gap / sum_cpu;
258
+ auto functor = phi::funcs::ScaleFunctor<float >(scale);
252
259
phi::funcs::ElementwiseKernel<float >(dev_ctx, ins, &outs, functor);
253
- KernelMul<<<GET_BLOCKS(static_cast <int >(bins)),
254
- PADDLE_CUDA_NUM_THREADS>>> (out_data, sum.data <float >(), bins);
255
260
}
256
261
}
257
262
}
0 commit comments