Skip to content

Commit b6d4e79

Browse files
committed
Implement progressive DFT execution
1 parent 0e964b9 commit b6d4e79

File tree

5 files changed

+298
-50
lines changed

5 files changed

+298
-50
lines changed

include/kfr/dft/fft.hpp

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ struct dft_stage
9292
{
9393
do_execute(cinvert_t(), out, in, temp);
9494
}
95+
KFR_MEM_INTRINSIC void execute(bool inverse, complex<T>* out, const complex<T>* in, u8* temp)
96+
{
97+
if (inverse)
98+
do_execute(cinvert_t(), out, in, temp);
99+
else
100+
do_execute(cdirect_t(), out, in, temp);
101+
}
95102
virtual ~dft_stage() {}
96103

97104
protected:
@@ -155,6 +162,12 @@ using fn_transpose = void (*)(complex<T>*, const complex<T>*, shape<2>);
155162
template <typename T>
156163
void dft_initialize_transpose(fn_transpose<T>& transpose);
157164

165+
template <typename T>
166+
void dft_progressive_start(const dft_plan<T>& plan, typename dft_plan<T>::progressive& progressive,
167+
bool inverse, complex<T>* out, const complex<T>* in, u8* temp);
168+
template <typename T>
169+
void dft_progressive_step(const dft_plan<T>& plan, typename dft_plan<T>::progressive& progressive);
170+
158171
} // namespace internal_generic
159172

160173
/**
@@ -232,9 +245,10 @@ struct dft_plan
232245
*
233246
* @param size The size of the DFT.
234247
* @param order The order of the DFT samples. See `dft_order`.
248+
* @param progressive_optimized If true, the plan will be optimized for progressive execution.
235249
*/
236-
explicit dft_plan(size_t size, dft_order order = dft_order::normal)
237-
: size(size), temp_size(0), data_size(0), arblen(false)
250+
explicit dft_plan(size_t size, dft_order order = dft_order::normal, bool progressive_optimized = false)
251+
: size(size), temp_size(0), data_size(0), arblen(false), progressive_optimized(progressive_optimized)
238252
{
239253
internal_generic::dft_initialize(*this);
240254
}
@@ -369,6 +383,7 @@ struct dft_plan
369383
std::vector<dft_stage_ptr<T>> all_stages; /**< Internal data. */
370384
std::array<std::vector<dft_stage<T>*>, 2> stages; /**< Internal data. */
371385
bool arblen; /**< True if Bluestein's FFT algorithm is selected. */
386+
bool progressive_optimized; /**< True if the plan is for progressive execution of the DFT. */
372387
using bitset = std::bitset<DFT_MAX_STAGES>; /**< Internal typedef. */
373388
std::array<bitset, 2> disposition_inplace; /**< Internal data. */
374389
std::array<bitset, 2> disposition_outofplace; /**< Internal data. */
@@ -380,12 +395,64 @@ struct dft_plan
380395
static bitset precompute_disposition(int num_stages, bitset can_inplace_per_stage,
381396
bool inplace_requested);
382397

398+
/** Internal data structure for progressive execution of the DFT.
399+
Do not access the members directly as they may change in future versions.
400+
*/
401+
struct progressive
402+
{
403+
bool inverse;
404+
complex<T>* out;
405+
const complex<T>* in;
406+
u8* temp;
407+
bitset disposition;
408+
complex<T>* scratch;
409+
size_t step = 0;
410+
};
411+
412+
/// @brief Returns the number of steps for progressive execution of the DFT.
413+
/// @return The number of steps for progressive execution.
414+
size_t progressive_total_steps() const;
415+
416+
/**
417+
* @brief Initiates the progressive execution of the DFT.
418+
* @param inverse If true, applies the inverse DFT.
419+
* @param out Pointer to the output data.
420+
* @param in Pointer to the input data.
421+
* @param temp Temporary (scratch) buffer. A scratch buffer of size
422+
* `plan->temp_size` must be provided.
423+
* @return A `progressive` structure that can be used with `progressive_step`.
424+
* @note Ensure that the entire input data is available in the `in` buffer before calling this function.
425+
* The `out` buffer will contain the result data after the final step of the progressive execution.
426+
*/
427+
KFR_MEM_INTRINSIC progressive progressive_start(bool inverse, complex<T>* out, const complex<T>* in,
428+
u8* temp) const
429+
{
430+
KFR_LOGIC_CHECK(is_initialized(), "dft_plan is not initialized");
431+
KFR_LOGIC_CHECK(temp_size == 0 || temp != nullptr,
432+
"Temporary buffer must be provided for progressive execution");
433+
progressive result{};
434+
internal_generic::dft_progressive_start(*this, result, inverse, out, in, temp);
435+
return result;
436+
}
437+
438+
/**
439+
* @brief Steps the progressive execution of the DFT.
440+
* @param progressive A `progressive` structure returned by `progressive_start`.
441+
* @return `true` if there are more steps to execute, `false` if the DFT is complete.
442+
*/
443+
KFR_MEM_INTRINSIC bool progressive_step(progressive& progressive) const
444+
{
445+
internal_generic::dft_progressive_step(*this, progressive);
446+
return ++progressive.step < stages[progressive.inverse].size();
447+
}
448+
383449
protected:
384450
struct noinit
385451
{
386452
};
387-
explicit dft_plan(noinit, size_t size, dft_order order = dft_order::normal)
388-
: size(size), temp_size(0), data_size(0), arblen(false)
453+
explicit dft_plan(noinit, size_t size, dft_order order = dft_order::normal,
454+
bool progressive_optimized = false)
455+
: size(size), temp_size(0), data_size(0), arblen(false), progressive_optimized(progressive_optimized)
389456
{
390457
}
391458

@@ -426,8 +493,10 @@ struct dft_plan_real : dft_plan<T>
426493
(void)cpu;
427494
}
428495

429-
explicit dft_plan_real(size_t size, dft_pack_format fmt = dft_pack_format::CCs)
430-
: dft_plan<T>(typename dft_plan<T>::noinit{}, size / 2), size(size), fmt(fmt)
496+
explicit dft_plan_real(size_t size, dft_pack_format fmt = dft_pack_format::CCs,
497+
bool progressive_optimized = false)
498+
: dft_plan<T>(typename dft_plan<T>::noinit{}, size / 2, dft_order::normal, progressive_optimized),
499+
size(size), fmt(fmt)
431500
{
432501
KFR_LOGIC_CHECK(is_even(size), "dft_plan_real requires size to be even");
433502
internal_generic::dft_real_initialize(*this);
@@ -488,6 +557,28 @@ struct dft_plan_real : dft_plan<T>
488557
{
489558
this->execute_dft(ctrue, ptr_cast<complex<T>>(out.data()), in.data(), temp);
490559
}
560+
561+
using progressive = typename dft_plan<T>::progressive;
562+
563+
KFR_MEM_INTRINSIC progressive progressive_start(T* out, const complex<T>* in, u8* temp) const
564+
{
565+
KFR_LOGIC_CHECK(is_initialized(), "dft_plan_real is not initialized");
566+
KFR_LOGIC_CHECK(this->temp_size == 0 || temp != nullptr,
567+
"Temporary buffer must be provided for progressive execution");
568+
progressive result{};
569+
internal_generic::dft_progressive_start(*this, result, true, ptr_cast<complex<T>>(out), in, temp);
570+
return result;
571+
}
572+
KFR_MEM_INTRINSIC progressive progressive_start(complex<T>* out, const T* in, u8* temp) const
573+
{
574+
KFR_LOGIC_CHECK(is_initialized(), "dft_plan_real is not initialized");
575+
KFR_LOGIC_CHECK(this->temp_size == 0 || temp != nullptr,
576+
"Temporary buffer must be provided for progressive execution");
577+
progressive result{};
578+
internal_generic::dft_progressive_start(*this, result, false, out, ptr_cast<const complex<T>>(in),
579+
temp);
580+
return result;
581+
}
491582
};
492583

