diff --git a/ggml.c b/ggml.c index f479dc3..000d9db 100644 --- a/ggml.c +++ b/ggml.c @@ -5,6 +5,7 @@ #include "ggml-quants.h" #include "ggml.h" + #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) @@ -28,6 +29,10 @@ #include #endif +#ifdef GGML_USE_OPENMP +#include +#endif + #ifdef GGML_USE_METAL #include #endif @@ -1756,7 +1761,7 @@ struct ggml_compute_state_shared { int64_t perf_node_start_cycles; int64_t perf_node_start_time_us; - const int n_threads; + int n_threads; // synchronization primitives atomic_int n_active; // num active threads @@ -19670,6 +19675,59 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa return cplan; } +static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * workers, int n_threads) { + enum ggml_status compute_status = GGML_STATUS_SUCCESS; + +#ifdef GGML_USE_OPENMP + if (n_threads > 1) { + #pragma omp parallel num_threads(n_threads) + { + #pragma omp single + { + // update the number of threads from the actual number of threads that we got from OpenMP + n_threads = omp_get_num_threads(); + workers[0].shared->n_threads = n_threads; + workers[0].shared->n_active = n_threads; + } + ggml_graph_compute_thread(&workers[omp_get_thread_num()]); + } + } else { + ggml_graph_compute_thread(&workers[0]); + } +#else + // create thread pool + if (n_threads > 1) { + for (int j = 1; j < n_threads; ++j) { + const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); + GGML_ASSERT(rc == 0); + UNUSED(rc); + } + } + + // this is a work thread too + ggml_graph_compute_thread(&workers[0]); + + // join or kill thread pool + if (n_threads > 1) { + for (int j = 1; j < n_threads; j++) { + const int rc = ggml_thread_join(workers[j].thrd, NULL); + GGML_ASSERT(rc == 0); + UNUSED(rc); + } + } +#endif + // don't leave affinity set on the main thread + clear_numa_thread_affinity(); + + for (int j = 0; j < n_threads; j++) { + if (workers[j].ec != GGML_STATUS_SUCCESS) { + compute_status = workers[j].ec; + break; + } + } + return compute_status; +} + enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { { GGML_ASSERT(cplan); @@ -19680,7 +19738,11 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl } } - const int n_threads = cplan->n_threads; + int n_threads = cplan->n_threads; + +#if defined(GGML_USE_OPENMP) + n_threads = MIN(n_threads, omp_get_max_threads()); +#endif struct ggml_compute_state_shared state_shared = { /*.cgraph =*/ cgraph, @@ -19696,47 +19758,20 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl /*.current_chunk; =*/ 0, }; struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); - - // create thread pool - if (n_threads > 1) { - for (int j = 1; j < n_threads; ++j) { - workers[j] = (struct ggml_compute_state) { - .thrd = 0, - .ith = j, - .shared = &state_shared, - .ec = GGML_STATUS_SUCCESS, - }; - - const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); - GGML_ASSERT(rc == 0); - UNUSED(rc); - } - } - - workers[0].ith = 0; - workers[0].shared = &state_shared; - workers[0].ec = GGML_STATUS_SUCCESS; - const int64_t perf_start_cycles = ggml_perf_cycles(); const int64_t perf_start_time_us = ggml_perf_time_us(); - // this is a work thread too - ggml_graph_compute_thread(&workers[0]); - enum ggml_status compute_status = workers[0].ec; - - // don't leave affinity set on the main thread - clear_numa_thread_affinity(); - - // join or kill thread pool - if (n_threads > 1) { - for (int j = 1; j < n_threads; j++) { - const int rc = ggml_thread_join(workers[j].thrd, NULL); - GGML_ASSERT(rc == 0); - if (workers[j].ec != GGML_STATUS_SUCCESS) - compute_status = workers[j].ec; - } + for (int j = 0; j < n_threads; ++j) { + workers[j] = (struct ggml_compute_state) { + .thrd = 0, + .ith = j, + .shared = &state_shared, + .ec = GGML_STATUS_SUCCESS, + }; } + enum ggml_status compute_status = ggml_graph_compute_parallel(workers, n_threads); + // performance stats (graph) { int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles;