@@ -111,17 +111,14 @@ void test_iamax2() {
111
111
void test_rotg1 () {
112
112
dpct::queue_ptr handle;
113
113
handle = &dpct::get_out_of_order_queue ();
114
- float *a = (float *)dpct::dpct_malloc (sizeof (float ) * 1 );
115
- float *b = (float *)dpct::dpct_malloc (sizeof (float ) * 1 );
116
- float *c = (float *)dpct::dpct_malloc (sizeof (float ) * 1 );
117
- float *s = (float *)dpct::dpct_malloc (sizeof (float ) * 1 );
118
- dpct::get_host_ptr<float >(a)[0 ] = 1 ;
119
- dpct::get_host_ptr<float >(b)[0 ] = 1.73205 ;
114
+ float *four_args = (float *)dpct::dpct_malloc (sizeof (float ) * 4 );
115
+ dpct::get_host_ptr<float >(four_args)[0 ] = 1 ;
116
+ dpct::get_host_ptr<float >(four_args)[1 ] = 1.73205 ;
120
117
[&]() {
121
- dpct::blas::wrapper_float_inout a_m (*handle, a );
122
- dpct::blas::wrapper_float_inout b_m (*handle, b );
123
- dpct::blas::wrapper_float_out c_m (*handle, c );
124
- dpct::blas::wrapper_float_out s_m (*handle, s );
118
+ dpct::blas::wrapper_float_inout a_m (*handle, four_args );
119
+ dpct::blas::wrapper_float_inout b_m (*handle, four_args + 1 );
120
+ dpct::blas::wrapper_float_out c_m (*handle, four_args + 2 );
121
+ dpct::blas::wrapper_float_out s_m (*handle, four_args + 3 );
125
122
oneapi::mkl::blas::column_major::rotg (
126
123
*handle,
127
124
dpct::rvalue_ref_to_lvalue_ref (dpct::get_buffer<float >(a_m.get_ptr ())),
@@ -131,22 +128,19 @@ void test_rotg1() {
131
128
}();
132
129
dpct::get_current_device ().queues_wait_and_throw ();
133
130
handle = nullptr ;
134
- if (std::abs (dpct::get_host_ptr<float >(a )[0 ] - 2 .0f ) < 0.01 &&
135
- std::abs (dpct::get_host_ptr<float >(b)[ 0 ] - 2 .0f ) < 0.01 &&
136
- std::abs (dpct::get_host_ptr<float >(c)[ 0 ] - 0 .5f ) < 0.01 &&
137
- std::abs (dpct::get_host_ptr<float >(s)[ 0 ] - 0 .866025f ) < 0.01 ) {
131
+ if (std::abs (dpct::get_host_ptr<float >(four_args )[0 ] - 2 .0f ) < 0.01 &&
132
+ std::abs (dpct::get_host_ptr<float >(four_args)[ 1 ] - 2 .0f ) < 0.01 &&
133
+ std::abs (dpct::get_host_ptr<float >(four_args)[ 2 ] - 0 .5f ) < 0.01 &&
134
+ std::abs (dpct::get_host_ptr<float >(four_args)[ 3 ] - 0 .866025f ) < 0.01 ) {
138
135
printf (" test_rotg1 pass\n " );
139
136
} else {
140
137
printf (" test_rotg1 fail:\n " );
141
- printf (" %f,%f,%f,%f\n " , dpct::get_host_ptr<float >(a )[0 ],
142
- dpct::get_host_ptr<float >(b)[ 0 ], dpct::get_host_ptr<float >(c)[ 0 ],
143
- dpct::get_host_ptr<float >(s)[ 0 ]);
138
+ printf (" %f,%f,%f,%f\n " , dpct::get_host_ptr<float >(four_args )[0 ],
139
+ dpct::get_host_ptr<float >(four_args)[ 1 ], dpct::get_host_ptr<float >(four_args)[ 2 ],
140
+ dpct::get_host_ptr<float >(four_args)[ 3 ]);
144
141
pass = false ;
145
142
}
146
- dpct::dpct_free (a);
147
- dpct::dpct_free (b);
148
- dpct::dpct_free (c);
149
- dpct::dpct_free (s);
143
+ dpct::dpct_free (four_args);
150
144
}
151
145
152
146
void test_rotg2 () {
0 commit comments