493584
/// @brief Multidimensional DFT

src/dft/dft.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ void dft_plan<T>::dump() const
3939
}
4040
}
4141

42+
template <typename T>
43+
size_t dft_plan<T>::progressive_total_steps() const
44+
{
45+
return stages[0].size();
46+
}
47+
4248
template <typename T>
4349
void dft_plan<T>::calc_disposition()
4450
{
@@ -131,6 +137,13 @@ CMT_MULTI_PROTO(namespace impl {
131137
u8* temp);
132138
template <typename T>
133139
void dft_initialize_transpose(internal_generic::fn_transpose<T> & transpose);
140+
141+
template <typename T>
142+
void dft_progressive_start(const dft_plan<T>& plan, typename dft_plan<T>::progressive& progressive,
143+
bool inverse, complex<T>* out, const complex<T>* in, u8* temp);
144+
145+
template <typename T>
146+
void dft_progressive_step(const dft_plan<T>& plan, typename dft_plan<T>::progressive& progressive);
134147
})
135148

136149
#ifdef CMT_MULTI_NEEDS_GATE
@@ -159,6 +172,19 @@ void dft_initialize_transpose(fn_transpose<T>& transpose)
159172
CMT_MULTI_GATE(ns::impl::dft_initialize_transpose(transpose));
160173
}
161174

175+
template <typename T>
176+
void dft_progressive_start(const dft_plan<T>& plan, typename dft_plan<T>::progressive& progressive,
177+
bool inverse, complex<T>* out, const complex<T>* in, u8* temp)
178+
{
179+
CMT_MULTI_GATE(ns::impl::dft_progressive_start(plan, progressive, inverse, out, in, temp));
180+
}
181+
182+
template <typename T>
183+
void dft_progressive_step(const dft_plan<T>& plan, typename dft_plan<T>::progressive& progressive)
184+
{
185+
CMT_MULTI_GATE(ns::impl::dft_progressive_step(plan, progressive));
186+
}
187+
162188
template void dft_initialize<float>(dft_plan<float>&);
163189
template void dft_initialize<double>(dft_plan<double>&);
164190
template void dft_real_initialize<float>(dft_plan_real<float>&);
@@ -173,6 +199,16 @@ template void dft_execute<double>(const dft_plan<double>&, cbool_t<true>, comple
173199
const complex<double>*, u8*);
174200
template void dft_initialize_transpose<float>(fn_transpose<float>&);
175201
template void dft_initialize_transpose<double>(fn_transpose<double>&);
202+
template void dft_progressive_start(const dft_plan<float>& plan,
203+
typename dft_plan<float>::progressive& progressive, bool inverse,
204+
complex<float>* out, const complex<float>* in, u8* temp);
205+
template void dft_progressive_start(const dft_plan<double>& plan,
206+
typename dft_plan<double>::progressive& progressive, bool inverse,
207+
complex<double>* out, const complex<double>* in, u8* temp);
208+
template void dft_progressive_step(const dft_plan<float>& plan,
209+
typename dft_plan<float>::progressive& progressive);
210+
template void dft_progressive_step(const dft_plan<double>& plan,
211+
typename dft_plan<double>::progressive& progressive);
176212

177213
} // namespace internal_generic
178214

