Skip to content

Commit 6e8322d

Browse files
authored
Add gemv_op_mt for DSP (#7009)
1 parent 5d834e9 commit 6e8322d

File tree

5 files changed

+454
-0
lines changed

5 files changed

+454
-0
lines changed

source/source_base/kernels/dsp/dsp_connector.cpp

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,118 @@ void cgemm_mt_(const char* transa,
187187
cluster_id);
188188
} // cgemm that needn't malloc_ht or free_ht
189189

190+
void sgemv_mt_(const char* transa,
191+
const int* m,
192+
const int* n,
193+
const float* alpha,
194+
const float* a,
195+
const int* lda,
196+
const float* x,
197+
const int* incx,
198+
const float* beta,
199+
float* y,
200+
const int* incy,
201+
int cluster_id)
202+
{
203+
mtblas_sgemv(CBLAS_ORDER::CblasColMajor,
204+
convertBLASTranspose(transa),
205+
*m,
206+
*n,
207+
*alpha,
208+
a,
209+
*lda,
210+
x,
211+
*incx,
212+
*beta,
213+
y,
214+
*incy,
215+
cluster_id);
216+
}
217+
218+
void dgemv_mt_(const char* transa,
219+
const int* m,
220+
const int* n,
221+
const double* alpha,
222+
const double* a,
223+
const int* lda,
224+
const double* x,
225+
const int* incx,
226+
const double* beta,
227+
double* y,
228+
const int* incy,
229+
int cluster_id)
230+
{
231+
mtblas_dgemv(CBLAS_ORDER::CblasColMajor,
232+
convertBLASTranspose(transa),
233+
*m,
234+
*n,
235+
*alpha,
236+
a,
237+
*lda,
238+
x,
239+
*incx,
240+
*beta,
241+
y,
242+
*incy,
243+
cluster_id);
244+
}
245+
246+
void zgemv_mt_(const char* transa,
247+
const int* m,
248+
const int* n,
249+
const std::complex<double>* alpha,
250+
const std::complex<double>* a,
251+
const int* lda,
252+
const std::complex<double>* x,
253+
const int* incx,
254+
const std::complex<double>* beta,
255+
std::complex<double>* y,
256+
const int* incy,
257+
int cluster_id)
258+
{
259+
mtblas_zgemv(CBLAS_ORDER::CblasColMajor,
260+
convertBLASTranspose(transa),
261+
*m,
262+
*n,
263+
(const void*)alpha,
264+
(const void*)a,
265+
*lda,
266+
(const void*)x,
267+
*incx,
268+
(const void*)beta,
269+
(void*)y,
270+
*incy,
271+
cluster_id);
272+
}
273+
274+
void cgemv_mt_(const char* transa,
275+
const int* m,
276+
const int* n,
277+
const std::complex<float>* alpha,
278+
const std::complex<float>* a,
279+
const int* lda,
280+
const std::complex<float>* x,
281+
const int* incx,
282+
const std::complex<float>* beta,
283+
std::complex<float>* y,
284+
const int* incy,
285+
int cluster_id)
286+
{
287+
mtblas_cgemv(CBLAS_ORDER::CblasColMajor,
288+
convertBLASTranspose(transa),
289+
*m,
290+
*n,
291+
(const void*)alpha,
292+
(const void*)a,
293+
*lda,
294+
(const void*)x,
295+
*incx,
296+
(const void*)beta,
297+
(void*)y,
298+
*incy,
299+
cluster_id);
300+
}
301+
190302
// Used to replace original free
191303

