Skip to content

Commit 86c8dff

Browse files
authored
[SYCLomatic oneapi-src#1844] Adjust blas_utils_parameter_wrapper_buf to test more cases (oneapi-src#667)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent fb64662 commit 86c8dff

File tree

1 file changed

+15
-21
lines changed

1 file changed

+15
-21
lines changed

help_function/src/blas_utils_parameter_wrapper_buf.cpp

+15-21
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,14 @@ void test_iamax2() {
111111
void test_rotg1() {
112112
dpct::queue_ptr handle;
113113
handle = &dpct::get_out_of_order_queue();
114-
float *a = (float *)dpct::dpct_malloc(sizeof(float) * 1);
115-
float *b = (float *)dpct::dpct_malloc(sizeof(float) * 1);
116-
float *c = (float *)dpct::dpct_malloc(sizeof(float) * 1);
117-
float *s = (float *)dpct::dpct_malloc(sizeof(float) * 1);
118-
dpct::get_host_ptr<float>(a)[0] = 1;
119-
dpct::get_host_ptr<float>(b)[0] = 1.73205;
114+
float *four_args = (float *)dpct::dpct_malloc(sizeof(float) * 4);
115+
dpct::get_host_ptr<float>(four_args)[0] = 1;
116+
dpct::get_host_ptr<float>(four_args)[1] = 1.73205;
120117
[&]() {
121-
dpct::blas::wrapper_float_inout a_m(*handle, a);
122-
dpct::blas::wrapper_float_inout b_m(*handle, b);
123-
dpct::blas::wrapper_float_out c_m(*handle, c);
124-
dpct::blas::wrapper_float_out s_m(*handle, s);
118+
dpct::blas::wrapper_float_inout a_m(*handle, four_args);
119+
dpct::blas::wrapper_float_inout b_m(*handle, four_args + 1);
120+
dpct::blas::wrapper_float_out c_m(*handle, four_args + 2);
121+
dpct::blas::wrapper_float_out s_m(*handle, four_args + 3);
125122
oneapi::mkl::blas::column_major::rotg(
126123
*handle,
127124
dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer<float>(a_m.get_ptr())),
@@ -131,22 +128,19 @@ void test_rotg1() {
131128
}();
132129
dpct::get_current_device().queues_wait_and_throw();
133130
handle = nullptr;
134-
if (std::abs(dpct::get_host_ptr<float>(a)[0] - 2.0f) < 0.01 &&
135-
std::abs(dpct::get_host_ptr<float>(b)[0] - 2.0f) < 0.01 &&
136-
std::abs(dpct::get_host_ptr<float>(c)[0] - 0.5f) < 0.01 &&
137-
std::abs(dpct::get_host_ptr<float>(s)[0] - 0.866025f) < 0.01) {
131+
if (std::abs(dpct::get_host_ptr<float>(four_args)[0] - 2.0f) < 0.01 &&
132+
std::abs(dpct::get_host_ptr<float>(four_args)[1] - 2.0f) < 0.01 &&
133+
std::abs(dpct::get_host_ptr<float>(four_args)[2] - 0.5f) < 0.01 &&
134+
std::abs(dpct::get_host_ptr<float>(four_args)[3] - 0.866025f) < 0.01) {
138135
printf("test_rotg1 pass\n");
139136
} else {
140137
printf("test_rotg1 fail:\n");
141-
printf("%f,%f,%f,%f\n", dpct::get_host_ptr<float>(a)[0],
142-
dpct::get_host_ptr<float>(b)[0], dpct::get_host_ptr<float>(c)[0],
143-
dpct::get_host_ptr<float>(s)[0]);
138+
printf("%f,%f,%f,%f\n", dpct::get_host_ptr<float>(four_args)[0],
139+
dpct::get_host_ptr<float>(four_args)[1], dpct::get_host_ptr<float>(four_args)[2],
140+
dpct::get_host_ptr<float>(four_args)[3]);
144141
pass = false;
145142
}
146-
dpct::dpct_free(a);
147-
dpct::dpct_free(b);
148-
dpct::dpct_free(c);
149-
dpct::dpct_free(s);
143+
dpct::dpct_free(four_args);
150144
}
151145

152146
void test_rotg2() {

0 commit comments

Comments
 (0)