src/dft/fft-impl.hpp

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,6 @@ KFR_INTRINSIC ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfa
862862
template <typename T>
863863
struct fft_config
864864
{
865-
constexpr static inline const bool recursion = true;
866865
#ifdef CMT_ARCH_NEON
867866
constexpr static inline const bool prefetch = false;
868867
#else
@@ -881,7 +880,7 @@ struct fft_stage_impl : dft_stage<T>
881880
this->radix = 4;
882881
this->stage_size = stage_size;
883882
this->repeats = 4;
884-
this->recursion = fft_config<T>::recursion;
883+
this->recursion = true;
885884
this->data_size =
886885
align_up(sizeof(complex<T>) * stage_size / 4 * 3, platform<>::native_cache_alignment);
887886
}
@@ -922,7 +921,7 @@ struct fft_final_stage_impl : dft_stage<T>
922921
this->stage_size = size;
923922
this->out_offset = size;
924923
this->repeats = 4;
925-
this->recursion = fft_config<T>::recursion;
924+
this->recursion = true;
926925
this->data_size = align_up(sizeof(complex<T>) * size * 3 / 2, platform<>::native_cache_alignment);
927926
}
928927

@@ -1706,9 +1705,9 @@ void make_fft_stages(dft_plan<T>* self, cbool_t<autosort>, size_t stage_size, cb
17061705
} // namespace intrinsics
17071706

17081707
template <bool is_even, typename T>
1709-
void make_fft(dft_plan<T>* self, size_t stage_size, cbool_t<is_even>)
1708+
void make_fft(dft_plan<T>* self, size_t stage_size, cbool_t<is_even>, bool autosort)
17101709
{
1711-
if (use_autosort<T>(ilog2(stage_size)))
1710+
if (autosort)
17121711
{
17131712
intrinsics::make_fft_stages(self, ctrue, stage_size, cbool<is_even>, ctrue);
17141713
}
@@ -1776,7 +1775,8 @@ KFR_INTRINSIC void initialize_order(dft_plan<T>* self)
17761775
template <typename T>
17771776
KFR_INTRINSIC void init_fft(dft_plan<T>* self, size_t size, dft_order)
17781777
{
1779-
const size_t log2n = ilog2(size);
1778+
const size_t log2n = ilog2(size);
1779+
const bool autosort = use_autosort<T>(ilog2(size)) || self->progressive_optimized;
17801780
cswitch(
17811781
csizes_t<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
17821782
#ifdef KFR_AUTOSORT_FOR_2048
@@ -1791,8 +1791,10 @@ KFR_INTRINSIC void init_fft(dft_plan<T>* self, size_t size, dft_order)
17911791
constexpr size_t log2nv = val_of(decltype(log2n)());
17921792
add_stage<intrinsics::fft_specialization<T, log2nv>>(self, size);
17931793
},
1794-
[&]()
1795-
{ cswitch(cfalse_true, is_even(log2n), [&](auto is_even) { make_fft(self, size, is_even); }); });
1794+
[&]() {
1795+
cswitch(cfalse_true, is_even(log2n),
1796+
[&](auto is_even) { make_fft(self, size, is_even, autosort); });
1797+
});
17961798
}
17971799

