Skip to content

Commit f1eaf2f

Browse files
committed
fix histogram kernel,including accuracy and big tensor handling
1 parent e7007d1 commit f1eaf2f

File tree

1 file changed

+62
-57
lines changed

1 file changed

+62
-57
lines changed

paddle/phi/kernels/gpu/histogram_kernel.cu

Lines changed: 62 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/backends/gpu/gpu_context.h"
1818
#include "paddle/phi/backends/gpu/gpu_primitives.h"
19+
#include "paddle/phi/common/memory_utils.h"
1920
#include "paddle/phi/core/kernel_registry.h"
2021
#include "paddle/phi/kernels/funcs/elementwise_base.h"
2122
#include "paddle/phi/kernels/funcs/functors.h"
@@ -28,7 +29,7 @@ namespace phi {
2829
using IndexType = int64_t;
2930
using phi::PADDLE_CUDA_NUM_THREADS;
3031

31-
inline int GET_BLOCKS(const int N) {
32+
inline int64_t GET_BLOCKS(const int64_t N) {
3233
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
3334
}
3435

@@ -46,18 +47,18 @@ __device__ static IndexType GetBin(T input_value,
4647
template <typename T, typename IndexType, typename Out_T>
4748
__global__ void KernelHistogram(const T* input,
4849
const T* weight,
49-
const int total_elements,
50+
const int64_t total_elements,
5051
const int64_t nbins,
5152
const T* min_value,
5253
const T* max_value,
5354
Out_T* output) {
5455
extern __shared__ float buf_hist[];
55-
for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
56+
for (int64_t i = threadIdx.x; i < nbins; i += blockDim.x) {
5657
buf_hist[i] = 0;
5758
}
5859
__syncthreads();
5960

60-
CUDA_KERNEL_LOOP(input_index, total_elements) {
61+
CUDA_KERNEL_LOOP_TYPE(input_index, total_elements, IndexType) {
6162
// const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x;
6263
const auto input_value = input[input_index];
6364
if (input_value >= *min_value && input_value <= *max_value) {
@@ -69,23 +70,23 @@ __global__ void KernelHistogram(const T* input,
6970
}
7071
__syncthreads();
7172

72-
for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
73+
for (int64_t i = threadIdx.x; i < nbins; i += blockDim.x) {
7374
phi::CudaAtomicAdd(&output[i], buf_hist[i]);
7475
}
7576
}
7677

7778
template <typename T>
7879
__global__ void KernelMinMax(const T* input,
79-
const int numel,
80-
const int block_num,
80+
const int64_t numel,
81+
const int64_t block_num,
8182
T* min_ptr,
8283
T* max_ptr) {
83-
int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
84+
int64_t index = threadIdx.x + blockIdx.x * static_cast<int64_t>(blockDim.x);
8485
int64_t i = index;
8586
T min_value = static_cast<T>(i < numel ? input[i] : input[0]);
8687
T max_value = static_cast<T>(i < numel ? input[i] : input[0]);
8788

88-
for (; i < numel; i += blockDim.x * gridDim.x) {
89+
for (; i < numel; i += blockDim.x * static_cast<int64_t>(gridDim.x)) {
8990
T value = static_cast<T>(input[i]);
9091
min_value = value < min_value ? value : min_value;
9192
max_value = value > max_value ? value : max_value;
@@ -106,9 +107,11 @@ __global__ void KernelMinMax(const T* input,
106107
min_value = min_ptr[0];
107108
max_value = max_ptr[0];
108109
for (int64_t i = 1; i < block_num; i++) {
109-
min_ptr[0] = min_ptr[i] < min_value ? min_ptr[i] : min_value;
110-
max_ptr[0] = max_ptr[i] > max_value ? max_ptr[i] : max_value;
110+
min_value = min_ptr[i] < min_value ? min_ptr[i] : min_value;
111+
max_value = max_ptr[i] > max_value ? max_ptr[i] : max_value;
111112
}
113+
min_ptr[0] = min_value;
114+
max_ptr[0] = max_value;
112115
if (min_ptr[0] == max_ptr[0]) {
113116
min_ptr[0] = min_ptr[0] - 1;
114117
max_ptr[0] = max_ptr[0] + 1;
@@ -128,13 +131,6 @@ __global__ void KernelMinMax(const T min_value,
128131
}
129132
}
130133

131-
__global__ void KernelMul(float* data, float* scale, int64_t numel) {
132-
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
133-
if (index < numel) {
134-
data[index] /= *scale;
135-
}
136-
}
137-
138134
template <typename T, typename Context>
139135
void HistogramKernel(const Context& dev_ctx,
140136
const DenseTensor& input,
@@ -149,35 +145,42 @@ void HistogramKernel(const Context& dev_ctx,
149145
auto& maxval = max;
150146

151147
const T* input_data = input.data<T>();
152-
const int input_numel = input.numel();
148+
const int64_t input_numel = input.numel();
153149
auto weight_data = weight.get_ptr() ? weight.get_ptr()->data<T>() : nullptr;
154150

155151
if (input_data == nullptr) return;
156152

157153
T output_min = static_cast<T>(minval);
158154
T output_max = static_cast<T>(maxval);
159155
DenseTensor min_max;
160-
int block_num = GET_BLOCKS(input_numel);
156+
int64_t block_num = GET_BLOCKS(input_numel);
157+
block_num = std::min(
158+
block_num, static_cast<int64_t>(dev_ctx.GetCUDAMaxGridDimSize()[0]));
161159
min_max.Resize({2 * block_num});
162160
auto* min_block_ptr = dev_ctx.template Alloc<T>(&min_max);
163161
auto* max_block_ptr = min_block_ptr + block_num;
164162
if (min == max) {
165-
KernelMinMax<T><<<GET_BLOCKS(input_numel),
166-
PADDLE_CUDA_NUM_THREADS,
167-
0,
168-
dev_ctx.stream()>>>(
169-
input_data, input_numel, block_num, min_block_ptr, max_block_ptr);
163+
KernelMinMax<T>
164+
<<<block_num, PADDLE_CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>(
165+
input_data, input_numel, block_num, min_block_ptr, max_block_ptr);
166+
// copy min max value from GPU to CPU
167+
phi::memory_utils::Copy(phi::CPUPlace(),
168+
&output_min,
169+
min_max.place(),
170+
min_block_ptr,
171+
sizeof(T),
172+
dev_ctx.stream());
173+
phi::memory_utils::Copy(phi::CPUPlace(),
174+
&output_max,
175+
min_max.place(),
176+
max_block_ptr,
177+
sizeof(T),
178+
dev_ctx.stream());
170179
} else {
171180
KernelMinMax<T><<<1, 1, 0, dev_ctx.stream()>>>(
172181
output_min, output_max, min_block_ptr, max_block_ptr);
173182
}
174183

175-
// copy min max value from GPU to CPU
176-
std::vector<T> min_max_vec;
177-
phi::TensorToVector(min_max, dev_ctx, &min_max_vec);
178-
output_min = min_max_vec[0];
179-
output_max = min_max_vec[1];
180-
181184
// check if out of range
182185
double range =
183186
static_cast<double>(output_max) - static_cast<double>(output_min);
@@ -212,46 +215,48 @@ void HistogramKernel(const Context& dev_ctx,
212215

213216
auto stream = dev_ctx.stream();
214217

215-
if (!density && !weight_data) {
218+
if (!density && weight_data == nullptr) {
216219
int64_t* out_data = dev_ctx.template Alloc<int64_t>(output);
217220
phi::funcs::SetConstant<Context, int64_t>()(dev_ctx, output, 0);
218-
KernelHistogram<T, IndexType, int64_t><<<GET_BLOCKS(input_numel),
219-
PADDLE_CUDA_NUM_THREADS,
220-
nbins * sizeof(int64_t),
221-
stream>>>(input_data,
222-
weight_data,
223-
input_numel,
224-
nbins,
225-
min_block_ptr,
226-
max_block_ptr,
227-
out_data);
228-
return;
229-
221+
KernelHistogram<T, IndexType, int64_t>
222+
<<<block_num, PADDLE_CUDA_NUM_THREADS, nbins * sizeof(float), stream>>>(
223+
input_data,
224+
weight_data,
225+
input_numel,
226+
nbins,
227+
min_block_ptr,
228+
max_block_ptr,
229+
out_data);
230230
} else {
231231
float* out_data = dev_ctx.template Alloc<float>(output);
232232
phi::funcs::SetConstant<Context, float>()(
233233
dev_ctx, output, static_cast<float>(0));
234-
KernelHistogram<T, IndexType, float><<<GET_BLOCKS(input_numel),
235-
PADDLE_CUDA_NUM_THREADS,
236-
nbins * sizeof(int64_t),
237-
stream>>>(input_data,
238-
weight_data,
239-
input_numel,
240-
nbins,
241-
min_block_ptr,
242-
max_block_ptr,
243-
out_data);
234+
KernelHistogram<T, IndexType, float>
235+
<<<block_num, PADDLE_CUDA_NUM_THREADS, nbins * sizeof(float), stream>>>(
236+
input_data,
237+
weight_data,
238+
input_numel,
239+
nbins,
240+
min_block_ptr,
241+
max_block_ptr,
242+
out_data);
244243
if (density) {
245244
DenseTensor sum = phi::Sum<float, Context>(
246245
dev_ctx, *output, phi::IntArray({0}), phi::DataType::FLOAT32, false);
246+
float sum_cpu;
247+
phi::memory_utils::Copy(phi::CPUPlace(),
248+
&sum_cpu,
249+
sum.place(),
250+
sum.data<float>(),
251+
sizeof(float),
252+
dev_ctx.stream());
247253
float gap = static_cast<float>(nbins) /
248254
static_cast<float>(output_max - output_min);
249255
std::vector<const DenseTensor*> ins = {output};
250256
std::vector<DenseTensor*> outs = {output};
251-
auto functor = phi::funcs::ScaleFunctor<float>(gap);
257+
float scale = gap / sum_cpu;
258+
auto functor = phi::funcs::ScaleFunctor<float>(scale);
252259
phi::funcs::ElementwiseKernel<float>(dev_ctx, ins, &outs, functor);
253-
KernelMul<<<GET_BLOCKS(static_cast<int>(bins)),
254-
PADDLE_CUDA_NUM_THREADS>>>(out_data, sum.data<float>(), bins);
255260
}
256261
}
257262
}

0 commit comments

Comments
 (0)