diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 90a0a81..7f4764d 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -139,6 +139,7 @@ #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 +#define CC_TURING 750 #define CC_AMPERE 800 #define CC_OFFSET_AMD 1000000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) @@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000 #endif // defined(GGML_USE_HIPBLAS) -#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL +#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL +#define FP16_AVAILABLE +#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL -#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#define FP16_MMA_AVAILABLE +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA + +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING +#define INT8_MMA_AVAILABLE +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING static bool fast_fp16_available(const int cc) { return cc >= CC_PASCAL && cc != 610; @@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; } +static bool int8_mma_available(const int cc) { + return cc < CC_OFFSET_AMD && cc >= CC_TURING; +} + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { @@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { } static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #pragma unroll @@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { } static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX return __float2half(fmaxf(__half2float(a), __half2float(b))); diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index c00f860..37b3b99 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -74,7 +74,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const int sumi = __dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -122,7 +122,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const int sumi = __dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -181,7 +181,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const int sumi = __dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const int sumi = __dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -314,7 +314,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_h2 = (const half2 *) Q_v; @@ -407,7 +407,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F) - 8; -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } @@ -428,7 +428,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return __low2half(dm)*((half) q) + __high2half(dm); } @@ -453,7 +453,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh) - 16; -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } @@ -478,7 +478,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return __low2half(dm)*((half) q) + __high2half(dm); } @@ -497,7 +497,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ const T d = x[ib].d; const int q = x[ib].qs[iqs]; -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index cb11d72..c6c3513 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -43,7 +43,7 @@ static __global__ void flash_attn_tile_ext_f16( const int ne1, const int ne2, const int ne3) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. diff --git a/ggml-cuda/fattn-vec-f16.cuh b/ggml-cuda/fattn-vec-f16.cuh index 9e1aa2c..02a4ad0 100644 --- a/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml-cuda/fattn-vec-f16.cuh @@ -40,7 +40,7 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); diff --git a/ggml-cuda/fattn-wmma-f16.cuh b/ggml-cuda/fattn-wmma-f16.cuh index 59cd30d..ae23222 100644 --- a/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml-cuda/fattn-wmma-f16.cuh @@ -1,9 +1,9 @@ #include "common.cuh" #include "fattn-common.cuh" -#if FP16_MMA_AVAILABLE +#ifdef FP16_MMA_AVAILABLE #include -#endif +#endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template @@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if FP16_MMA_AVAILABLE +#ifdef FP16_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. diff --git a/ggml-cuda/mma.cuh b/ggml-cuda/mma.cuh new file mode 100644 index 0000000..71e8e34 --- /dev/null +++ b/ggml-cuda/mma.cuh @@ -0,0 +1,95 @@ +#include "common.cuh" + +struct mma_int_A_I16K8 { + static constexpr int I = 16; + static constexpr int K = 8; + static constexpr int ne = 4; + + int x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_k(const int l) { + const int ret = (l/2) * (K/2) + threadIdx.x % (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < K); + return ret; + } +}; + +struct mma_int_B_J8K8 { + static constexpr int J = 8; + static constexpr int K = 8; + static constexpr int ne = 2; + + int x[ne] = {0}; + + static __device__ __forceinline__ int get_j(const int /* l */) { + const int ret = threadIdx.x / (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + static __device__ __forceinline__ int get_k(const int l) { + const int ret = l * (K/2) + threadIdx.x % (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < K); + return ret; + } +}; + +struct mma_int_C_I16J8 { + static constexpr int I = 16; + static constexpr int J = 8; + static constexpr int ne = 4; + + int x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_j(const int l) { + const int ret = 2 * (threadIdx.x % (J/2)) + l%2; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) { +#ifdef INT8_MMA_AVAILABLE +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) + : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1])); +#else + // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[0]), "+r"(x[1]) + : "r"(mma_A.x[0]), "r"(mma_B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[2]), "+r"(x[3]) + : "r"(mma_A.x[1]), "r"(mma_B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[0]), "+r"(x[1]) + : "r"(mma_A.x[2]), "r"(mma_B.x[1])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[2]), "+r"(x[3]) + : "r"(mma_A.x[3]), "r"(mma_B.x[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(mma_A); + GGML_UNUSED(mma_B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } +}; diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 3ccae8a..62111f3 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -2,6 +2,7 @@ #include "common.cuh" #include "vecdotq.cuh" +#include "mma.cuh" #include #include @@ -14,6 +15,7 @@ typedef void (*load_tiles_mmq_t)( typedef void (*vec_dot_mmq_t)( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0); +typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1); struct block_q8_1_mmq { half2 ds[4]; @@ -141,15 +143,15 @@ template static __device__ __forceinlin } template -static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( +static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - const float * x_dmf = (const float *) x_dm; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; + const float * x_df = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -170,12 +172,76 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl - (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], + (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } +template +static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const float * x_df = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + mma_A A; + float dA[mma_C::ne/2]; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int k = k0 + mma_A::get_k(l) % QI4_0; + const int shift = 4*(mma_A::get_k(l) / QI4_0); + + A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808); + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + + dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; + } + + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C; + mma_B B; + half2 dsB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + + B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + + C.mma_K8(A, B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l]; + } + } +} + template static __device__ __forceinline__ void load_tiles_q4_1( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { @@ -215,7 +281,7 @@ template static __device__ __forceinlin } template -static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( +static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { @@ -249,6 +315,70 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( } } +template +static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + mma_A A; + half2 dmA[mma_C::ne/2]; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int k = k0 + mma_A::get_k(l) % QI4_0; + const int shift = 4*(mma_A::get_k(l) / QI4_0); + + A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + + dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; + } + + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C; + mma_B B; + half2 dsB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + + B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + + C.mma_K8(A, B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; + sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } + } +} + template static __device__ __forceinline__ void load_tiles_q5_0( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { @@ -308,7 +438,7 @@ template static __device__ __forceinlin } template -static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( +static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { @@ -343,6 +473,68 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( } } +template +static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const float * x_df = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + mma_A A; + float dA[mma_C::ne/2]; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0; + + A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + + dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0]; + } + + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C; + mma_B B; + float dB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + + B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + + C.mma_K8(A, B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l]; + } + } +} template static __device__ __forceinline__ void load_tiles_q5_1( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, @@ -400,7 +592,7 @@ template static __device__ __forceinlin } template -static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( +static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { @@ -434,6 +626,69 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( } } +template +static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + mma_A A; + half2 dmA[mma_C::ne/2]; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1; + + A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + + dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1]; + } + + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C; + mma_B B; + half2 dsB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + + B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + + C.mma_K8(A, B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; + sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } + } +} + template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { @@ -475,7 +730,7 @@ template static __device__ __forceinlin } template -static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { @@ -500,6 +755,69 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( } } +template +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const float * x_df = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + mma_A A; + float dA[mma_C::ne/2]; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int k = k0 + mma_A::get_k(l); + + A.x[l] = x_ql[i*(WARP_SIZE + 1) + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + + dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0]; + } + + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C; + mma_B B; + float dB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = k0 + mma_B::get_k(l); + + B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1]; + } + + C.mma_K8(A, B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2]; + } + } +} + template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { @@ -989,6 +1307,57 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( } } +template +static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; + + if (j >= ne1) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; + + if (need_check && i >= ne0) { + continue; + } + + dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + } + } +} + +template +static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { + typedef mma_int_C_I16J8 mma_C; + + const int i0 = threadIdx.y*mma_C::I; + static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l); + + if (j >= ne1) { + continue; + } + + const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l); + + if (need_check && i >= ne0) { + continue; + } + + dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l]; + } + } +} + // ------------------------------------------------------------------------------------------------------------------------------------- template @@ -998,35 +1367,65 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat; +#ifdef INT8_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mma; + static constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat; +#ifdef INT8_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mma; + static constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat; +#ifdef INT8_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mma; + static constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat; +#ifdef INT8_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mma; + static constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat; +#ifdef INT8_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mma; + static constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE }; template @@ -1034,6 +1433,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; }; template @@ -1041,6 +1441,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; }; template @@ -1048,6 +1449,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; }; template @@ -1055,6 +1457,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; }; template @@ -1062,6 +1465,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; }; static int mmq_need_sum(const ggml_type type_x) { @@ -1118,6 +1522,7 @@ static __global__ void mul_mat_q( constexpr int vdr = mmq_type_traits::vdr; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot; + constexpr mmq_write_back_t write_back = mmq_type_traits::write_back; constexpr tile_x_sizes txs = get_tile_x_sizes_device(type); @@ -1137,7 +1542,7 @@ static __global__ void mul_mat_q( const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); - float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f}; + float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { @@ -1164,25 +1569,7 @@ static __global__ void mul_mat_q( } } -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; - - if (j >= ne1) { - return; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; - - if (need_check && i >= ne0) { - continue; - } - - dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; - } - } + write_back(sum, dst, ne0, ne1); } struct mmq_args { @@ -1256,10 +1643,10 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { launch_mul_mat_q(args, stream); break; case 16: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 24: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 32: launch_mul_mat_q(args, stream);