17981800
template <typename T>
@@ -2027,14 +2029,14 @@ void dft_execute(const dft_plan<T>& plan, cbool_t<inverse>, complex<T>* out, con
20272029
}
20282030
else
20292031
{
2030-
size_t offset = 0;
2032+
size_t offset = 0;
2033+
complex<T>* cur_out = select_out(plan, disposition, depth, stages.size(), out, scratch);
2034+
const complex<T>* cur_in = select_in(plan, disposition, depth, out, in, scratch);
2035+
dft_stage<T>* stage = stages[depth];
20312036
while (offset < plan.size)
20322037
{
2033-
stages[depth]->execute(cbool<inverse>,
2034-
select_out(plan, disposition, depth, stages.size(), out, scratch) +
2035-
offset,
2036-
select_in(plan, disposition, depth, out, in, scratch) + offset, temp);
2037-
offset += stages[depth]->stage_size;
2038+
stage->execute(cbool<inverse>, cur_out + offset, cur_in + offset, temp);
2039+
offset += stage->stage_size;
20382040
}
20392041
depth++;
20402042
}
@@ -2045,6 +2047,47 @@ void dft_initialize_transpose(internal_generic::fn_transpose<T>& transpose)
20452047
{
20462048
transpose = &kfr::CMT_ARCH_NAME::matrix_transpose;
20472049
}
2050+
2051+
template <typename T>
2052+
void dft_progressive_start(const dft_plan<T>& plan, typename dft_plan<T>::progressive& prog, bool inverse,
2053+
complex<T>* out, const complex<T>* in, u8* temp)
2054+
{
2055+
prog.inverse = inverse;
2056+
prog.out = out;
2057+
prog.in = in;
2058+
prog.temp = temp;
2059+
prog.scratch = ptr_cast<complex<T>>(
2060+
temp + plan.temp_size -
2061+
align_up(sizeof(complex<T>) * (plan.size + 1), platform<>::native_cache_alignment));
2062+
2063+
prog.disposition = in == out ? plan.disposition_inplace[inverse] : plan.disposition_outofplace[inverse];
2064+
2065+
bool in_scratch = prog.disposition.test(0);
2066+
if (in_scratch)
2067+
{
2068+
plan.stages[inverse][0]->copy_input(inverse, prog.scratch, in, plan.size);
2069+
}
2070+
prog.step = 0;
2071+
}
2072+
2073+
template <typename T>
2074+
void dft_progressive_step(const dft_plan<T>& plan, typename dft_plan<T>::progressive& progressive)
2075+
{
2076+
auto&& stages = plan.stages[progressive.inverse];
2077+
uint32_t depth = progressive.step;
2078+
complex<T>* cur_out =
2079+
select_out(plan, progressive.disposition, depth, stages.size(), progressive.out, progressive.scratch);
2080+
const complex<T>* cur_in =
2081+
select_in(plan, progressive.disposition, depth, progressive.out, progressive.in, progressive.scratch);
2082+
2083+
size_t offset = 0;
2084+
dft_stage<T>* stage = stages[depth];
2085+
while (offset < plan.size)
2086+
{
2087+
stage->execute(progressive.inverse, cur_out + offset, cur_in + offset, progressive.temp);
2088+
offset += stage->stage_size;
2089+
}
2090+
}
20482091
} // namespace impl
20492092

20502093
namespace intrinsics

src/dft/fft-templates.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,16 @@ inline namespace CMT_ARCH_NAME
3333
{
3434
namespace impl
3535
{
36-
template void dft_initialize<FLOAT>(dft_plan<FLOAT>& plan);
37-
template void dft_real_initialize<FLOAT>(dft_plan_real<FLOAT>& plan);
36+
template void dft_initialize<FLOAT>(dft_plan<FLOAT>&);
37+
template void dft_real_initialize<FLOAT>(dft_plan_real<FLOAT>&);
3838
template void dft_execute<FLOAT>(const dft_plan<FLOAT>&, cbool_t<false>, complex<FLOAT>*,
3939
const complex<FLOAT>*, u8*);
4040
template void dft_execute<FLOAT>(const dft_plan<FLOAT>&, cbool_t<true>, complex<FLOAT>*,
4141
const complex<FLOAT>*, u8*);
42-
template void dft_initialize_transpose<FLOAT>(internal_generic::fn_transpose<FLOAT>& transpose);
42+
template void dft_initialize_transpose<FLOAT>(internal_generic::fn_transpose<FLOAT>&);
43+
template void dft_progressive_start(const dft_plan<FLOAT>&, typename dft_plan<FLOAT>::progressive&, bool,
44+
complex<FLOAT>*, const complex<FLOAT>*, u8*);
45+
template void dft_progressive_step(const dft_plan<FLOAT>&, typename dft_plan<FLOAT>::progressive&);
4346
} // namespace impl
4447
} // namespace CMT_ARCH_NAME
4548
} // namespace kfr

0 commit comments

Comments
 (0)