Skip to content

Commit 323c817

Browse files
authored
Support ComputeFn where output type differs from input type (#1771)
This is useful for e.g. function taking in 2 float inputs and turn them to complex
1 parent 82f5075 commit 323c817

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,20 @@ struct Sm90Compute {
181181
},
182182
[&] (auto&&... cvt_frg_inputs) {
183183
using ComputeOutput = ComputeFn<Array<ElementCompute, FragmentSize>>;
184-
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize, RoundStyle>;
185184
ComputeOutput compute_output{};
186-
ConvertOutput convert_output{};
187185

188186
if constexpr (cute::is_same_v<Arguments, EmptyArguments>) {
187+
using ElementComputeOutput =
188+
typename cute::remove_cvref_t<decltype(compute_output(cvt_frg_inputs...))>::Element;
189+
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementComputeOutput, FragmentSize, RoundStyle>;
190+
ConvertOutput convert_output{};
189191
return convert_output(compute_output(cvt_frg_inputs...));
190192
}
191193
else {
194+
using ElementComputeOutput =
195+
typename cute::remove_cvref_t<decltype(compute_output(cvt_frg_inputs..., params))>::Element;
196+
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementComputeOutput, FragmentSize, RoundStyle>;
197+
ConvertOutput convert_output{};
192198
return convert_output(compute_output(cvt_frg_inputs..., params));
193199
}
194200
}

0 commit comments

Comments
 (0)