@@ -33,31 +33,8 @@ using Generator = DefaultSystemGenerator<>;
3333
3434struct ConversionBenchmark : Benchmark<gko::device_matrix_data<etype, itype>> {
3535 std::string name;
36- std::vector<std::string> operations;
3736
38- ConversionBenchmark () : name{" conversion" }
39- {
40- auto ref_exec = gko::ReferenceExecutor::create ();
41- auto formats = split (FLAGS_formats);
42- for (const auto & from_format : formats) {
43- operations.push_back (from_format + " -read" );
44- auto from_mtx =
45- formats::matrix_type_factory.at (from_format)(ref_exec);
46- // all pairs of conversions that are supported by Ginkgo
47- for (const auto & to_format : formats) {
48- if (from_format == to_format) {
49- continue ;
50- }
51- auto to_mtx =
52- formats::matrix_type_factory.at (to_format)(ref_exec);
53- try {
54- to_mtx->copy_from (from_mtx);
55- operations.push_back (from_format + " -" + to_format);
56- } catch (const std::exception& e) {
57- }
58- }
59- }
60- }
37+ ConversionBenchmark () : name{" conversion" } {}
6138
6239 const std::string& get_name () const override { return name; }
6340
@@ -83,55 +60,80 @@ struct ConversionBenchmark : Benchmark<gko::device_matrix_data<etype, itype>> {
8360 gko::device_matrix_data<etype, itype>& data,
8461 const json& operation_case, json& result_case) const override
8562 {
86- for (const auto & operation_name : operations) {
87- result_case[operation_name] = json::object ();
88- auto & op_result_case = result_case[operation_name];
89-
90- auto split_it =
91- std::find (operation_name.begin (), operation_name.end (), ' -' );
92- std::string from_name{operation_name.begin (), split_it};
93- std::string to_name{split_it + 1 , operation_name.end ()};
94- auto mtx_from = formats::matrix_type_factory.at (from_name)(exec);
95- auto readable = gko::as<gko::ReadableFromMatrixData<etype, itype>>(
96- mtx_from.get ());
97- IterationControl ic{timer};
98- if (to_name == " read" ) {
99- // warm run
100- {
101- auto range = annotate (" warmup" , FLAGS_warmup > 0 );
102- for (auto _ : ic.warmup_run ()) {
103- exec->synchronize ();
104- readable->read (data);
105- exec->synchronize ();
106- }
107- }
108- // timed run
109- for (auto _ : ic.run ()) {
110- auto range = annotate (" repetition" );
63+ std::string from_name = operation_case[" from" ].get <std::string>();
64+ std::string to_name = operation_case[" to" ].get <std::string>();
65+ auto mtx_from = formats::matrix_type_factory.at (from_name)(exec);
66+ auto readable =
67+ gko::as<gko::ReadableFromMatrixData<etype, itype>>(mtx_from.get ());
68+
69+ // check if conversion is supported on empty matrix first
70+ if (from_name != to_name) {
71+ auto to_mtx = formats::matrix_type_factory.at (to_name)(exec);
72+ to_mtx->copy_from (mtx_from);
73+ }
74+
75+ IterationControl ic{timer};
76+ if (to_name == from_name) {
77+ // warm run
78+ {
79+ auto range = annotate (" warmup" , FLAGS_warmup > 0 );
80+ for (auto _ : ic.warmup_run ()) {
81+ exec->synchronize ();
11182 readable->read (data);
83+ exec->synchronize ();
11284 }
113- } else {
85+ }
86+ // timed run
87+ for (auto _ : ic.run ()) {
88+ auto range = annotate (" repetition" );
11489 readable->read (data);
115- auto mtx_to = formats::matrix_type_factory.at (to_name)(exec);
116-
117- // warm run
118- {
119- auto range = annotate (" warmup" , FLAGS_warmup > 0 );
120- for (auto _ : ic.warmup_run ()) {
121- exec->synchronize ();
122- mtx_to->copy_from (mtx_from);
123- exec->synchronize ();
124- }
125- }
126- // timed run
127- for (auto _ : ic.run ()) {
128- auto range = annotate (" repetition" );
90+ }
91+ } else {
92+ readable->read (data);
93+ auto mtx_to = formats::matrix_type_factory.at (to_name)(exec);
94+
95+ // warm run
96+ {
97+ auto range = annotate (" warmup" , FLAGS_warmup > 0 );
98+ for (auto _ : ic.warmup_run ()) {
99+ exec->synchronize ();
129100 mtx_to->copy_from (mtx_from);
101+ exec->synchronize ();
130102 }
131103 }
132- op_result_case[" time" ] = ic.compute_time (FLAGS_timer_method);
133- op_result_case[" repetitions" ] = ic.get_num_repetitions ();
104+ // timed run
105+ for (auto _ : ic.run ()) {
106+ auto range = annotate (" repetition" );
107+ mtx_to->copy_from (mtx_from);
108+ }
109+ }
110+ result_case[" time" ] = ic.compute_time (FLAGS_timer_method);
111+ result_case[" repetitions" ] = ic.get_num_repetitions ();
112+ }
113+
114+ void postprocess (json& test_cases) const override
115+ {
116+ std::map<json, json> same_operators;
117+ for (const auto & test_case : test_cases) {
118+ if (test_case[name].contains (" error_type" ) &&
119+ test_case[name][" error_type" ] == " gko::NotSupported" ) {
120+ continue ;
121+ }
122+ auto case_operator = test_case;
123+ case_operator.erase (" to" );
124+ case_operator.erase (" from" );
125+ case_operator.erase (name);
126+ same_operators.try_emplace (case_operator, json::array ());
127+ same_operators[case_operator].push_back (test_case[name]);
128+ same_operators[case_operator].back ()[" to" ] = test_case[" to" ];
129+ same_operators[case_operator].back ()[" from" ] = test_case[" from" ];
134130 }
131+ auto merged_cases = json::array ();
132+ for (auto & [case_operator, results] : same_operators) {
133+ merged_cases.push_back (case_operator);
134+ merged_cases.back ()[name] = results;
135+ }
136+ test_cases = std::move (merged_cases);
135137 }
136138};
137139
@@ -146,12 +148,8 @@ int main(int argc, char* argv[])
146148
147149 initialize_argument_parsing (&argc, &argv, header, schema[" examples" ]);
148150
149- std::string extra_information =
150- std::string () + " The formats are " + FLAGS_formats;
151-
152151 auto exec = executor_factory.at (FLAGS_executor)(FLAGS_gpu_timer);
153- print_general_information (extra_information, exec);
154- auto formats = split (FLAGS_formats, ' ,' );
152+ print_general_information (" " , exec);
155153
156154 auto test_cases = json::parse (get_input_stream ());
157155
0 commit comments