diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4c50e2a..8b800e0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,4 +1,4 @@
-cmake_minimum_required (VERSION 3.0)
+cmake_minimum_required (VERSION 3.5)
project(whisper.cpp VERSION 1.4.2)
@@ -35,6 +35,12 @@ endif()
# options
+if (APPLE)
+ set(WHISPER_METAL_DEFAULT ON)
+else()
+ set(WHISPER_METAL_DEFAULT OFF)
+endif()
+
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
@@ -58,6 +64,8 @@ option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF)
if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
+ option(WHISPER_METAL "whisper: use Metal" ${WHISPER_METAL_DEFAULT})
+ option(WHISPER_METAL_NDEBUG "whisper: disable Metal debugging" OFF)
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
else()
@@ -113,6 +121,34 @@ if (APPLE)
endif()
endif()
+ if (WHISPER_METAL)
+ find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
+ find_library(METAL_FRAMEWORK Metal REQUIRED)
+ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
+
+ if (METAL_FRAMEWORK)
+ message(STATUS "Metal framework found")
+
+ set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS}
+ ${FOUNDATION_LIBRARY}
+ ${METAL_FRAMEWORK}
+ ${METALKIT_FRAMEWORK}
+ )
+ set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_METAL)
+
+ if (WHISPER_METAL_NDEBUG)
+ set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_METAL_NDEBUG)
+ endif()
+ else()
+ message(WARNING "Metal framework not found")
+ endif()
+
+ set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
+
+ # copy ggml-metal.metal to bin directory
+ configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY)
+ endif()
+
if (WHISPER_COREML)
find_library(FOUNDATION_FRAMEWORK Foundation)
find_library(COREML_FRAMEWORK CoreML)
@@ -177,7 +213,7 @@ if (WHISPER_CUBLAS)
enable_language(CUDA)
- set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
+ set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
add_compile_definitions(GGML_USE_CUBLAS)
@@ -228,7 +264,7 @@ if (WHISPER_CLBLAST)
if (CLBlast_FOUND)
message(STATUS "CLBlast found")
- set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h)
+ set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h)
add_compile_definitions(GGML_USE_CLBLAST)
@@ -426,8 +462,11 @@ set(TARGET whisper)
add_library(${TARGET}
ggml.h
ggml.c
- ${GGML_CUDA_SOURCES}
- ${GGML_OPENCL_SOURCES}
+ ggml-alloc.h
+ ggml-alloc.c
+ ${GGML_SOURCES_METAL}
+ ${GGML_SOURCES_CUDA}
+ ${GGML_SOURCES_OPENCL}
whisper.h
whisper.cpp
)
@@ -468,9 +507,15 @@ if (BUILD_SHARED_LIBS)
WHISPER_BUILD
GGML_BUILD
)
+
+ if (WHISPER_METAL)
+ # TODO: I think this should make ggml-metal.m "see" the ggml-metal.metal file from the "bin" directory
+ # but for some reason it does not work here like it does in llama.cpp
+ set_target_properties(${TARGET} PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
+ endif()
endif()
-if (GGML_CUDA_SOURCES)
+if (GGML_SOURCES_CUDA)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
@@ -486,10 +531,13 @@ target_compile_definitions(${TARGET} PUBLIC
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
+include(GNUInstallDirs)
+
install(TARGETS ${TARGET}
- LIBRARY DESTINATION lib
- ARCHIVE DESTINATION lib/static
- RUNTIME DESTINATION bin
+ LIBRARY DESTINATION lib
+ ARCHIVE DESTINATION lib/static
+ RUNTIME DESTINATION bin
+ RESOURCE DESTINATION bin
PUBLIC_HEADER DESTINATION include
)
diff --git a/Makefile b/Makefile
index ecbbcff..2df5111 100644
--- a/Makefile
+++ b/Makefile
@@ -18,7 +18,7 @@ ifndef NVCC_VERSION
endif
endif
-CCV := $(shell $(CC) --version | head -n 1)
+CCV := $(shell $(CC) --version | head -n 1)
CXXV := $(shell $(CXX) --version | head -n 1)
# Mac OS + Arm can report x86_64
@@ -182,6 +182,15 @@ ifdef WHISPER_COREML_ALLOW_FALLBACK
endif
endif
+ifndef WHISPER_NO_METAL
+ ifeq ($(UNAME_S),Darwin)
+ WHISPER_METAL := 1
+
+ CXXFLAGS += -DGGML_USE_METAL
+ LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
+ endif
+endif
+
ifdef WHISPER_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
LDFLAGS += -lopenblas
@@ -288,6 +297,11 @@ $(info )
ggml.o: ggml.c ggml.h ggml-cuda.h
$(CC) $(CFLAGS) -c $< -o $@
+ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
+ $(CC) $(CFLAGS) -c $< -o $@
+
+WHISPER_OBJ += ggml-alloc.o
+
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
$(CXX) $(CXXFLAGS) -c $< -o $@
@@ -303,6 +317,13 @@ whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-imp
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
endif
+ifdef WHISPER_METAL
+ggml-metal.o: ggml-metal.m ggml-metal.h
+ $(CC) $(CFLAGS) -c $< -o $@
+
+WHISPER_OBJ += ggml-metal.o
+endif
+
libwhisper.a: ggml.o $(WHISPER_OBJ)
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
diff --git a/README.md b/README.md
index 5f18060..3707b93 100644
--- a/README.md
+++ b/README.md
@@ -11,14 +11,14 @@ Beta: [v1.4.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.2) / S
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
- Plain C/C++ implementation without dependencies
-- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate framework and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
+- Apple Silicon first-class citizen - optimized via ARM NEON, Accelerate framework, Metal and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
- AVX intrinsics support for x86 architectures
- VSX intrinsics support for POWER architectures
- Mixed F16 / F32 precision
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
- Low memory usage (Flash Attention)
- Zero memory allocations at runtime
-- Runs on the CPU
+- Support for CPU-only inference
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
@@ -50,6 +50,10 @@ You can also easily make your own offline voice assistant application: [command]
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
+On Apply Silicon, the inference runs fully on the GPU via Metal:
+
+https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
+
Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
## Implementation details
diff --git a/bindings/ios b/bindings/ios
index de46d9e..22a9eef 160000
--- a/bindings/ios
+++ b/bindings/ios
@@ -1 +1 @@
-Subproject commit de46d9e7817fe851c109d66080239d415812d32a
+Subproject commit 22a9eef021afc67f2154bc9811ed620b26299d1b
diff --git a/coreml/whisper-encoder.mm b/coreml/whisper-encoder.mm
index 6cd90ed..499edae 100644
--- a/coreml/whisper-encoder.mm
+++ b/coreml/whisper-encoder.mm
@@ -22,7 +22,13 @@ struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
- const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]);
+ // select which device to run the Core ML model on
+ MLModelConfiguration *config = [[MLModelConfiguration alloc] init];
+ config.computeUnits = MLComputeUnitsCPUAndGPU;
+ //config.computeUnits = MLComputeUnitsCPUAndNeuralEngine;
+ //config.computeUnits = MLComputeUnitsAll;
+
+ const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model configuration:config error:nil]);
if (data == NULL) {
return NULL;
diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp
index 49daaa0..ac0e6bb 100644
--- a/examples/bench/bench.cpp
+++ b/examples/bench/bench.cpp
@@ -44,13 +44,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
- fprintf(stderr, " %-7s 0 - whisper encoder\n", "");
+ fprintf(stderr, " %-7s 0 - whisper\n", "");
fprintf(stderr, " %-7s 1 - memcpy\n", "");
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
fprintf(stderr, "\n");
}
-int whisper_bench_encoder(const whisper_params & params) {
+int whisper_bench_full(const whisper_params & params) {
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
@@ -69,12 +69,49 @@ int whisper_bench_encoder(const whisper_params & params) {
fprintf(stderr, "error: failed to set mel: %d\n", ret);
return 3;
}
-
+ // heat encoder
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
return 4;
}
+ whisper_token tokens[512];
+ memset(tokens, 0, sizeof(tokens));
+
+ // prompt heat
+ if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
+ fprintf(stderr, "error: failed to encode model: %d\n", ret);
+ return 4;
+ }
+
+ // text-generation heat
+ if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
+ fprintf(stderr, "error: failed to encode model: %d\n", ret);
+ return 4;
+ }
+
+ whisper_reset_timings(ctx);
+
+ // actual run
+ if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
+ fprintf(stderr, "error: failed to encode model: %d\n", ret);
+ return 4;
+ }
+
+ for (int i = 0; i < 16; i++) {
+ if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
+ fprintf(stderr, "error: failed to encode model: %d\n", ret);
+ return 4;
+ }
+ }
+
+ for (int i = 0; i < 256; i++) {
+ if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
+ fprintf(stderr, "error: failed to encode model: %d\n", ret);
+ return 4;
+ }
+ }
+
whisper_print_timings(ctx);
whisper_free(ctx);
@@ -103,7 +140,7 @@ int main(int argc, char ** argv) {
int ret = -1;
switch (params.what) {
- case 0: ret = whisper_bench_encoder(params); break;
+ case 0: ret = whisper_bench_full(params); break;
case 1: ret = whisper_bench_memcpy(params.n_threads); break;
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt
index cbdfb41..af5b547 100644
--- a/examples/talk-llama/CMakeLists.txt
+++ b/examples/talk-llama/CMakeLists.txt
@@ -7,7 +7,7 @@ if (WHISPER_SDL2)
# TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
- add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
+ add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../ggml-alloc.c ../../whisper.cpp)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
diff --git a/examples/whisper.android/app/src/main/jni/whisper/CMakeLists.txt b/examples/whisper.android/app/src/main/jni/whisper/CMakeLists.txt
index 55a4725..eac718a 100644
--- a/examples/whisper.android/app/src/main/jni/whisper/CMakeLists.txt
+++ b/examples/whisper.android/app/src/main/jni/whisper/CMakeLists.txt
@@ -8,6 +8,7 @@ set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../../)
set(
SOURCE_FILES
${WHISPER_LIB_DIR}/ggml.c
+ ${WHISPER_LIB_DIR}/ggml-alloc.c
${WHISPER_LIB_DIR}/whisper.cpp
${CMAKE_SOURCE_DIR}/jni.c
)
@@ -20,7 +21,7 @@ function(build_library target_name)
SHARED
${SOURCE_FILES}
)
-
+
target_link_libraries(${target_name} ${LOG_LIB} android)
if (${target_name} STREQUAL "whisper_v8fp16_va")
diff --git a/examples/whisper.objc/README.md b/examples/whisper.objc/README.md
index 6833ebb..bb55653 100644
--- a/examples/whisper.objc/README.md
+++ b/examples/whisper.objc/README.md
@@ -28,6 +28,8 @@ This can significantly improve the performance of the transcription:
+## Core ML
+
If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK` compiler flag for `whisper.cpp` in Build Phases:
@@ -35,3 +37,13 @@ If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DW
Then follow the [`Core ML support` section of readme](../../README.md#core-ml-support) for convert the model.
In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.
+
+## Metal
+
+You can also enable Metal to make the inference run on the GPU of your device. This might or might not be more efficient
+compared to Core ML depending on the model and device that you use.
+
+To enable Metal, just add `-DGGML_USE_METAL` instead off the `-DWHISPER_USE_COREML` flag and you are ready.
+This will make both the Encoder and the Decoder run on the GPU.
+
+If you want to run the Encoder with Core ML and the Decoder with Metal then simply add both `-DWHISPER_USE_COREML -DGGML_USE_METAL` flags. That's all!
diff --git a/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj b/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj
index 49bd74e..f34b9c5 100644
--- a/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj
+++ b/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj
@@ -7,6 +7,9 @@
objects = {
/* Begin PBXBuildFile section */
+ 1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 184447182AB211A2007D6BFE /* ggml-alloc.c */; };
+ 1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */ = {isa = PBXBuildFile; fileRef = 1844471B2AB21655007D6BFE /* ggml-metal.m */; settings = {COMPILER_FLAGS = "-framework Foundation -framework Metal -framework MetalKit -fno-objc-arc"; }; };
+ 184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */ = {isa = PBXBuildFile; fileRef = 1844471D2AB2195F007D6BFE /* ggml-metal.metal */; };
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7A29052BDF00BD2A04 /* AppDelegate.m */; };
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7D29052BDF00BD2A04 /* SceneDelegate.m */; };
18627C8129052BDF00BD2A04 /* ViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8029052BDF00BD2A04 /* ViewController.m */; };
@@ -14,7 +17,7 @@
18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; };
18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; };
18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; };
- 18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK"; }; };
+ 18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML"; }; };
18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
@@ -23,7 +26,24 @@
7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */ = {isa = PBXBuildFile; fileRef = 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */; };
/* End PBXBuildFile section */
+/* Begin PBXCopyFilesBuildPhase section */
+ 184447202AB21B25007D6BFE /* CopyFiles */ = {
+ isa = PBXCopyFilesBuildPhase;
+ buildActionMask = 2147483647;
+ dstPath = "";
+ dstSubfolderSpec = 7;
+ files = (
+ 184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */,
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ };
+/* End PBXCopyFilesBuildPhase section */
+
/* Begin PBXFileReference section */
+ 184447182AB211A2007D6BFE /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-alloc.c"; path = "../../../ggml-alloc.c"; sourceTree = ""; };
+ 184447192AB211A2007D6BFE /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-alloc.h"; path = "../../../ggml-alloc.h"; sourceTree = ""; };
+ 1844471B2AB21655007D6BFE /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = "ggml-metal.m"; path = "../../../ggml-metal.m"; sourceTree = ""; };
+ 1844471D2AB2195F007D6BFE /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; name = "ggml-metal.metal"; path = "../../../ggml-metal.metal"; sourceTree = ""; };
18627C7629052BDF00BD2A04 /* whisper.objc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.objc.app; sourceTree = BUILT_PRODUCTS_DIR; };
18627C7929052BDF00BD2A04 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; };
18627C7A29052BDF00BD2A04 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = ""; };
@@ -80,6 +100,10 @@
18627C7829052BDF00BD2A04 /* whisper.objc */ = {
isa = PBXGroup;
children = (
+ 1844471D2AB2195F007D6BFE /* ggml-metal.metal */,
+ 1844471B2AB21655007D6BFE /* ggml-metal.m */,
+ 184447182AB211A2007D6BFE /* ggml-alloc.c */,
+ 184447192AB211A2007D6BFE /* ggml-alloc.h */,
7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */,
7FE342442A0C3FA20015A058 /* coreml */,
18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */,
@@ -126,6 +150,7 @@
18627C7229052BDF00BD2A04 /* Sources */,
18627C7329052BDF00BD2A04 /* Frameworks */,
18627C7429052BDF00BD2A04 /* Resources */,
+ 184447202AB21B25007D6BFE /* CopyFiles */,
);
buildRules = (
);
@@ -194,8 +219,10 @@
18627C9629052C5800BD2A04 /* ggml.c in Sources */,
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
+ 1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */,
18627C8C29052BE000BD2A04 /* main.m in Sources */,
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
+ 1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */,
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
diff --git a/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj b/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj
index ab9f688..d2d0b05 100644
--- a/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj
+++ b/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj
@@ -20,6 +20,7 @@
0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
+ 18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 18AED47F2AB21F2B009D854F /* ggml-alloc.c */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
@@ -41,6 +42,8 @@
0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = ""; };
0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = ""; };
0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = ""; };
+ 18AED47F2AB21F2B009D854F /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-alloc.c"; sourceTree = ""; };
+ 18AED4802AB21F2B009D854F /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-alloc.h"; sourceTree = ""; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
@@ -124,6 +127,8 @@
0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
isa = PBXGroup;
children = (
+ 18AED47F2AB21F2B009D854F /* ggml-alloc.c */,
+ 18AED4802AB21F2B009D854F /* ggml-alloc.h */,
0AAC5DC929539EB0003032C3 /* ggml.c */,
0AAC5DCA29539EB0003032C3 /* ggml.h */,
0AAC5DC729539EB0003032C3 /* whisper.cpp */,
@@ -242,6 +247,7 @@
0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
+ 18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@@ -369,7 +375,7 @@
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
- DEVELOPMENT_TEAM = 3TZ9BM962G;
+ DEVELOPMENT_TEAM = P8JZH34X63;
ENABLE_HARDENED_RUNTIME = YES;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
@@ -410,7 +416,7 @@
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
- DEVELOPMENT_TEAM = 3TZ9BM962G;
+ DEVELOPMENT_TEAM = P8JZH34X63;
ENABLE_HARDENED_RUNTIME = YES;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
diff --git a/extra/bench-all.sh b/extra/bench-all.sh
index 43f989d..352a223 100755
--- a/extra/bench-all.sh
+++ b/extra/bench-all.sh
@@ -44,27 +44,26 @@ if [ "$encoder_only" -eq 0 ]; then
printf "\n"
fi
-printf "| CPU | OS | Config | Model | Th | Load | Enc. | Commit |\n"
-printf "| --- | -- | ------ | ----- | -- | ---- | ---- | ------ |\n"
+printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
+printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
for model in "${models[@]}"; do
- # run once to heat-up the cache
- ./bench -m ./models/ggml-$model.bin -t $n_threads 2>/dev/null 1>/dev/null
-
# actual run
# store stderr output in a variable in order to parse it later
output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1)
ret=$?
# parse the output:
- load_time=$(echo "$output" | grep "load time" | awk '{print $5}')
- encode_time=$(echo "$output" | grep "encode time" | awk '{print $5}')
+ encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
+ decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
+ prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
system_info=$(echo "$output" | grep "system_info")
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
# floor to milliseconds
- load_time=${load_time%.*}
- encode_time=${encode_time%.*}
+ #encode_time=${encode_time%.*}
+ #decode_time=${decode_time%.*}
+ #prompt_time=${prompt_time%.*}
config=""
@@ -87,6 +86,6 @@ for model in "${models[@]}"; do
commit=$(git rev-parse --short HEAD)
if [ $ret -eq 0 ]; then
- printf "| | | $config | $model | $n_threads | $load_time | $encode_time | $commit |\n"
+ printf "| | | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
fi
done
diff --git a/extra/sync-ggml.sh b/extra/sync-ggml.sh
index 3bd99e3..0070e9e 100755
--- a/extra/sync-ggml.sh
+++ b/extra/sync-ggml.sh
@@ -1,18 +1,20 @@
#!/bin/bash
-cp -rpv ../ggml/src/ggml.c ./ggml.c
-cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
-cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
-cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
-cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
-cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
-cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
-cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
-cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
-cp -rpv ../ggml/examples/common.h ./examples/common.h
-cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
-cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
-cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp
+cp -rpv ../ggml/src/ggml.c ./ggml.c
+cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c
+cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
+cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
+cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
+cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
+cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
+cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
+cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
+cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
+cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
+cp -rpv ../ggml/examples/common.h ./examples/common.h
+cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
+cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
+cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp
cp -rpv ../ggml/examples/whisper/whisper.h ./whisper.h
cp -rpv ../ggml/examples/whisper/whisper.cpp ./whisper.cpp
diff --git a/ggml-alloc.c b/ggml-alloc.c
index 856a4cd..304964b 100644
--- a/ggml-alloc.c
+++ b/ggml-alloc.c
@@ -6,6 +6,26 @@
#include
#include
+#ifdef __has_include
+ #if __has_include()
+ #include
+ #if defined(_POSIX_MAPPED_FILES)
+ #include
+ #include
+ #endif
+ #endif
+#endif
+
+#if defined(_WIN32)
+ #define WIN32_LEAN_AND_MEAN
+ #ifndef NOMINMAX
+ #define NOMINMAX
+ #endif
+ #include
+ #include
+#endif
+
+
#define UNUSED(x) (void)(x)
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
@@ -99,15 +119,28 @@ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tens
}
#endif
-
-static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+static size_t ggml_allocr_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
return ggml_nbytes(tensor);
UNUSED(alloc);
}
+// check if a tensor is allocated by this buffer
+static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) {
+ void * ptr = tensor->data;
+ return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
+}
+
+static bool ggml_is_view(struct ggml_tensor * t) {
+ return t->view_src != NULL;
+}
+
void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
- size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
+#ifdef GGML_ALLOCATOR_DEBUG
+ GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
+ GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
+#endif
+ size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
@@ -131,14 +164,14 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
if (best_fit_block == -1) {
// the last block is our last resort
struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
+ max_avail = MAX(max_avail, block->size);
if (block->size >= size) {
best_fit_block = alloc->n_free_blocks - 1;
- max_avail = MAX(max_avail, block->size);
} else {
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
__func__, size, max_avail);
GGML_ASSERT(!"not enough space in the buffer");
- return;
+ return;
}
}
struct free_block * block = &alloc->free_blocks[best_fit_block];
@@ -173,17 +206,17 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
}
// this is a very naive implementation, but for our case the number of free blocks should be very small
-static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
void * ptr = tensor->data;
- if (ptr < alloc->data || (char*)ptr >= (char*)alloc->data + alloc->max_size) {
+ if (ggml_allocr_is_own(alloc, tensor) == false) {
// the tensor was not allocated in this buffer
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
// the easiest way to deal with this is just to ignore it
return;
}
- size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
+ size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
@@ -277,17 +310,68 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment)
return alloc;
}
-// address and size of the buffer when measuring
-// it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers
-static void * const MEASURE_BASE_ADDR = (void *) 0x1000;
-static const size_t MEASURE_MAX_SIZE = 1ULL<<40; // 1 TB
+// OS specific functions to allocate and free uncommitted virtual memory
+static void * alloc_vmem(size_t size) {
+#if defined(_WIN32)
+ return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS);
+#elif defined(_POSIX_MAPPED_FILES)
+ void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0);
+ if (ptr == MAP_FAILED) {
+ return NULL;
+ }
+ return ptr;
+#else
+ // use a fixed address for other platforms
+ uintptr_t base_addr = (uintptr_t)-size - 0x100;
+ return (void *)base_addr;
+#endif
+}
+
+static void free_vmem(void * base_addr, size_t size) {
+#if defined(_WIN32)
+ VirtualFree(base_addr, 0, MEM_RELEASE);
+ UNUSED(size);
+#elif defined(_POSIX_MAPPED_FILES)
+ munmap(base_addr, size);
+#else
+ // nothing to do
+ UNUSED(base_addr);
+ UNUSED(size);
+#endif
+}
+
+// allocate uncommitted virtual memory to measure the size of the graph
+static void alloc_measure_vmem(void ** base_addr, size_t * size) {
+ // 128GB for 64-bit, 1GB for 32-bit
+ *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37;
+ do {
+ *base_addr = alloc_vmem(*size);
+ if (*base_addr != NULL) {
+ AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr);
+ return;
+ }
+ // try again with half the size
+ *size /= 2;
+ } while (*size > 0);
+
+ GGML_ASSERT(!"failed to allocate virtual memory for measure buffer");
+}
+
+static void free_measure_vmem(void * base_addr, size_t size) {
+ free_vmem(base_addr, size);
+}
struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
+ void * base_addr;
+ size_t size;
+
+ alloc_measure_vmem(&base_addr, &size);
+
*alloc = (struct ggml_allocr){
- /*.data = */ MEASURE_BASE_ADDR,
- /*.size = */ MEASURE_MAX_SIZE,
+ /*.data = */ base_addr,
+ /*.size = */ size,
/*.alignment = */ alignment,
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
@@ -307,6 +391,9 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
}
void ggml_allocr_free(struct ggml_allocr * alloc) {
+ if (alloc->measure) {
+ free_measure_vmem(alloc->data, alloc->size);
+ }
free(alloc);
}
@@ -316,11 +403,6 @@ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
//////////// compute graph allocator
-static bool ggml_is_view(struct ggml_tensor * t) {
- return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
- t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
-}
-
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
if (a->type != b->type) {
return false;
@@ -336,28 +418,6 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
return true;
}
-static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
- switch (t->op) {
- case GGML_OP_PERMUTE:
- case GGML_OP_RESHAPE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_VIEW:
- return t->src[0];
- case GGML_OP_CPY:
- return t->src[1];
- default:
- return NULL;
- }
-}
-
-static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
- struct ggml_tensor * parent = t;
- do {
- parent = get_view_parent(parent);
- } while (ggml_is_view(parent));
- return parent;
-}
-
static bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) {
case GGML_OP_SCALE:
@@ -365,7 +425,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ADD:
case GGML_OP_ADD1:
- case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
@@ -375,10 +434,8 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_UNARY:
case GGML_OP_ROPE:
case GGML_OP_RMS_NORM:
- case GGML_OP_SET:
case GGML_OP_SOFT_MAX:
case GGML_OP_CONT:
- case GGML_OP_ADD_REL_POS:
return true;
default:
@@ -390,24 +447,8 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
struct hash_node * ht = alloc->hash_table;
if (node->data == NULL) {
if (ggml_is_view(node)) {
- size_t offset;
- switch(node->op) {
- case GGML_OP_VIEW:
- memcpy(&offset, node->op_params, sizeof(size_t));
- node->data = (char *) node->src[0]->data + offset;
- break;
- case GGML_OP_PERMUTE:
- case GGML_OP_RESHAPE:
- case GGML_OP_TRANSPOSE:
- node->data = node->src[0]->data;
- break;
- case GGML_OP_CPY:
- node->data = node->src[1]->data;
- break;
- default:
- GGML_ASSERT(!"unknown view op");
- break;
- }
+ assert(node->view_src->data != NULL);
+ node->data = (char *)node->view_src->data + node->view_offs;
} else {
// see if we can reuse a parent's buffer (inplace)
if (ggml_op_can_inplace(node->op)) {
@@ -418,8 +459,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
}
// if the node's data is external, then we cannot re-use it
- if ((char *) parent->data < (char *) alloc->data ||
- (char *) parent->data >= ((char *) alloc->data + alloc->size)) {
+ if (ggml_allocr_is_own(alloc, parent) == false) {
AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
continue;
}
@@ -427,7 +467,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
struct hash_node * p_hn = hash_get(ht, parent);
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
if (ggml_is_view(parent)) {
- struct ggml_tensor * view_src = get_view_source(parent);
+ struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = hash_get(ht, view_src);
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
@@ -453,7 +493,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
}
}
-static size_t ggml_allocator_alloc_graph_tensors_n(
+static size_t ggml_allocr_alloc_graph_tensors_n(
struct ggml_allocr * alloc,
struct ggml_cgraph ** graphs, int n_graphs,
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
@@ -469,7 +509,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
struct ggml_tensor * node = gf->nodes[i];
if (ggml_is_view(node)) {
- struct ggml_tensor * view_src = get_view_source(node);
+ struct ggml_tensor * view_src = node->view_src;
hash_get(ht, view_src)->n_views += 1;
}
@@ -531,11 +571,10 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
AT_PRINTF("\n");
}
-
// update parents
// update immediately if there is no parse_seq
// update only at barriers if there is parse_seq
- if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] == -1) {
+ if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) {
int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
int update_end = alloc->parse_seq_len ? ind : ind + 1;
for (int i = update_start; i < update_end; i++) {
@@ -554,17 +593,17 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
if (ggml_is_view(parent)) {
- struct ggml_tensor * view_src = get_view_source(parent);
+ struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = hash_get(ht, view_src);
view_src_hn->n_views -= 1;
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
- ggml_allocator_free_tensor(alloc, view_src);
+ ggml_allocr_free_tensor(alloc, view_src);
}
}
else {
if (parent->data != node->data) {
- ggml_allocator_free_tensor(alloc, parent);
+ ggml_allocr_free_tensor(alloc, parent);
}
}
}
@@ -581,7 +620,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
for (int i = 0; outputs[g][i] != NULL; i++) {
struct ggml_tensor * output = outputs[g][i];
AT_PRINTF("output: %s\n", output->name);
- ggml_allocator_free_tensor(alloc, output);
+ ggml_allocr_free_tensor(alloc, output);
}
}
}
@@ -590,5 +629,5 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
}
size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
- return ggml_allocator_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
+ return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
}
diff --git a/ggml-metal.m b/ggml-metal.m
index 7e2355c..b438b83 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -63,7 +63,10 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(gelu);
GGML_METAL_DECL_KERNEL(soft_max);
+ GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
+ GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
+ GGML_METAL_DECL_KERNEL(get_rows_f32);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@@ -77,6 +80,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
+ GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -117,14 +121,17 @@ static NSString * const msl_library_source = @"see metal.metal";
struct ggml_metal_context * ggml_metal_init(int n_cb) {
metal_printf("%s: allocating\n", __func__);
- // Show all the Metal device instances in the system
- NSArray * devices = MTLCopyAllDevices();
id device;
NSString * s;
+
+#if TARGET_OS_OSX
+ // Show all the Metal device instances in the system
+ NSArray * devices = MTLCopyAllDevices();
for (device in devices) {
s = [device name];
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
}
+#endif
// Pick and show default Metal device
device = MTLCreateSystemDefaultDevice();
@@ -139,14 +146,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->n_buffers = 0;
ctx->concur_list_len = 0;
- ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
+ ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
-#if 0
- // compile from source string and show compile log
+#ifdef GGML_SWIFT
+ // load the default.metallib file
{
NSError * error = nil;
- ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
+ NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
+ NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
+ NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
+ NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
+ NSURL * libURL = [NSURL fileURLWithPath:libPath];
+
+ // Load the metallib file into a Metal library
+ ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
+
if (error) {
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
@@ -161,7 +176,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+ NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
@@ -207,7 +222,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(gelu);
GGML_METAL_ADD_KERNEL(soft_max);
+ GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
+ GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
+ GGML_METAL_ADD_KERNEL(get_rows_f32);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@@ -221,6 +239,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
+ GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -247,13 +266,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
#undef GGML_METAL_ADD_KERNEL
}
- metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+#if TARGET_OS_OSX
+ metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
if (ctx->device.maxTransferRate != 0) {
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
} else {
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
}
+#endif
return ctx;
}
@@ -273,7 +294,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(relu);
GGML_METAL_DEL_KERNEL(gelu);
GGML_METAL_DEL_KERNEL(soft_max);
+ GGML_METAL_DEL_KERNEL(soft_max_4);
GGML_METAL_DEL_KERNEL(diag_mask_inf);
+ GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
+ GGML_METAL_DEL_KERNEL(get_rows_f32);
GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@@ -287,6 +311,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -365,6 +390,7 @@ static id ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
+ //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
*offs = (size_t) ioffs;
@@ -454,6 +480,7 @@ bool ggml_metal_add_buffer(
}
}
+#if TARGET_OS_OSX
metal_printf(", (%8.2f / %8.2f)",
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
@@ -463,6 +490,9 @@ bool ggml_metal_add_buffer(
} else {
metal_printf("\n");
}
+#else
+ metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
+#endif
}
return true;
@@ -698,6 +728,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ADD:
{
GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
// utilize float4
GGML_ASSERT(ne00 % 4 == 0);
@@ -705,6 +736,7 @@ void ggml_metal_graph_compute(
if (ggml_nelements(src1) == ne10) {
// src1 is a row
+ GGML_ASSERT(ne11 == 1);
[encoder setComputePipelineState:ctx->pipeline_add_row];
} else {
[encoder setComputePipelineState:ctx->pipeline_add];
@@ -721,6 +753,7 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL:
{
GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
// utilize float4
GGML_ASSERT(ne00 % 4 == 0);
@@ -728,6 +761,7 @@ void ggml_metal_graph_compute(
if (ggml_nelements(src1) == ne10) {
// src1 is a row
+ GGML_ASSERT(ne11 == 1);
[encoder setComputePipelineState:ctx->pipeline_mul_row];
} else {
[encoder setComputePipelineState:ctx->pipeline_mul];
@@ -743,6 +777,8 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_SCALE:
{
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
const float scale = *(const float *) src1->data;
[encoder setComputePipelineState:ctx->pipeline_scale];
@@ -750,7 +786,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
- const int64_t n = ggml_nelements(dst);
+ const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@@ -762,7 +798,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- const int64_t n = ggml_nelements(dst);
+ const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@@ -782,7 +818,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- const int64_t n = ggml_nelements(dst);
+ const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@@ -796,13 +832,16 @@ void ggml_metal_graph_compute(
{
const int nth = 32;
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
+ if (ne00%4 == 0) {
+ [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
+ } else {
+ [encoder setComputePipelineState:ctx->pipeline_soft_max];
+ }
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
@@ -810,14 +849,23 @@ void ggml_metal_graph_compute(
{
const int n_past = ((int32_t *)(dst->op_params))[0];
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+ if (ne00%8 == 0) {
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
+ } else {
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+ }
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ if (ne00%8 == 0) {
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
+ else {
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
} break;
case GGML_OP_MUL_MAT:
{
@@ -830,8 +878,8 @@ void ggml_metal_graph_compute(
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- if (ggml_is_contiguous(src0) &&
- ggml_is_contiguous(src1) &&
+ if (!ggml_is_transposed(src0) &&
+ !ggml_is_transposed(src1) &&
src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00%32 == 0 &&
@@ -856,14 +904,18 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
int nth0 = 32;
int nth1 = 1;
+ int nrows = 1;
// use custom matrix x vector kernel
switch (src0t) {
@@ -873,8 +925,14 @@ void ggml_metal_graph_compute(
nth1 = 1;
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
+ //} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+ } else if (false) {
+ // TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
+ nrows = ne11;
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+ nrows = 4;
}
} break;
case GGML_TYPE_Q4_0:
@@ -995,7 +1053,7 @@ void ggml_metal_graph_compute(
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
- int64_t ny = (ne11 + 3)/4;
+ int64_t ny = (ne11 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
@@ -1003,6 +1061,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GET_ROWS:
{
switch (src0->type) {
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
@@ -1018,9 +1077,9 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
- [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
const int64_t n = ggml_nelements(src1);
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 5070561..0db037c 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -38,7 +38,7 @@ kernel void kernel_add_row(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant int64_t & nb,
+ constant int64_t & nb,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
}
kernel void kernel_scale(
- device const float * src0,
- device float * dst,
+ device const float4 * src0,
+ device float4 * dst,
constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}
kernel void kernel_silu(
- device const float * src0,
- device float * dst,
+ device const float4 * src0,
+ device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
- float x = src0[tpig];
+ device const float4 & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x));
}
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
kernel void kernel_gelu(
- device const float * src0,
- device float * dst,
+ device const float4 * src0,
+ device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
- float x = src0[tpig];
+ device const float4 & x = src0[tpig];
// BEWARE !!!
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
- threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
@@ -119,64 +118,70 @@ kernel void kernel_soft_max(
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
- buf[tpitg[0]] = -INFINITY;
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
- buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
+ float lmax = psrc0[tpitg[0]];
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
+ lmax = MAX(lmax, psrc0[i00]);
}
-
- // reduce
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
- if (tpitg[0] < i) {
- buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
-
- //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
- // the loop, and when that is done, buf[0] has the correct (synchronized) value
- //if (tpitg[0] == 0) {
- // buf[0] = buf[0];
- //}
-
- //threadgroup_barrier(mem_flags::mem_threadgroup);
-
- const float max = buf[0];
+ const float max = simd_max(lmax);
// parallel sum
- buf[tpitg[0]] = 0.0f;
+ float lsum = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
const float exp_psrc0 = exp(psrc0[i00] - max);
- buf[tpitg[0]] += exp_psrc0;
+ lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice.
pdst[i00] = exp_psrc0;
}
- // reduce
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
- if (tpitg[0] < i) {
- buf[tpitg[0]] += buf[tpitg[0] + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
-
- // broadcast - not needed, see above
- //// broadcast
- //if (tpitg[0] == 0) {
- // buf[0] = buf[0];
- //}
-
- //threadgroup_barrier(mem_flags::mem_threadgroup);
-
- const float sum = buf[0];
+ const float sum = simd_sum(lsum);
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
pdst[i00] /= sum;
}
}
+kernel void kernel_soft_max_4(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+ // parallel max
+ float4 lmax4 = psrc4[tpitg[0]];
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ lmax4 = fmax(lmax4, psrc4[i00]);
+ }
+ float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+
+ const float max = simd_max(lmax);
+
+ // parallel sum
+ float4 lsum4 = 0.0f;
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ const float4 exp_psrc4 = exp(psrc4[i00] - max);
+ lsum4 += exp_psrc4;
+ pdst4[i00] = exp_psrc4;
+ }
+ float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+ const float sum = simd_sum(lsum);
+
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ pdst4[i00] /= sum;
+ }
+}
+
kernel void kernel_diag_mask_inf(
device const float * src0,
device float * dst,
@@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
} else {
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+ }
+}
+
+kernel void kernel_diag_mask_inf_8(
+ device const float4 * src0,
+ device float4 * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int & n_past,
+ uint3 tpig[[thread_position_in_grid]]) {
+
+ const int64_t i = 2*tpig[0];
+
+ dst[i+0] = src0[i+0];
+ dst[i+1] = src0[i+1];
+ int64_t i4 = 4*i;
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
+ const int64_t i00 = i4;
+ for (int k = 3; k >= 0; --k) {
+ if (i00 + 4 + k <= n_past + i01) {
+ break;
+ }
+ dst[i+1][k] = -INFINITY;
+ if (i00 + k > n_past + i01) {
+ dst[i][k] = -INFINITY;
+ }
}
}
@@ -616,6 +648,49 @@ kernel void kernel_mul_mat_f16_f32(
}
}
+// Assumes row size (ne00) is a multiple of 4
+kernel void kernel_mul_mat_f16_f32_l4(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int nrows = ne11;
+ const int64_t r0 = tgpig.x;
+ const int64_t im = tgpig.z;
+
+ device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+ for (int r1 = 0; r1 < nrows; ++r1) {
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+}
+
kernel void kernel_alibi_f32(
device const float * src0,
device float * dst,
@@ -1123,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32(
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
- float yl[16];
+ float yl[32];
- const uint16_t kmask1 = 0x0303;
+ const uint16_t kmask1 = 0x3030;
const uint16_t kmask2 = 0x0f0f;
- const int tid = tiisg/2;
- const int ix = tiisg%2;
- const int ip = tid/8; // 0 or 1
- const int il = tid/2 - 4*ip; // 0...3
+ const int tid = tiisg/4;
+ const int ix = tiisg%4;
+ const int ip = tid/4; // 0 or 1
+ const int il = 2*((tid%4)/2); // 0 or 2
const int ir = tid%2;
const int n = 8;
const int l0 = n*ir;
- const uint16_t m1 = 1 << (4*ip + il);
- const uint16_t m2 = m1 << 8;
+ // One would think that the Metal compiler would figure out that ip and il can only have
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
+ // with these two tales.
+ //
+ // Possible masks for the high bit
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
+
+ // Possible masks for the low 2 bits
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
+
+ const ushort4 hm = mm[2*ip + il/2];
const int shift = 2*il;
- const uint16_t qm1 = 0x0003 << shift;
- const uint16_t qm2 = 0x0300 << shift;
- const int32_t v1 = 4 << shift;
- const int32_t v2 = 1024 << shift;
+ const float v1 = il == 0 ? 4.f : 64.f;
+ const float v2 = 4.f * v1;
const uint16_t s_shift1 = 4*ip;
- const uint16_t s_shift2 = s_shift1 + 2*(il/2);
- const int ik = 4 + (il%2);
+ const uint16_t s_shift2 = s_shift1 + il;
const int q_offset = 32*ip + l0;
const int y_offset = 128*ip + 32*il + l0;
@@ -1156,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32(
device const float * y1 = yy + ix*QK_K + y_offset;
- float sumf1[2] = {0.f}, sumf2[2] = {0.f};
- for (int i = ix; i < nb; i += 2) {
+ uint32_t scales32, aux32;
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
+
+ float sumf1[2] = {0.f};
+ float sumf2[2] = {0.f};
+ for (int i = ix; i < nb; i += 4) {
for (int l = 0; l < 8; ++l) {
- yl[l+0] = y1[l+ 0];
- yl[l+8] = y1[l+16];
+ yl[l+ 0] = y1[l+ 0];
+ yl[l+ 8] = y1[l+16];
+ yl[l+16] = y1[l+32];
+ yl[l+24] = y1[l+48];
}
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1172,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32(
for (int row = 0; row < 2; ++row) {
const float d_all = (float)dh[0];
- const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
- float s1 = 0, s2 = 0;
- for (int l = 0; l < n; l += 2) {
- const uint16_t qs = q[l/2];
- s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
- s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
- }
- float d = d_all * (s1 + 1.f/256.f * s2);
- sumf1[row] += d * scales[0];
- sumf2[row] += d;
+ scales16[0] = a[4];
+ scales16[1] = a[5];
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
+ scales16[0] = a[il+0];
+ scales16[1] = a[il+1];
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
- s1 = s2 = 0;
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
for (int l = 0; l < n; l += 2) {
- const uint16_t qs = q[l/2+8];
- s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
- s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
+ const int32_t qs = q[l/2];
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
}
- d = d_all * (s1 + 1.f/256.f * s2);
- sumf1[row] += d * scales[1];
- sumf2[row] += d;
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+ sumf1[row] += d1 * (scales[0] - 32);
+ sumf2[row] += d2 * (scales[2] - 32);
+
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
+ for (int l = 0; l < n; l += 2) {
+ const int32_t qs = q[l/2+8];
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
+ }
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+ sumf1[row] += d1 * (scales[1] - 32);
+ sumf2[row] += d2 * (scales[3] - 32);
q += step;
h += step;
@@ -1201,15 +1308,17 @@ kernel void kernel_mul_mat_q3_K_f32(
}
- y1 += 2 * QK_K;
+ y1 += 4 * QK_K;
}
for (int row = 0; row < 2; ++row) {
- const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
- const float tot = simd_sum(sumf);
- if (tiisg == 0) {
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
+ sumf1[row] = simd_sum(sumf);
+ }
+ if (tiisg == 0) {
+ for (int row = 0; row < 2; ++row) {
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
}
}
}
@@ -1564,17 +1673,25 @@ kernel void kernel_mul_mat_q5_K_f32(
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
- float4 acc = {0.f, 0.f, 0.f, 0.f};
+ float4 acc1 = {0.f};
+ float4 acc2 = {0.f};
for (int l = 0; l < n; ++l) {
uint8_t h = qh[l];
- acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
- acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
- acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
- acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
}
const float dall = dh[0];
const float dmin = dh[1];
- sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
q1 += step;
@@ -1747,6 +1864,15 @@ kernel void kernel_mul_mat_q6_K_f32(
//============================= templates and their specializations =============================
+// NOTE: this is not dequantizing - we are simply fitting the template
+template
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+ float4x4 temp = *(((device float4x4 *)src));
+ for (int i = 0; i < 16; i++){
+ reg[i/4][i%4] = temp[i/4][i%4];
+ }
+}
+
template
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
half4x4 temp = *(((device half4x4 *)src));
@@ -1758,28 +1884,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
template
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
- const half d = il ? (xb->d / 16.h) : xb->d;
- const half m = il ? ( -8.h * 16.h) : -8.h;
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = il ? 0xF000 : 0x0F00;
+ const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) {
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
}
}
template
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
- const half d = il ? (xb->d / 16.h) : xb->d;
- const half m = xb->m;
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float m = xb->m;
const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = il ? 0xF000 : 0x0F00;
+ const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) {
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
}
}
@@ -1815,7 +1943,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
template
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
- const float d_all = (float)(xb->d);
+ const half d_all = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs;
device const uint8_t * h = (device const uint8_t *)xb->hmask;
device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1828,16 +1956,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
((il/4)>0 ? 12 : 3);
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
- (scale_2&kmask2) | ((scale_1&kmask1) << 4);
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+ half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
+ const half ml = 4.h * dl;
- il = (il/2)%4;
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ il = (il/2) & 3;
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl *= coef;
for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
}
#else
float kcoef = il&1 ? 1.f/16.f : 1.f;
@@ -1852,26 +1982,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
#endif
}
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
+}
+
template
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
- device const uint8_t * q = xb->qs;
+ device const uchar * q = xb->qs;
#if QK_K == 256
- const float d = (float)(xb->d);
- const float min = (float)(xb->dmin);
short is = (il/4) * 2;
q = q + (il/4) * 32 + 16 * (il&1);
- il = il%4;
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
- const float ml = il<2 ? min * sc[1] : min * sc[3];
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
+ const half min = xb->dmin;
+ const half dl = d * sc[0];
+ const half ml = min * sc[1];
#else
q = q + 16 * (il&1);
device const uint8_t * s = xb->scales;
device const half2 * dh = (device const half2 *)xb->d;
const float2 d = (float2)dh[0];
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
#endif
const ushort mask = il<2 ? 0x0F : 0xF0;
for (int i = 0; i < 16; ++i) {
@@ -1885,19 +2020,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
device const uint8_t * qh = xb->qh;
#if QK_K == 256
- const float d = (float)(xb->d);
- const float min = (float)(xb->dmin);
short is = (il/4) * 2;
q = q + 32 * (il/4) + 16 * (il&1);
qh = qh + 16 * (il&1);
uint8_t ul = 1 << (il/2);
- il = il%4;
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
- const float ml = il<2 ? min * sc[1] : min * sc[3];
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
+ const half min = xb->dmin;
+ const half dl = d * sc[0];
+ const half ml = min * sc[1];
- const ushort mask = il<2 ? 0x0F : 0xF0;
- const float qh_val = il<2 ? 16.f : 256.f;
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ const half qh_val = il<2 ? 16.h : 256.h;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
}
@@ -1916,7 +2051,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
template
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
- const float d_all = (float)(xb->d);
+ const half d_all = xb->d;
device const uint8_t * ql = (device const uint8_t *)xb->ql;
device const uint8_t * qh = (device const uint8_t *)xb->qh;
device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1924,19 +2059,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
#if QK_K == 256
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
qh = qh + 32*(il/8) + 16*(il&1);
- float sc = scales[(il%2) + 2 * ((il/2))];
- il = (il/2)%4;
+ half sc = scales[(il%2) + 2 * ((il/2))];
+ il = (il/2) & 3;
#else
ql = ql + 16 * (il&1);
- float sc = scales[il];
+ half sc = scales[il];
#endif
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
+ const half coef = il>1 ? 1.f/16.h : 1.h;
+ const half ml = d_all * sc * 32.h;
+ const half dl = d_all * sc * coef;
for (int i = 0; i < 16; ++i) {
- uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
- uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
- const float coef = il>1 ? 1.f/16.f : 1.f;
- float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
- ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
- reg[i/4][i%4] = d_all * sc * q * coef;
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+ reg[i/4][i%4] = dl * q - ml;
}
}
@@ -1976,22 +2113,25 @@ kernel void kernel_get_rows(
// each block_q contains 16*nl weights
template
kernel void kernel_mul_mm(device const uchar * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant int64_t & nb01,
- constant int64_t & nb02,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & gqa,
- threadgroup uchar * shared_memory [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const uchar * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant int64_t & nb01,
+ constant int64_t & nb02,
+ constant int64_t & ne12,
+ constant int64_t & nb10,
+ constant int64_t & nb11,
+ constant int64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & gqa,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
- threadgroup half * sa = ((threadgroup half *)shared_memory);
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
const uint r0 = tgpig.y;
@@ -2004,7 +2144,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
- simdgroup_half8x8 ma[4];
+ simdgroup_half8x8 ma[4];
simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8];
for (int i = 0; i < 8; i++){
@@ -2012,10 +2152,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
}
short il = (tiitg % THREAD_PER_ROW);
- uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
- device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
- + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
+
+ uint offset0 = im/gqa*nb02;
+ ushort offset1 = il/nl;
+
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+ device const float * y = (device const float *)(src1
+ + nb12 * im
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
//load data and store to threadgroup memory
@@ -2095,6 +2240,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
constant uint64_t &, constant uint64_t &, uint, uint, uint);
+template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows;
@@ -2105,14 +2251,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows;
-typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
- constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
- constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
+typedef void (mat_mm_t)(
+ device const uchar * src0,
+ device const uchar * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant int64_t & nb01,
+ constant int64_t & nb02,
+ constant int64_t & ne12,
+ constant int64_t & nb10,
+ constant int64_t & nb11,
+ constant int64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & gqa,
+ threadgroup uchar *, uint3, uint, uint);
-template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm;
diff --git a/ggml.c b/ggml.c
index dcdebd2..c5b5dd6 100644
--- a/ggml.c
+++ b/ggml.c
@@ -4303,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
}
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
- size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
- for (int i = 1; i < GGML_MAX_DIMS; ++i) {
- nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+ size_t nbytes;
+ size_t blck_size = ggml_blck_size(tensor->type);
+ if (blck_size == 1) {
+ nbytes = ggml_type_size(tensor->type);
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+ nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+ }
}
+ else {
+ nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+ nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+ }
+ }
+
return nbytes;
}
@@ -18345,10 +18356,11 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
for (int i = 0; i < cgraph->n_leafs; i++) {
struct ggml_tensor * node = cgraph->leafs[i];
- GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
i,
node->ne[0], node->ne[1],
- ggml_op_name(node->op));
+ ggml_op_name(node->op),
+ ggml_get_name(node));
}
for (int i = 0; i < GGML_OP_COUNT; i++) {
diff --git a/whisper.cpp b/whisper.cpp
index f5a9a71..23ebd7e 100644
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -3,11 +3,16 @@
#include "coreml/whisper-encoder.h"
#endif
+#ifdef GGML_USE_METAL
+# include "ggml-metal.h"
+#endif
+
#ifdef WHISPER_USE_OPENVINO
#include "openvino/whisper-openvino-encoder.h"
#endif
#include "ggml.h"
+#include "ggml-alloc.h"
#include
#include
@@ -24,6 +29,7 @@
#include
#include
#include
+#include
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -115,9 +121,6 @@ static void byteswap_tensor(ggml_tensor * tensor) {
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 16
-#define WHISPER_USE_SCRATCH
-#define WHISPER_MAX_SCRATCH_BUFFERS 16
-
//
// ggml helpers
//
@@ -133,6 +136,44 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph *
ggml_graph_compute(graph, &plan);
}
+// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
+// the idea is to represent the original matrix multiplication:
+//
+// Z = X @ Y
+//
+// with the sum of two matrix multiplications:
+//
+// Z = (X_0 @ Y_0) + (X_1 @ Y_1)
+//
+// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
+// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
+// general-purpose kernels
+//
+static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) {
+ // use padding only if dimension 0 is at least 8 times larger than the padding
+ // else we won't get much benefit from the optimization
+ const int n_pad_req = 8;
+
+ if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
+ return ggml_mul_mat(ctx, x, y);
+ }
+
+ struct ggml_tensor * x_0 = ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
+ struct ggml_tensor * x_1 = ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
+
+ struct ggml_tensor * y_0 = ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
+ struct ggml_tensor * y_1 = ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
+
+ return ggml_add(ctx,
+ ggml_mul_mat(ctx, x_0, y_0),
+ ggml_mul_mat(ctx, x_1, y_1));
+}
+
+// TODO: check if other platforms can benefit from this optimization
+#if defined(GGML_USE_METAL)
+#define ggml_mul_mat ggml_mul_mat_pad
+#endif
+
// available whisper models
enum e_model {
MODEL_UNKNOWN,
@@ -247,38 +288,7 @@ static const std::map> g_lang = {
static const size_t MB = 1ull*1024*1024;
-static const std::map MEM_REQ_SCRATCH0 = {
- { MODEL_TINY, 62ull*MB },
- { MODEL_BASE, 80ull*MB },
- { MODEL_SMALL, 120ull*MB },
- { MODEL_MEDIUM, 158ull*MB },
- { MODEL_LARGE, 198ull*MB },
-};
-
-static const std::map MEM_REQ_SCRATCH1 = {
- { MODEL_TINY, 18ull*MB },
- { MODEL_BASE, 24ull*MB },
- { MODEL_SMALL, 36ull*MB },
- { MODEL_MEDIUM, 48ull*MB },
- { MODEL_LARGE, 60ull*MB },
-};
-
-static const std::map MEM_REQ_SCRATCH2 = {
- { MODEL_TINY, 4ull*MB },
- { MODEL_BASE, 4ull*MB },
- { MODEL_SMALL, 6ull*MB },
- { MODEL_MEDIUM, 7ull*MB },
- { MODEL_LARGE, 9ull*MB },
-};
-
-static const std::map MEM_REQ_SCRATCH3 = {
- { MODEL_TINY, 4ull*MB },
- { MODEL_BASE, 4ull*MB },
- { MODEL_SMALL, 6ull*MB },
- { MODEL_MEDIUM, 7ull*MB },
- { MODEL_LARGE, 9ull*MB },
-};
-
+// TODO: avoid using GGUF
static const std::map> MEM_REQ_MODEL = {
{ GGML_TYPE_F32,
{
@@ -345,38 +355,6 @@ static const std::map> MEM_REQ_MODEL = {
},
};
-static const std::map MEM_REQ_KV_SELF = {
- { MODEL_TINY, 3ull*MB },
- { MODEL_BASE, 6ull*MB },
- { MODEL_SMALL, 16ull*MB },
- { MODEL_MEDIUM, 43ull*MB },
- { MODEL_LARGE, 71ull*MB },
-};
-
-static const std::map MEM_REQ_KV_CROSS = {
- { MODEL_TINY, 9ull*MB },
- { MODEL_BASE, 18ull*MB },
- { MODEL_SMALL, 53ull*MB },
- { MODEL_MEDIUM, 141ull*MB },
- { MODEL_LARGE, 235ull*MB },
-};
-
-static const std::map MEM_REQ_ENCODE = {
- { MODEL_TINY, 30ull*MB },
- { MODEL_BASE, 38ull*MB },
- { MODEL_SMALL, 56ull*MB },
- { MODEL_MEDIUM, 74ull*MB },
- { MODEL_LARGE, 94ull*MB },
-};
-
-static const std::map MEM_REQ_DECODE = {
- { MODEL_TINY, 3ull*MB },
- { MODEL_BASE, 5ull*MB },
- { MODEL_SMALL, 10ull*MB },
- { MODEL_MEDIUM, 18ull*MB },
- { MODEL_LARGE, 27ull*MB },
-};
-
struct whisper_mel {
int n_len;
int n_len_org;
@@ -657,15 +635,57 @@ struct kv_buf {
std::vector v;
};
+// ggml_allocr wrapper for whisper usage
+struct whisper_allocr {
+ ggml_allocr * alloc = nullptr;
+
+ std::vector meta;
+ std::vector data;
+};
+
+static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
+ return allocr.meta.size() + allocr.data.size();
+}
+
+// measure the memory usage of a graph and prepare the allocr's internal data buffer
+static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function && get_graph) {
+ const int tensor_alignment = 32;
+
+ auto & alloc = allocr.alloc;
+ auto & meta = allocr.meta;
+ auto & data = allocr.data;
+
+ meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
+
+ alloc = ggml_allocr_new_measure(tensor_alignment);
+
+ const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
+
+ ggml_allocr_free(alloc);
+
+ data.resize(alloc_size);
+
+ alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
+}
+
+static void whisper_allocr_free(struct whisper_allocr & allocr) {
+ if (allocr.alloc) {
+ ggml_allocr_free(allocr.alloc);
+ allocr.alloc = nullptr;
+ }
+}
+
struct whisper_state {
int64_t t_sample_us = 0;
int64_t t_encode_us = 0;
int64_t t_decode_us = 0;
+ int64_t t_prompt_us = 0;
int64_t t_mel_us = 0;
int32_t n_sample = 0; // number of tokens sampled
int32_t n_encode = 0; // number of encoder calls
- int32_t n_decode = 0; // number of decoder calls
+ int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
+ int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
int32_t n_fail_p = 0; // number of logprob threshold failures
int32_t n_fail_h = 0; // number of entropy threshold failures
@@ -679,13 +699,20 @@ struct whisper_state {
// buffer for swapping KV caches between decoders during beam-search
std::vector kv_swap_bufs;
- // memory buffers used by encode / decode contexts
- std::vector buf_compute;
- std::vector buf_work;
- std::vector buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
+ // reusable buffer for `struct ggml_graph_plan.work_data`
+ std::vector work_buffer;
- int buf_last = 0;
- size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
+ // ggml-alloc:
+ // - stores meta info about the intermediate tensors into the `meta` buffers
+ // - stores the actual tensor data into the `data` buffers
+ whisper_allocr alloc_conv;
+ whisper_allocr alloc_encode;
+ whisper_allocr alloc_cross;
+ whisper_allocr alloc_decode;
+
+ // result of the encoder
+ struct ggml_tensor * embd_conv = nullptr;
+ struct ggml_tensor * embd_enc = nullptr;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector logits;
@@ -705,6 +732,10 @@ struct whisper_state {
whisper_coreml_context * ctx_coreml = nullptr;
#endif
+#ifdef GGML_USE_METAL
+ ggml_metal_context * ctx_metal = nullptr;
+#endif
+
#ifdef WHISPER_USE_OPENVINO
whisper_openvino_context * ctx_openvino = nullptr;
#endif
@@ -717,37 +748,6 @@ struct whisper_state {
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx = 0; // 0 - use default
-
- void use_buf(struct ggml_context * ctx, int i) {
-#if defined(WHISPER_USE_SCRATCH)
- size_t last_size = 0;
-
- if (i == -1) {
- last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
- } else {
- auto & buf = buf_scratch[i];
- last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
- }
-
- if (buf_last >= 0) {
- buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
- }
-
- buf_last = i;
-#else
- (void) i;
- (void) ctx;
-#endif
- }
-
- size_t get_buf_max_mem(int i) const {
-#if defined(WHISPER_USE_SCRATCH)
- return buf_max_size[i];
-#else
- (void) i;
- return 0;
-#endif
- }
};
struct whisper_context {
@@ -794,10 +794,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
static bool kv_cache_init(
const struct whisper_hparams & hparams,
- const size_t mem_bytes,
struct whisper_kv_cache & cache,
ggml_type wtype,
int n_ctx) {
+ const int64_t n_text_state = hparams.n_text_state;
+ const int64_t n_text_layer = hparams.n_text_layer;
+
+ const int64_t n_mem = n_text_layer*n_ctx;
+ const int64_t n_elements = n_text_state*n_mem;
+
+ const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead());
+
cache.buf.resize(mem_bytes);
struct ggml_init_params params = {
@@ -813,12 +820,6 @@ static bool kv_cache_init(
return false;
}
- const int n_text_state = hparams.n_text_state;
- const int n_text_layer = hparams.n_text_layer;
-
- const int n_mem = n_text_layer*n_ctx;
- const int n_elements = n_text_state*n_mem;
-
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
@@ -961,22 +962,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// print memory requirements
{
- // this is the total memory required to run the inference
- const size_t mem_required =
- MEM_REQ_SCRATCH0.at(model.type) +
- MEM_REQ_SCRATCH1.at(model.type) +
- MEM_REQ_SCRATCH2.at(model.type) +
- MEM_REQ_SCRATCH3.at(model.type) +
- scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
- scale*MEM_REQ_KV_CROSS.at(model.type) +
- scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
-
- // this is the memory required by one decoder
- const size_t mem_required_decoder =
- scale*MEM_REQ_KV_SELF.at(model.type);
-
- log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
- mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
+ // TODO
+ //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
+ // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
}
// initialize all memory buffers
@@ -1485,6 +1473,441 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
return true;
}
+static bool whisper_encode_external(const whisper_state & wstate) {
+ GGML_UNUSED(wstate);
+
+#ifndef WHISPER_USE_COREML
+ const bool use_coreml = false;
+#else
+ const bool use_coreml = wstate.ctx_coreml != nullptr;
+#endif
+
+#ifndef WHISPER_USE_OPENVINO
+ const bool use_openvino = false;
+#else
+ const bool use_openvino = wstate.ctx_openvino != nullptr;
+#endif
+
+ return use_coreml || use_openvino;
+}
+
+static struct ggml_cgraph * whisper_build_graph_conv(
+ whisper_context & wctx,
+ whisper_state & wstate,
+ const int mel_offset) {
+ const auto & model = wctx.model;
+ const auto & mel_inp = wstate.mel;
+ const auto & hparams = model.hparams;
+
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
+ const int n_state = hparams.n_audio_state; GGML_UNUSED(n_state);
+
+ const int n_mels = hparams.n_mels;
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ wstate.alloc_conv.meta.size(),
+ /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
+ /*.no_alloc =*/ true,
+ };
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+ ggml_allocr * alloc = wstate.alloc_conv.alloc;
+
+ struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
+ ggml_allocr_alloc(alloc, mel);
+
+ assert(mel->type == GGML_TYPE_F32);
+ if (!ggml_allocr_is_measure(alloc)) {
+ assert(mel_inp.n_mel == n_mels);
+
+ float * dst = (float *) mel->data;
+ memset(dst, 0, ggml_nbytes(mel));
+
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
+ const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
+
+ for (int j = 0; j < mel_inp.n_mel; ++j) {
+ for (int i = i0; i < i1; ++i) {
+ dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
+ }
+ }
+ }
+
+ struct ggml_tensor * cur = nullptr;
+
+ if (!whisper_encode_external(wstate)) {
+ // convolution + gelu
+ {
+ cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
+ model.e_conv_1_b,
+ cur),
+ cur);
+
+ cur = ggml_gelu(ctx0, cur);
+
+ cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
+ model.e_conv_2_b,
+ cur),
+ cur);
+
+ cur = ggml_gelu(ctx0, cur);
+ }
+
+ wstate.embd_conv = cur;
+ } else {
+#ifdef WHISPER_USE_COREML
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+ ggml_allocr_alloc(alloc, cur);
+
+ if (!ggml_allocr_is_measure(alloc)) {
+ whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
+ }
+#endif
+#ifdef WHISPER_USE_OPENVINO
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+ ggml_allocr_alloc(alloc, cur);
+
+ if (!ggml_allocr_is_measure(alloc)) {
+ whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
+ }
+#endif
+
+ wstate.embd_enc = cur;
+ }
+
+ ggml_build_forward_expand(gf, cur);
+
+ ggml_free(ctx0);
+
+ return gf;
+}
+
+static struct ggml_cgraph * whisper_build_graph_encoder(
+ whisper_context & wctx,
+ whisper_state & wstate) {
+ const auto & model = wctx.model;
+ const auto & hparams = model.hparams;
+
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
+ const int n_state = hparams.n_audio_state;
+ const int n_head = hparams.n_audio_head;
+ const int n_layer = hparams.n_audio_layer;
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ wstate.alloc_encode.meta.size(),
+ /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
+ /*.no_alloc =*/ true,
+ };
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+ ggml_allocr * alloc = wstate.alloc_encode.alloc;
+
+ struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ ggml_allocr_alloc(alloc, KQscale);
+
+ if (!ggml_allocr_is_measure(alloc)) {
+ ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
+ }
+
+ struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
+
+ // ===================================================================
+ // NOTE: experimenting with partial evaluation of the encoder (ignore)
+ //static int iter = -1;
+ //const int n_iter = 1500/n_ctx;
+
+ //iter = (iter + 1) % n_iter;
+
+ //if (iter == 0) {
+ // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
+ // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
+ //}
+
+ static int iter = 0;
+
+ const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
+ const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
+
+ struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
+
+ cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
+
+ // ===================================================================
+
+ // original:
+ //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
+
+ struct ggml_tensor * inpL = cur;
+
+ for (int il = 0; il < n_layer; ++il) {
+ const auto & layer = model.layers_encoder[il];
+
+ // norm
+ {
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
+
+ // cur = ln_0_w*cur + ln_0_b
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0, cur, layer.attn_ln_0_w),
+ layer.attn_ln_0_b);
+ }
+
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
+ layer.attn_q_w,
+ cur);
+
+ Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
+
+ //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+
+ // note: no bias for Key
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
+ layer.attn_k_w,
+ cur);
+
+ //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
+ layer.attn_v_w,
+ cur);
+
+ Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b);
+
+ // ------
+
+#ifdef WHISPER_USE_FLASH_ATTN
+ struct ggml_tensor * Q =
+ ggml_permute(ctx0,
+ ggml_cpy(ctx0,
+ Qcur,
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * K =
+ ggml_permute(ctx0,
+ ggml_cpy(ctx0,
+ Kcur,
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * V =
+ ggml_cpy(ctx0,
+ ggml_permute(ctx0,
+ ggml_reshape_3d(ctx0,
+ Vcur,
+ n_state/n_head, n_head, n_ctx),
+ 1, 2, 0, 3),
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
+
+ struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
+#else
+ struct ggml_tensor * Q =
+ ggml_permute(ctx0,
+ ggml_cpy(ctx0,
+ Qcur,
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * K =
+ ggml_permute(ctx0,
+ ggml_cpy(ctx0,
+ Kcur,
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
+ 0, 2, 1, 3);
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+ struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
+
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
+
+ struct ggml_tensor * V =
+ ggml_cpy(ctx0,
+ ggml_permute(ctx0,
+ ggml_reshape_3d(ctx0,
+ Vcur,
+ n_state/n_head, n_head, n_ctx),
+ 1, 2, 0, 3),
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
+ );
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+#endif
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+ cur = ggml_cpy(ctx0,
+ KQV_merged,
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
+ }
+
+ // projection
+ {
+ cur = ggml_mul_mat(ctx0,
+ layer.attn_ln_1_w,
+ cur);
+
+ cur = ggml_add(ctx0, cur, layer.attn_ln_1_b);
+ }
+
+ // add the input
+ cur = ggml_add(ctx0, cur, inpL);
+
+ struct ggml_tensor * inpFF = cur;
+
+ // feed-forward network
+ {
+ // norm
+ {
+ cur = ggml_norm(ctx0, inpFF, hparams.eps);
+
+ // cur = mlp_ln_w*cur + mlp_ln_b
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0, cur, layer.mlp_ln_w),
+ layer.mlp_ln_b);
+ }
+
+#ifdef WHISPER_USE_FLASH_FF
+ cur = ggml_flash_ff(ctx0,
+ ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
+ layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
+#else
+ // fully connected
+ cur = ggml_mul_mat(ctx0,
+ layer.mlp_0_w,
+ cur);
+
+ cur = ggml_add(ctx0, cur, layer.mlp_0_b);
+
+ // GELU activation
+ cur = ggml_gelu(ctx0, cur);
+
+ // projection
+ cur = ggml_mul_mat(ctx0,
+ layer.mlp_1_w,
+ cur);
+
+ cur = ggml_add(ctx0, cur, layer.mlp_1_b);
+#endif
+ }
+
+ inpL = ggml_add(ctx0, cur, inpFF);
+ }
+
+ cur = inpL;
+
+ // norm
+ {
+ cur = ggml_norm(ctx0, cur, hparams.eps);
+
+ // cur = ln_f_g*cur + ln_f_b
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0, cur, model.e_ln_w),
+ model.e_ln_b);
+ }
+
+ ggml_build_forward_expand(gf, cur);
+
+ wstate.embd_enc = cur;
+
+ //ggml_graph_print(gf);
+
+ ////////////////////////////////////////////////////////////////////////////
+
+ //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
+ // ggml_used_mem(ctx0)/1024.0/1024.0,
+ // wstate.get_buf_max_mem(0)/1024.0/1024.0,
+ // wstate.get_buf_max_mem(1)/1024.0/1024.0,
+ // wstate.get_buf_max_mem(2)/1024.0/1024.0,
+ // wstate.get_buf_max_mem(3)/1024.0/1024.0);
+
+ ggml_free(ctx0);
+
+ return gf;
+}
+
+// pre-compute cross-attention memory
+static struct ggml_cgraph * whisper_build_graph_cross(
+ whisper_context & wctx,
+ whisper_state & wstate) {
+ const auto & model = wctx.model;
+ const auto & hparams = model.hparams;
+
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
+ const int n_state = hparams.n_audio_state;
+ const int n_head = hparams.n_audio_head;
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ wstate.alloc_cross.meta.size(),
+ /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
+ /*.no_alloc =*/ true,
+ };
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+ ggml_allocr * alloc = wstate.alloc_cross.alloc;
+
+ struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
+
+ struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ ggml_allocr_alloc(alloc, Kscale);
+
+ if (!ggml_allocr_is_measure(alloc)) {
+ ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
+ }
+
+ for (int il = 0; il < model.hparams.n_text_layer; ++il) {
+ auto & layer = model.layers_decoder[il];
+
+ struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
+ layer.cross_attn_k_w,
+ cur);
+
+ Kcross = ggml_scale(ctx0, Kcross, Kscale);
+
+ struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
+ layer.cross_attn_v_w,
+ cur);
+
+ Vcross = ggml_add(ctx0,
+ Vcross,
+ layer.cross_attn_v_b);
+
+ Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
+
+ struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
+ n_state*n_ctx,
+ (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
+
+ struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
+ ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
+ (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
+
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
+ }
+
+ //ggml_graph_print(gf);
+
+ ggml_free(ctx0);
+
+ return gf;
+}
+
// evaluate the encoder with the given state
//
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
@@ -1499,453 +1922,69 @@ static bool whisper_encode_internal(
whisper_context & wctx,
whisper_state & wstate,
const int mel_offset,
- const int n_threads){
-
+ const int n_threads) {
const int64_t t_start_us = ggml_time_us();
- const auto & model = wctx.model;
- const auto & mel_inp = wstate.mel;
- const auto & hparams = model.hparams;
-
- const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
- const int n_state = hparams.n_audio_state;
- const int n_head = hparams.n_audio_head;
- const int n_layer = hparams.n_audio_layer;
-
- const int n_mels = hparams.n_mels;
- assert(mel_inp.n_mel == n_mels);
-
- struct ggml_init_params params = {
- /*.mem_size =*/ wstate.buf_compute.size(),
- /*.mem_buffer =*/ wstate.buf_compute.data(),
- /*.no_alloc =*/ false,
- };
-
- struct ggml_context * ctx0 = ggml_init(params);
-
- wstate.use_buf(ctx0, 0);
-
- struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
- assert(mel->type == GGML_TYPE_F32);
+ // conv
{
- float * dst = (float *) mel->data;
- memset(dst, 0, ggml_nbytes(mel));
+ auto & alloc = wstate.alloc_conv.alloc;
- const int i0 = std::min(mel_offset, mel_inp.n_len);
- const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
+ ggml_allocr_reset(alloc);
- for (int j = 0; j < mel_inp.n_mel; ++j) {
- for (int i = i0; i < i1; ++i) {
- dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
- }
+ ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
+
+ ggml_allocr_alloc_graph(alloc, gf);
+
+ if (!whisper_encode_external(wstate)) {
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
}
- struct ggml_tensor * cur;
+ // encoder
+ if (!whisper_encode_external(wstate)) {
+ auto & alloc = wstate.alloc_encode.alloc;
-#ifndef WHISPER_USE_COREML
- const bool use_coreml = false;
-#else
- const bool use_coreml = wstate.ctx_coreml != nullptr;
-#endif
+ ggml_allocr_reset(alloc);
-#ifndef WHISPER_USE_OPENVINO
- const bool use_openvino = false;
-#else
- const bool use_openvino = wstate.ctx_openvino != nullptr;
-#endif
+ ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
- if (!use_coreml && !use_openvino) {
- // convolution + gelu
- {
- wstate.use_buf(ctx0, 1);
+ ggml_allocr_alloc_graph(alloc, gf);
- cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0,
- model.e_conv_1_b,
- cur),
- cur);
-
- cur = ggml_gelu(ctx0, cur);
-
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0,
- model.e_conv_2_b,
- cur),
- cur);
-
- cur = ggml_gelu(ctx0, cur);
+#ifdef GGML_USE_METAL
+ if (wstate.ctx_metal) {
+ ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
+ ggml_metal_graph_compute(wstate.ctx_metal, gf);
+ } else {
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
-
- wstate.use_buf(ctx0, 3);
-
- // ===================================================================
- // NOTE: experimenting with partial evaluation of the encoder (ignore)
- //static int iter = -1;
- //const int n_iter = 1500/n_ctx;
-
- //iter = (iter + 1) % n_iter;
-
- //if (iter == 0) {
- // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
- // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
- //}
-
- static int iter = 0;
-
- const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
- const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
-
- struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
-
- cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
-
- // ===================================================================
-
- // original:
- //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
-
- struct ggml_tensor * inpL = cur;
-
- for (int il = 0; il < n_layer; ++il) {
- const auto & layer = model.layers_encoder[il];
-
- // norm
- {
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_norm(ctx0, inpL, hparams.eps);
-
- // cur = ln_0_w*cur + ln_0_b
- cur = ggml_add(ctx0,
- ggml_mul(ctx0,
- ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
- cur),
- ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
- }
-
- // self-attention
- {
- wstate.use_buf(ctx0, 1);
-
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
- layer.attn_q_w,
- cur);
-
- Qcur = ggml_add(ctx0,
- ggml_repeat(ctx0,
- layer.attn_q_b,
- Qcur),
- Qcur);
-
- //Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
-
- // note: no bias for Key
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
- layer.attn_k_w,
- cur);
-
- //Kcur = ggml_scale_inplace(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
-
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
- layer.attn_v_w,
- cur);
-
- Vcur = ggml_add(ctx0,
- ggml_repeat(ctx0,
- layer.attn_v_b,
- Vcur),
- Vcur);
-
- // ------
-
- wstate.use_buf(ctx0, 0);
-
-#ifdef WHISPER_USE_FLASH_ATTN
- struct ggml_tensor * Q =
- ggml_permute(ctx0,
- ggml_cpy(ctx0,
- Qcur,
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
- 0, 2, 1, 3);
-
- struct ggml_tensor * K =
- ggml_permute(ctx0,
- ggml_cpy(ctx0,
- Kcur,
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
- 0, 2, 1, 3);
-
- struct ggml_tensor * V =
- ggml_cpy(ctx0,
- ggml_permute(ctx0,
- ggml_reshape_3d(ctx0,
- Vcur,
- n_state/n_head, n_head, n_ctx),
- 1, 2, 0, 3),
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
-
- struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
#else
- struct ggml_tensor * Q =
- ggml_permute(ctx0,
- ggml_cpy(ctx0,
- Qcur,
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
- 0, 2, 1, 3);
-
- struct ggml_tensor * K =
- ggml_permute(ctx0,
- ggml_cpy(ctx0,
- Kcur,
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
- 0, 2, 1, 3);
-
- // K * Q
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-
- struct ggml_tensor * KQ_scaled =
- ggml_scale_inplace(ctx0,
- KQ,
- ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
- );
-
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_scaled);
-
- struct ggml_tensor * V =
- ggml_cpy(ctx0,
- ggml_permute(ctx0,
- ggml_reshape_3d(ctx0,
- Vcur,
- n_state/n_head, n_head, n_ctx),
- 1, 2, 0, 3),
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
- );
-
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
#endif
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
-
- wstate.use_buf(ctx0, 1);
-
- cur = ggml_cpy(ctx0,
- KQV_merged,
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
- }
-
- // projection
- {
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_mul_mat(ctx0,
- layer.attn_ln_1_w,
- cur);
-
- wstate.use_buf(ctx0, 1);
-
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
- cur);
- }
-
- wstate.use_buf(ctx0, 2);
-
- // add the input
- cur = ggml_add(ctx0, cur, inpL);
-
- struct ggml_tensor * inpFF = cur;
-
- // feed-forward network
- {
- // norm
- {
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_norm(ctx0, inpFF, hparams.eps);
-
- wstate.use_buf(ctx0, 1);
-
- // cur = mlp_ln_w*cur + mlp_ln_b
- cur = ggml_add(ctx0,
- ggml_mul(ctx0,
- ggml_repeat(ctx0, layer.mlp_ln_w, cur),
- cur),
- ggml_repeat(ctx0, layer.mlp_ln_b, cur));
- }
-
-#ifdef WHISPER_USE_FLASH_FF
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_flash_ff(ctx0,
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
-#else
- wstate.use_buf(ctx0, 0);
-
- // fully connected
- cur = ggml_mul_mat(ctx0,
- layer.mlp_0_w,
- cur);
-
- wstate.use_buf(ctx0, 1);
-
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0, layer.mlp_0_b, cur),
- cur);
-
- wstate.use_buf(ctx0, 0);
-
- // GELU activation
- cur = ggml_gelu(ctx0, cur);
-
- wstate.use_buf(ctx0, 1);
-
- // projection
- cur = ggml_mul_mat(ctx0,
- layer.mlp_1_w,
- cur);
-
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0, layer.mlp_1_b, cur),
- cur);
-#endif
- }
-
- wstate.use_buf(ctx0, 3);
-
- inpL = ggml_add(ctx0, cur, inpFF);
- }
-
- cur = inpL;
-
- // norm
- {
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_norm(ctx0, cur, hparams.eps);
-
- wstate.use_buf(ctx0, 1);
-
- // cur = ln_f_g*cur + ln_f_b
- cur = ggml_add(ctx0,
- ggml_mul(ctx0,
- ggml_repeat(ctx0, model.e_ln_w, cur),
- cur),
- ggml_repeat(ctx0, model.e_ln_b, cur));
- }
-
- wstate.use_buf(ctx0, -1);
-
- // run the computation
- {
- struct ggml_cgraph gf = {};
-
- ggml_build_forward_expand(&gf, cur);
- ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
-
- //ggml_graph_print(&gf);
- }
}
-#ifdef WHISPER_USE_COREML
- else if (use_coreml) {
- wstate.use_buf(ctx0, -1);
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
-
- whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
- }
-#endif
-#ifdef WHISPER_USE_OPENVINO
- else if (use_openvino) {
- wstate.use_buf(ctx0, -1);
-
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
-
- if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
- return false;
- }
- }
-#endif
-
- // cur
- //{
- // printf("ne0 = %d\n", cur->ne[0]);
- // printf("ne1 = %d\n", cur->ne[1]);
- // for (int i = 0; i < 10; ++i) {
- // printf("%8.4f ", ((float *)(cur->data))[i]);
- // }
- // printf("... ");
- // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
- // printf("%8.4f ", ((float *)(cur->data))[i]);
- // }
- // printf("\n");
- //}
-
- // pre-compute cross-attention memory
+ // cross
{
- struct ggml_cgraph gf = {};
+ auto & alloc = wstate.alloc_cross.alloc;
- // TODO: hack to disconnect the encoded features from the previous graph
- cur->op = GGML_OP_NONE;
- cur->src[0] = nullptr;
- cur->src[1] = nullptr;
+ ggml_allocr_reset(alloc);
- for (int il = 0; il < model.hparams.n_text_layer; ++il) {
- auto& layer = model.layers_decoder[il];
+ ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
- wstate.use_buf(ctx0, 0);
+ ggml_allocr_alloc_graph(alloc, gf);
- struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
- layer.cross_attn_k_w,
- cur);
-
- Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
-
- wstate.use_buf(ctx0, 1);
-
- struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
- layer.cross_attn_v_w,
- cur);
-
- Vcross = ggml_add(ctx0,
- ggml_repeat(ctx0,
- layer.cross_attn_v_b,
- Vcross),
- Vcross);
-
- wstate.use_buf(ctx0, -1);
-
- Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
-
- struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
- struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
- ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
- (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
-
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
+#ifdef GGML_USE_METAL
+ if (wstate.ctx_metal) {
+ ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
+ ggml_metal_graph_compute(wstate.ctx_metal, gf);
+ } else {
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
-
- ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
- //ggml_graph_print(&gf);
+#else
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+#endif
}
- ////////////////////////////////////////////////////////////////////////////
-
- //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
- // ggml_used_mem(ctx0)/1024.0/1024.0,
- // wstate.get_buf_max_mem(0)/1024.0/1024.0,
- // wstate.get_buf_max_mem(1)/1024.0/1024.0,
- // wstate.get_buf_max_mem(2)/1024.0/1024.0,
- // wstate.get_buf_max_mem(3)/1024.0/1024.0);
-
- ggml_free(ctx0);
+ // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
wstate.t_encode_us += ggml_time_us() - t_start_us;
wstate.n_encode++;
@@ -1953,6 +1992,343 @@ static bool whisper_encode_internal(
return true;
}
+static struct ggml_cgraph * whisper_build_graph_decoder(
+ whisper_context & wctx,
+ whisper_state & wstate,
+ whisper_decoder & decoder,
+ const whisper_token * tokens,
+ int n_tokens,
+ int n_past) {
+ const auto & model = wctx.model;
+ const auto & hparams = model.hparams;
+
+ auto & kv_self = decoder.kv_self;
+
+ WHISPER_ASSERT(!!kv_self.ctx);
+
+ const int n_ctx = hparams.n_text_ctx;
+ const int n_state = hparams.n_text_state;
+ const int n_head = hparams.n_text_head;
+ const int n_layer = hparams.n_text_layer;
+
+ const int N = n_tokens;
+ const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
+
+ //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ wstate.alloc_decode.meta.size(),
+ /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
+ /*.no_alloc =*/ true,
+ };
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+ ggml_allocr * alloc = wstate.alloc_decode.alloc;
+
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+ ggml_allocr_alloc(alloc, embd);
+
+ if (!ggml_allocr_is_measure(alloc)) {
+ memcpy(embd->data, tokens, N*ggml_element_size(embd));
+ }
+
+ struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+ ggml_allocr_alloc(alloc, position);
+
+ if (!ggml_allocr_is_measure(alloc)) {
+ for (int i = 0; i < N; ++i) {
+ ((int32_t *) position->data)[i] = n_past + i;
+ }
+ }
+
+ struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ ggml_allocr_alloc(alloc, KQscale);
+
+ if (!ggml_allocr_is_measure(alloc)) {
+ ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
+ }
+
+ // token encoding + position encoding
+ struct ggml_tensor * cur =
+ ggml_add(ctx0,
+ ggml_get_rows(ctx0, model.d_te, embd),
+ ggml_get_rows(ctx0, model.d_pe, position));
+
+ struct ggml_tensor * inpL = cur;
+
+ for (int il = 0; il < n_layer; ++il) {
+ const auto & layer = model.layers_decoder[il];
+
+ // norm
+ {
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
+
+ // cur = ln_0_w*cur + ln_0_b
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ cur,
+ layer.attn_ln_0_w),
+ layer.attn_ln_0_b);
+ }
+
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
+ layer.attn_q_w,
+ cur);
+
+ Qcur = ggml_add(ctx0,
+ Qcur,
+ layer.attn_q_b);
+
+ Qcur = ggml_scale(ctx0, Qcur, KQscale);
+
+ // note: no bias for Key
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
+ layer.attn_k_w,
+ cur);
+
+ Kcur = ggml_scale(ctx0, Kcur, KQscale);
+
+ // store key and value to memory
+ {
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
+ layer.attn_v_w,
+ cur);
+
+ Vcur = ggml_add(ctx0,
+ Vcur,
+ layer.attn_v_b);
+
+ Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
+
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
+ struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
+ ( n_ctx)*ggml_element_size(kv_self.v),
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
+
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+ }
+
+ // ------
+
+ struct ggml_tensor * Q =
+ ggml_permute(ctx0,
+ ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * K =
+ ggml_view_3d(ctx0, kv_self.k,
+ n_state/n_head, n_past + N, n_head,
+ ggml_element_size(kv_self.k)*n_state,
+ ggml_element_size(kv_self.k)*n_state/n_head,
+ ggml_element_size(kv_self.k)*n_state*n_ctx*il);
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+ //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
+
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
+
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+ struct ggml_tensor * V =
+ ggml_view_3d(ctx0, kv_self.v,
+ n_past + N, n_state/n_head, n_head,
+ n_ctx*ggml_element_size(kv_self.v),
+ n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
+ il*n_ctx*ggml_element_size(kv_self.v)*n_state);
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+ cur = ggml_cpy(ctx0,
+ KQV_merged,
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
+ }
+
+ // projection
+ {
+ cur = ggml_mul_mat(ctx0,
+ layer.attn_ln_1_w,
+ cur);
+
+ cur = ggml_add(ctx0,
+ cur,
+ layer.attn_ln_1_b);
+ }
+
+ // add the input
+ struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
+
+ // norm
+ {
+ cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
+
+ // cur = ln_0_w*cur + ln_0_b
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ cur,
+ layer.cross_attn_ln_0_w),
+ layer.cross_attn_ln_0_b);
+ }
+
+ // cross-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
+ layer.cross_attn_q_w,
+ cur);
+
+ Qcur = ggml_add(ctx0,
+ Qcur,
+ layer.cross_attn_q_b);
+
+ Qcur = ggml_scale(ctx0, Qcur, KQscale);
+
+ // Kcross is already scaled
+ struct ggml_tensor * Kcross =
+ ggml_view_3d(ctx0, wstate.kv_cross.k,
+ n_state/n_head, M, n_head,
+ ggml_element_size(wstate.kv_cross.k)*n_state,
+ ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
+ ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
+
+ //struct ggml_tensor * Vcross =
+ // ggml_reshape_3d(ctx0,
+ // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
+ // n_state/n_head, n_head, M);
+
+ //struct ggml_tensor * V_trans =
+ // ggml_cpy(ctx0,
+ // ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
+ // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
+
+ struct ggml_tensor * V =
+ ggml_view_3d(ctx0, wstate.kv_cross.v,
+ M, n_state/n_head, n_head,
+ M*ggml_element_size(wstate.kv_cross.v),
+ M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
+ il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
+
+ // ------
+
+ struct ggml_tensor * Q =
+ ggml_permute(ctx0,
+ ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
+ 0, 2, 1, 3);
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
+
+ //struct ggml_tensor * KQ_scaled =
+ // ggml_scale(ctx0,
+ // KQ,
+ // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
+ // );
+
+ // no masking for cross-attention
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+ // cur = KQV_merged.contiguous().view(n_state, N)
+ cur = ggml_cpy(ctx0,
+ KQV_merged,
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
+ }
+
+ // projection
+ {
+ cur = ggml_mul_mat(ctx0,
+ layer.cross_attn_ln_1_w,
+ cur);
+
+ cur = ggml_add(ctx0,
+ cur,
+ layer.cross_attn_ln_1_b);
+ }
+
+ // add the input
+ cur = ggml_add(ctx0, cur, inpCA);
+
+ struct ggml_tensor * inpFF = cur;
+
+ // feed-forward network
+ {
+ // norm
+ {
+ cur = ggml_norm(ctx0, inpFF, hparams.eps);
+
+ // cur = mlp_ln_w*cur + mlp_ln_b
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ cur,
+ layer.mlp_ln_w),
+ layer.mlp_ln_b);
+ }
+
+ // fully connected
+ cur = ggml_mul_mat(ctx0,
+ layer.mlp_0_w,
+ cur);
+
+ cur = ggml_add(ctx0,
+ cur,
+ layer.mlp_0_b);
+
+ // GELU activation
+ cur = ggml_gelu(ctx0, cur);
+
+ // projection
+ cur = ggml_mul_mat(ctx0,
+ layer.mlp_1_w,
+ cur);
+
+ cur = ggml_add(ctx0,
+ cur,
+ layer.mlp_1_b);
+ }
+
+ inpL = ggml_add(ctx0, cur, inpFF);
+ }
+
+ cur = inpL;
+
+ // norm
+ {
+ cur = ggml_norm(ctx0, cur, hparams.eps);
+
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ cur,
+ model.d_ln_w),
+ model.d_ln_b);
+ }
+
+ // compute logits only for the last token
+ // comment this line to compute logits for all N tokens
+ // might be useful in the future
+ cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
+
+ struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
+
+ ggml_build_forward_expand(gf, logits);
+
+ ggml_free(ctx0);
+
+ return gf;
+}
+
// evaluate the decoder
//
// given text prompt + audio features -> computes the logits for the next token
@@ -1976,388 +2352,45 @@ static bool whisper_decode_internal(
const auto & model = wctx.model;
const auto & hparams = model.hparams;
- auto & kv_self = decoder.kv_self;
-
- WHISPER_ASSERT(!!kv_self.ctx);
+ const int n_vocab = hparams.n_vocab;
auto & logits_out = wstate.logits;
- const int n_vocab = hparams.n_vocab;
+ struct ggml_tensor * logits;
- const int n_ctx = hparams.n_text_ctx;
- const int n_state = hparams.n_text_state;
- const int n_head = hparams.n_text_head;
- const int n_layer = hparams.n_text_layer;
-
- const int N = n_tokens;
- const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
-
- //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
-
- struct ggml_init_params params = {
- /*.mem_size =*/ wstate.buf_compute.size(),
- /*.mem_buffer =*/ wstate.buf_compute.data(),
- /*.no_alloc =*/ false,
- };
-
- struct ggml_context * ctx0 = ggml_init(params);
-
- struct ggml_cgraph gf = {};
-
- struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
- memcpy(embd->data, tokens, N*ggml_element_size(embd));
-
- struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
- for (int i = 0; i < N; ++i) {
- ((int32_t *) position->data)[i] = n_past + i;
- }
-
- wstate.use_buf(ctx0, 3);
-
- // token encoding + position encoding
- struct ggml_tensor * cur =
- ggml_add(ctx0,
- ggml_get_rows(ctx0, model.d_te, embd),
- ggml_get_rows(ctx0, model.d_pe, position));
-
- struct ggml_tensor * inpL = cur;
-
- for (int il = 0; il < n_layer; ++il) {
- const auto & layer = model.layers_decoder[il];
-
- // norm
- {
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_norm(ctx0, inpL, hparams.eps);
-
- // cur = ln_0_w*cur + ln_0_b
- cur = ggml_add(ctx0,
- ggml_mul(ctx0,
- ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
- cur),
- ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
- }
-
- // self-attention
- {
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
- layer.attn_q_w,
- cur);
-
- Qcur = ggml_add(ctx0,
- ggml_repeat(ctx0,
- layer.attn_q_b,
- Qcur),
- Qcur);
-
- Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
-
- // note: no bias for Key
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
- layer.attn_k_w,
- cur);
-
- Kcur = ggml_scale_inplace(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
-
- // store key and value to memory
- {
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
- layer.attn_v_w,
- cur);
-
- Vcur = ggml_add(ctx0,
- ggml_repeat(ctx0,
- layer.attn_v_b,
- Vcur),
- Vcur);
-
- Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
-
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
- struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
- ( n_ctx)*ggml_element_size(kv_self.v),
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
-
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
- }
-
- // ------
-
- wstate.use_buf(ctx0, 0);
-
- struct ggml_tensor * Q =
- ggml_permute(ctx0,
- ggml_cpy(ctx0,
- Qcur,
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
- 0, 2, 1, 3);
-
- struct ggml_tensor * K =
- ggml_permute(ctx0,
- ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
- n_state/n_head, n_head, n_past + N),
- 0, 2, 1, 3);
-
- wstate.use_buf(ctx0, 1);
-
- // K * Q
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-
- //struct ggml_tensor * KQ_scaled =
- // ggml_scale_inplace(ctx0,
- // KQ,
- // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
- // );
-
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ, n_past);
-
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
-
- struct ggml_tensor * V =
- ggml_view_3d(ctx0, kv_self.v,
- n_past + N, n_state/n_head, n_head,
- n_ctx*ggml_element_size(kv_self.v),
- n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
- il*n_ctx*ggml_element_size(kv_self.v)*n_state);
-
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
-
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
-
- cur = ggml_cpy(ctx0,
- KQV_merged,
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
- }
-
- // projection
- {
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_mul_mat(ctx0,
- layer.attn_ln_1_w,
- cur);
-
- wstate.use_buf(ctx0, 1);
-
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
- cur);
- }
-
- wstate.use_buf(ctx0, 2);
-
- // add the input
- struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
-
- // norm
- {
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
-
- // cur = ln_0_w*cur + ln_0_b
- cur = ggml_add(ctx0,
- ggml_mul(ctx0,
- ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
- cur),
- ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
- }
-
- // cross-attention
- {
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
- layer.cross_attn_q_w,
- cur);
-
- Qcur = ggml_add(ctx0,
- ggml_repeat(ctx0,
- layer.cross_attn_q_b,
- Qcur),
- Qcur);
-
- Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
-
- // Kcross is already scaled
- struct ggml_tensor * Kcross =
- ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
- n_state/n_head, n_head, M);
-
- //struct ggml_tensor * Vcross =
- // ggml_reshape_3d(ctx0,
- // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
- // n_state/n_head, n_head, M);
-
- //struct ggml_tensor * V_trans =
- // ggml_cpy(ctx0,
- // ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
- // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
-
- struct ggml_tensor * V =
- ggml_view_3d(ctx0, wstate.kv_cross.v,
- M, n_state/n_head, n_head,
- M*ggml_element_size(wstate.kv_cross.v),
- M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
- il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
-
- // ------
-
- struct ggml_tensor * Q =
- ggml_permute(ctx0,
- ggml_cpy(ctx0,
- Qcur,
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
- 0, 2, 1, 3);
-
- struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
-
- // K * Q
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-
- //struct ggml_tensor * KQ_scaled =
- // ggml_scale_inplace(ctx0,
- // KQ,
- // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
- // );
-
- // no masking for cross-attention
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
-
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ);
-
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
-
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
-
- // cur = KQV_merged.contiguous().view(n_state, N)
- cur = ggml_cpy(ctx0,
- KQV_merged,
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
- }
-
- // projection
- {
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_mul_mat(ctx0,
- layer.cross_attn_ln_1_w,
- cur);
-
- wstate.use_buf(ctx0, 1);
-
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
- cur);
- }
-
- wstate.use_buf(ctx0, 2);
-
- // add the input
- cur = ggml_add(ctx0, cur, inpCA);
-
- struct ggml_tensor * inpFF = cur;
-
- // feed-forward network
- {
- // norm
- {
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_norm(ctx0, inpFF, hparams.eps);
-
- wstate.use_buf(ctx0, 1);
-
- // cur = mlp_ln_w*cur + mlp_ln_b
- cur = ggml_add(ctx0,
- ggml_mul(ctx0,
- ggml_repeat(ctx0, layer.mlp_ln_w, cur),
- cur),
- ggml_repeat(ctx0, layer.mlp_ln_b, cur));
- }
-
- wstate.use_buf(ctx0, 0);
-
- // fully connected
- cur = ggml_mul_mat(ctx0,
- layer.mlp_0_w,
- cur);
-
- wstate.use_buf(ctx0, 1);
-
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0, layer.mlp_0_b, cur),
- cur);
-
- wstate.use_buf(ctx0, 0);
-
- // GELU activation
- cur = ggml_gelu(ctx0, cur);
-
- wstate.use_buf(ctx0, 1);
-
- // projection
- cur = ggml_mul_mat(ctx0,
- layer.mlp_1_w,
- cur);
-
- wstate.use_buf(ctx0, 0);
-
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0, layer.mlp_1_b, cur),
- cur);
- }
-
- wstate.use_buf(ctx0, 3);
-
- inpL = ggml_add(ctx0, cur, inpFF);
- }
-
- cur = inpL;
-
- // norm
+ // decoder
{
- wstate.use_buf(ctx0, 0);
+ auto & alloc = wstate.alloc_decode.alloc;
- cur = ggml_norm(ctx0, cur, hparams.eps);
+ ggml_allocr_reset(alloc);
- wstate.use_buf(ctx0, 1);
+ ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
- cur = ggml_add(ctx0,
- ggml_mul(ctx0,
- ggml_repeat(ctx0, model.d_ln_w, cur),
- cur),
- ggml_repeat(ctx0, model.d_ln_b, cur));
- }
+ ggml_allocr_alloc_graph(alloc, gf);
- wstate.use_buf(ctx0, 0);
+ logits = gf->nodes[gf->n_nodes - 1];
- // compute logits only for the last token
- // comment this line to compute logits for all N tokens
- // might be useful in the future
- cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
-
- struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
-
- wstate.use_buf(ctx0, -1);
-
- // run the computation
- {
- ggml_build_forward_expand(&gf, logits);
- ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
+#ifdef GGML_USE_METAL
+ if (wstate.ctx_metal) {
+ ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
+ ggml_metal_graph_compute(wstate.ctx_metal, gf);
+ } else {
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ }
+#else
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+#endif
}
// extract logits for all N tokens
- //logits_out.resize(N*n_vocab);
- //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
+ //logits_out.resize(n_tokens*n_vocab);
+ //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
// extract logits only for the last token
logits_out.resize(n_vocab);
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
- if (N > 1) {
+ if (n_tokens > 1) {
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
// ggml_used_mem(ctx0)/1024.0/1024.0,
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
@@ -2366,14 +2399,18 @@ static bool whisper_decode_internal(
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
}
- ggml_free(ctx0);
-
- wstate.t_decode_us += ggml_time_us() - t_start_us;
- wstate.n_decode++;
+ if (n_tokens == 1) {
+ wstate.t_decode_us += ggml_time_us() - t_start_us;
+ wstate.n_decode++;
+ } else {
+ wstate.t_prompt_us += ggml_time_us() - t_start_us;
+ wstate.n_prompt++;
+ }
return true;
}
+
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2782,9 +2819,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
fill_sin_cos_table();
whisper_state * state = new whisper_state;
- const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
-
- if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
+ if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
delete state;
return nullptr;
@@ -2795,7 +2830,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}
- if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
+ if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
delete state;
return nullptr;
@@ -2816,6 +2851,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
if (!state->ctx_coreml) {
log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
#ifndef WHISPER_COREML_ALLOW_FALLBACK
+ delete state;
return nullptr;
#endif
} else {
@@ -2830,15 +2866,111 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// TAGS: WHISPER_DECODER_INIT
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
- state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
- state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
+ state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
+ state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
- state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
- state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
- state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
- state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
- state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
+ // conv allocator
+ {
+ whisper_allocr_graph_init(state->alloc_conv,
+ [&]() {
+ return whisper_build_graph_conv(*ctx, *state, 0);
+ });
+
+ log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
+ }
+
+ // encoder allocator
+ if (!whisper_encode_external(*state)) {
+ whisper_allocr_graph_init(state->alloc_encode,
+ [&]() {
+ return whisper_build_graph_encoder(*ctx, *state);
+ });
+
+ log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
+ }
+
+ // cross allocator
+ {
+ whisper_allocr_graph_init(state->alloc_cross,
+ [&]() {
+ return whisper_build_graph_cross(*ctx, *state);
+ });
+
+ log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
+ }
+
+ // decoder allocator
+ {
+ whisper_allocr_graph_init(state->alloc_decode,
+ [&]() {
+ const auto & hparams = ctx->model.hparams;
+
+ // TODO: make sure this is the worst-case scenario
+ const int n_tokens = hparams.n_text_ctx;
+ const int n_past = 0;
+
+ return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
+ });
+
+ log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
+ }
+
+#ifdef GGML_USE_METAL
+ state->ctx_metal = ggml_metal_init(1);
+ if (!state->ctx_metal) {
+ log("%s: ggml_metal_init() failed\n", __func__);
+ delete state;
+ return nullptr;
+ }
+
+ log("%s: Metal context initialized\n", __func__);
+
+ // this allocates all Metal resources and memory buffers
+
+ void * data_ptr = NULL;
+ size_t data_size = 0;
+
+ // TODO: add mmap support
+ //if (params.use_mmap) {
+ // data_ptr = ctx->model.mapping->addr;
+ // data_size = ctx->model.mapping->size;
+ //} else {
+ // data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
+ // data_size = ggml_get_mem_size (ctx->model.ctx);
+ //}
+
+ data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
+ data_size = ggml_get_mem_size (ctx->model.ctx);
+
+ const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
+
+ log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
+
+#define WHISPER_METAL_CHECK_BUF(result) \
+ if (!(result)) { \
+ log("%s: failed to add metal buffer\n", __func__); \
+ delete state; \
+ return nullptr; \
+ }
+
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
+
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
+
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
+
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
+
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
+#undef WHISPER_METAL_CHECK_BUF
+#endif
state->rng = std::mt19937(0);
@@ -2895,7 +3027,6 @@ int whisper_ctx_init_openvino_encoder(
}
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
-
log("%s: loading model from '%s'\n", __func__, path_model);
auto fin = std::ifstream(path_model, std::ios::binary);
@@ -3048,6 +3179,13 @@ void whisper_free_state(struct whisper_state * state)
}
#endif
+#ifdef GGML_USE_METAL
+ if (state->ctx_metal) {
+ ggml_metal_free(state->ctx_metal);
+ state->ctx_metal = nullptr;
+ }
+#endif
+
#ifdef WHISPER_USE_OPENVINO
if (state->ctx_openvino != nullptr) {
whisper_openvino_free(state->ctx_openvino);
@@ -3055,6 +3193,11 @@ void whisper_free_state(struct whisper_state * state)
}
#endif
+ whisper_allocr_free(state->alloc_conv);
+ whisper_allocr_free(state->alloc_decode);
+ whisper_allocr_free(state->alloc_cross);
+ whisper_allocr_free(state->alloc_encode);
+
delete state;
}
}
@@ -3475,12 +3618,14 @@ void whisper_print_timings(struct whisper_context * ctx) {
const int32_t n_sample = std::max(1, ctx->state->n_sample);
const int32_t n_encode = std::max(1, ctx->state->n_encode);
const int32_t n_decode = std::max(1, ctx->state->n_decode);
+ const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
+ log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
}
log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
}
@@ -3490,6 +3635,11 @@ void whisper_reset_timings(struct whisper_context * ctx) {
ctx->state->t_sample_us = 0;
ctx->state->t_encode_us = 0;
ctx->state->t_decode_us = 0;
+ ctx->state->t_prompt_us = 0;
+ ctx->state->n_sample = 0;
+ ctx->state->n_encode = 0;
+ ctx->state->n_decode = 0;
+ ctx->state->n_prompt = 0;
}
}
@@ -4339,6 +4489,21 @@ int whisper_full_with_state(
decoder.probs.resize (ctx->vocab.n_vocab);
decoder.logits.resize (ctx->vocab.n_vocab);
decoder.logprobs.resize(ctx->vocab.n_vocab);
+
+ // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
+#ifdef GGML_USE_METAL
+#define WHISPER_METAL_CHECK_BUF(result) \
+ if (!(result)) { \
+ log("%s: failed to add metal buffer\n", __func__); \
+ return 0; \
+ }
+
+ const std::string kv_name = "kv_self_" + std::to_string(j);
+ auto & kv_self = decoder.kv_self;
+
+ WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
+#undef WHISPER_METAL_CHECK_BUF
+#endif
}
}
@@ -4531,8 +4696,8 @@ int whisper_full_with_state(
decoder.kv_self.n += prompt.size();
- memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
- memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
+ memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
+ memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
}
@@ -5045,6 +5210,12 @@ int whisper_full_parallel(
ctx->state->t_sample_us += states[i]->t_sample_us;
ctx->state->t_encode_us += states[i]->t_encode_us;
ctx->state->t_decode_us += states[i]->t_decode_us;
+ ctx->state->t_prompt_us += states[i]->t_prompt_us;
+
+ ctx->state->n_sample += states[i]->n_sample;
+ ctx->state->n_encode += states[i]->n_encode;
+ ctx->state->n_decode += states[i]->n_decode;
+ ctx->state->n_prompt += states[i]->n_prompt;
whisper_free_state(states[i]);
}
@@ -5241,8 +5412,8 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
// b: N*N*sizeof(float)
// c: N*N*sizeof(float)
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
- std::vector buf (3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead());
- std::vector work(1llu*N_max*N_max*sizeof(float) + 1*ggml_tensor_overhead());
+ std::vector buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead());
+ std::vector work;
// put a bunch of random data in the buffer
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;