#include "relapack.h" static void RELAPACK_cgemmt_rec(const char *, const char *, const char *, const blasint *, const blasint *, const float *, const float *, const blasint *, const float *, const blasint *, const float *, float *, const blasint *); static void RELAPACK_cgemmt_rec2(const char *, const char *, const char *, const blasint *, const blasint *, const float *, const float *, const blasint *, const float *, const blasint *, const float *, float *, const blasint *); /** CGEMMT computes a matrix-matrix product with general matrices but updates * only the upper or lower triangular part of the result matrix. * * This routine performs the same operation as the BLAS routine * cgemm(transA, transB, n, n, k, alpha, A, ldA, B, ldB, beta, C, ldC) * but only updates the triangular part of C specified by uplo: * If (*uplo == 'L'), only the lower triangular part of C is updated, * otherwise the upper triangular part is updated. * */ void RELAPACK_cgemmt( const char *uplo, const char *transA, const char *transB, const blasint *n, const blasint *k, const float *alpha, const float *A, const blasint *ldA, const float *B, const blasint *ldB, const float *beta, float *C, const blasint *ldC ) { #if HAVE_XGEMMT BLAS(cgemmt)(uplo, transA, transB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC); return; #else // Check arguments const blasint lower = LAPACK(lsame)(uplo, "L"); const blasint upper = LAPACK(lsame)(uplo, "U"); const blasint notransA = LAPACK(lsame)(transA, "N"); const blasint tranA = LAPACK(lsame)(transA, "T"); const blasint ctransA = LAPACK(lsame)(transA, "C"); const blasint notransB = LAPACK(lsame)(transB, "N"); const blasint tranB = LAPACK(lsame)(transB, "T"); const blasint ctransB = LAPACK(lsame)(transB, "C"); blasint info = 0; if (!lower && !upper) info = 1; else if (!tranA && !ctransA && !notransA) info = 2; else if (!tranB && !ctransB && !notransB) info = 3; else if (*n < 0) info = 4; else if (*k < 0) info = 5; else if (*ldA < MAX(1, notransA ? *n : *k)) info = 8; else if (*ldB < MAX(1, notransB ? *k : *n)) info = 10; else if (*ldC < MAX(1, *n)) info = 13; if (info) { LAPACK(xerbla)("CGEMMT", &info, strlen("CGEMMT")); return; } // Clean char * arguments const char cleanuplo = lower ? 'L' : 'U'; const char cleantransA = notransA ? 'N' : (tranA ? 'T' : 'C'); const char cleantransB = notransB ? 'N' : (tranB ? 'T' : 'C'); // Recursive kernel RELAPACK_cgemmt_rec(&cleanuplo, &cleantransA, &cleantransB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC); #endif } /** cgemmt's recursive compute kernel */ static void RELAPACK_cgemmt_rec( const char *uplo, const char *transA, const char *transB, const blasint *n, const blasint *k, const float *alpha, const float *A, const blasint *ldA, const float *B, const blasint *ldB, const float *beta, float *C, const blasint *ldC ) { if (*n <= MAX(CROSSOVER_CGEMMT, 1)) { // Unblocked RELAPACK_cgemmt_rec2(uplo, transA, transB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC); return; } // Splitting const blasint n1 = CREC_SPLIT(*n); const blasint n2 = *n - n1; // A_T // A_B const float *const A_T = A; const float *const A_B = A + 2 * ((*transA == 'N') ? n1 : *ldA * n1); // B_L B_R const float *const B_L = B; const float *const B_R = B + 2 * ((*transB == 'N') ? *ldB * n1 : n1); // C_TL C_TR // C_BL C_BR float *const C_TL = C; float *const C_TR = C + 2 * *ldC * n1; float *const C_BL = C + 2 * n1; float *const C_BR = C + 2 * *ldC * n1 + 2 * n1; // recursion(C_TL) RELAPACK_cgemmt_rec(uplo, transA, transB, &n1, k, alpha, A_T, ldA, B_L, ldB, beta, C_TL, ldC); if (*uplo == 'L') // C_BL = alpha A_B B_L + beta C_BL BLAS(cgemm)(transA, transB, &n2, &n1, k, alpha, A_B, ldA, B_L, ldB, beta, C_BL, ldC); else // C_TR = alpha A_T B_R + beta C_TR BLAS(cgemm)(transA, transB, &n1, &n2, k, alpha, A_T, ldA, B_R, ldB, beta, C_TR, ldC); // recursion(C_BR) RELAPACK_cgemmt_rec(uplo, transA, transB, &n2, k, alpha, A_B, ldA, B_R, ldB, beta, C_BR, ldC); } /** cgemmt's unblocked compute kernel */ static void RELAPACK_cgemmt_rec2( const char *uplo, const char *transA, const char *transB, const blasint *n, const blasint *k, const float *alpha, const float *A, const blasint *ldA, const float *B, const blasint *ldB, const float *beta, float *C, const blasint *ldC ) { const blasint incB = (*transB == 'N') ? 1 : *ldB; const blasint incC = 1; blasint i; for (i = 0; i < *n; i++) { // A_0 // A_i const float *const A_0 = A; const float *const A_i = A + 2 * ((*transA == 'N') ? i : *ldA * i); // * B_i * const float *const B_i = B + 2 * ((*transB == 'N') ? *ldB * i : i); // * C_0i * // * C_ii * float *const C_0i = C + 2 * *ldC * i; float *const C_ii = C + 2 * *ldC * i + 2 * i; if (*uplo == 'L') { const blasint nmi = *n - i; if (*transA == 'N') BLAS(cgemv)(transA, &nmi, k, alpha, A_i, ldA, B_i, &incB, beta, C_ii, &incC); else BLAS(cgemv)(transA, k, &nmi, alpha, A_i, ldA, B_i, &incB, beta, C_ii, &incC); } else { const blasint ip1 = i + 1; if (*transA == 'N') BLAS(cgemv)(transA, &ip1, k, alpha, A_0, ldA, B_i, &incB, beta, C_0i, &incC); else BLAS(cgemv)(transA, k, &ip1, alpha, A_0, ldA, B_i, &incB, beta, C_0i, &incC); } } }