192304
void sgemm_mth_(const char* transa,
@@ -330,4 +442,132 @@ void cgemm_mth_(const char* transa,
330442
free_ht(alp);
331443
free_ht(bet);
332444
} // cgemm that needn't malloc_ht or free_ht
445+
446+
void sgemv_mth_(const char* transa,
447+
const int* m,
448+
const int* n,
449+
const float* alpha,
450+
const float* a,
451+
const int* lda,
452+
const float* x,
453+
const int* incx,
454+
const float* beta,
455+
float* y,
456+
const int* incy,
457+
int cluster_id)
458+
{
459+
mt_hthread_sgemv(CBLAS_ORDER::CblasColMajor,
460+
convertBLASTranspose(transa),
461+
*m,
462+
*n,
463+
*alpha,
464+
a,
465+
*lda,
466+
x,
467+
*incx,
468+
*beta,
469+
y,
470+
*incy,
471+
cluster_id);
472+
}
473+
474+
void dgemv_mth_(const char* transa,
475+
const int* m,
476+
const int* n,
477+
const double* alpha,
478+
const double* a,
479+
const int* lda,
480+
const double* x,
481+
const int* incx,
482+
const double* beta,
483+
double* y,
484+
const int* incy,
485+
int cluster_id)
486+
{
487+
mt_hthread_dgemv(CBLAS_ORDER::CblasColMajor,
488+
convertBLASTranspose(transa),
489+
*m,
490+
*n,
491+
*alpha,
492+
a,
493+
*lda,
494+
x,
495+
*incx,
496+
*beta,
497+
y,
498+
*incy,
499+
cluster_id);
500+
}
501+
502+
void zgemv_mth_(const char* transa,
503+
const int* m,
504+
const int* n,
505+
const std::complex<double>* alpha,
506+
const std::complex<double>* a,
507+
const int* lda,
508+
const std::complex<double>* x,
509+
const int* incx,
510+
const std::complex<double>* beta,
511+
std::complex<double>* y,
512+
const int* incy,
513+
int cluster_id)
514+
{
515+
std::complex<double>* alp = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
516+
*alp = *alpha;
517+
std::complex<double>* bet = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
518+
*bet = *beta;
519+
520+
mt_hthread_zgemv(CBLAS_ORDER::CblasColMajor,
521+
convertBLASTranspose(transa),
522+
*m,
523+
*n,
524+
(const void*)alp,
525+
(const void*)a,
526+
*lda,
527+
(const void*)x,
528+
*incx,
529+
(const void*)bet,
530+
(void*)y,
531+
*incy,
532+
cluster_id);
533+
534+
free_ht(alp);
535+
free_ht(bet);
536+
}
537+
538+
void cgemv_mth_(const char* transa,
539+
const int* m,
540+
const int* n,
541+
const std::complex<float>* alpha,
542+
const std::complex<float>* a,
543+
const int* lda,
544+
const std::complex<float>* x,
545+
const int* incx,
546+
const std::complex<float>* beta,
547+
std::complex<float>* y,
548+
const int* incy,
549+
int cluster_id)
550+
{
551+
std::complex<float>* alp = (std::complex<float>*)malloc_ht(sizeof(std::complex<float>), cluster_id);
552+
*alp = *alpha;
553+
std::complex<float>* bet = (std::complex<float>*)malloc_ht(sizeof(std::complex<float>), cluster_id);
554+
*bet = *beta;
555+
556+
mt_hthread_cgemv(CBLAS_ORDER::CblasColMajor,
557+
convertBLASTranspose(transa),
558+
*m,
559+
*n,
560+
(const void*)alp,
561+
(const void*)a,
562+
*lda,
563+
(const void*)x,
564+
*incx,
565+
(const void*)bet,
566+
(void*)y,
567+
*incy,
568+
cluster_id);
569+
570+
free_ht(alp);
571+
free_ht(bet);
572+
}
333573
} // namespace mtfunc

source/source_base/kernels/dsp/dsp_connector.h

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,58 @@ void cgemm_mt_(const char* transa,
7676
const int* ldc,
7777
int cluster_id);
7878

79+
void sgemv_mt_(const char* transa,
80+
const int* m,
81+
const int* n,
82+
const float* alpha,
83+
const float* a,
84+
const int* lda,
85+
const float* x,
86+
const int* incx,
87+
const float* beta,
88+
float* y,
89+
const int* incy,
90+
int cluster_id);
91+
92+
void dgemv_mt_(const char* transa,
93+
const int* m,
94+
const int* n,
95+
const double* alpha,
96+
const double* a,
97+
const int* lda,
98+
const double* x,
99+
const int* incx,
100+
const double* beta,
101+
double* y,
102+
const int* incy,
103+
int cluster_id);
104+
105+
void zgemv_mt_(const char* transa,
106+
const int* m,
107+
const int* n,
108+
const std::complex<double>* alpha,
109+
const std::complex<double>* a,
110+
const int* lda,
111+
const std::complex<double>* x,
112+
const int* incx,
113+
const std::complex<double>* beta,
114+
std::complex<double>* y,
115+
const int* incy,
116+
int cluster_id);
117+
118+
void cgemv_mt_(const char* transa,
119+
const int* m,
120+
const int* n,
121+
const std::complex<float>* alpha,
122+
const std::complex<float>* a,
123+
const int* lda,
124+
const std::complex<float>* x,
125+
const int* incx,
126+
const std::complex<float>* beta,
127+
std::complex<float>* y,
128+
const int* incy,
129+
int cluster_id);
130+
79131
void sgemm_mth_(const char* transa,
80132
const char* transb,
81133
const int* m,
@@ -136,6 +188,58 @@ void cgemm_mth_(const char* transa,
136188
const int* ldc,
137189
int cluster_id);
138190

191+
void sgemv_mth_(const char* transa,
192+
const int* m,
193+
const int* n,
194+
const float* alpha,
195+
const float* a,
196+
const int* lda,
197+
const float* x,
198+
const int* incx,
199+
const float* beta,
200+
float* y,
201+
const int* incy,
202+
int cluster_id);
203+
204+
void dgemv_mth_(const char* transa,
205+
const int* m,
206+
const int* n,
207+
const double* alpha,
208+
const double* a,
209+
const int* lda,
210+
const double* x,
211+
const int* incx,
212+
const double* beta,
213+
double* y,
214+
const int* incy,
215+
int cluster_id);
216+
217+
void zgemv_mth_(const char* transa,
218+
const int* m,
219+
const int* n,
220+
const std::complex<double>* alpha,
221+
const std::complex<double>* a,
222+
const int* lda,
223+
const std::complex<double>* x,
224+
const int* incx,
225+
const std::complex<double>* beta,
226+
std::complex<double>* y,
227+
const int* incy,
228+
int cluster_id);
229+
230+
void cgemv_mth_(const char* transa,
231+
const int* m,
232+
const int* n,
233+
const std::complex<float>* alpha,
234+
const std::complex<float>* a,
235+
const int* lda,
236+
const std::complex<float>* x,
237+
const int* incx,
238+
const std::complex<float>* beta,
239+
std::complex<float>* y,
240+
const int* incy,
241+
int cluster_id);
242+
139243
// #define zgemm_ zgemm_mt
140244

141245
// The next is dsp utils. It may be moved to other files if this file get too huge

0 commit comments

Comments
 (0)