diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index a63e833..05e2279 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -107,5 +107,63 @@ whisper.transcribe("path/to/audio.wav", params) ``` +You can see model information: + +```ruby +whisper = Whisper::Context.new("path/to/model.bin") +model = whisper.model + +model.n_vocab # => 51864 +model.n_audio_ctx # => 1500 +model.n_audio_state # => 512 +model.n_audio_head # => 8 +model.n_audio_layer # => 6 +model.n_text_ctx # => 448 +model.n_text_state # => 512 +model.n_text_head # => 8 +model.n_text_layer # => 6 +model.n_mels # => 80 +model.ftype # => 1 +model.type # => "base" + +``` + +You can set log callback: + +```ruby +prefix = "[MyApp] " +log_callback = ->(level, buffer, user_data) { + case level + when Whisper::LOG_LEVEL_NONE + puts "#{user_data}none: #{buffer}" + when Whisper::LOG_LEVEL_INFO + puts "#{user_data}info: #{buffer}" + when Whisper::LOG_LEVEL_WARN + puts "#{user_data}warn: #{buffer}" + when Whisper::LOG_LEVEL_ERROR + puts "#{user_data}error: #{buffer}" + when Whisper::LOG_LEVEL_DEBUG + puts "#{user_data}debug: #{buffer}" + when Whisper::LOG_LEVEL_CONT + puts "#{user_data}same to previous: #{buffer}" + end +} +Whisper.log_set log_callback, prefix +``` + +Using this feature, you are also able to suppress log: + +```ruby +Whisper.log_set ->(level, buffer, user_data) { + # do nothing +}, nil +Whisper::Context.new(MODEL) +``` + +License +------- + +The same to [whisper.cpp][]. + [whisper.cpp]: https://github.com/ggerganov/whisper.cpp [models]: https://github.com/ggerganov/whisper.cpp/tree/master/models diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index 5a6a916..d6fc49c 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -23,30 +23,39 @@ CLEAN.include FileList[ "ext/depend" ] -task build: SOURCES + FileList[ - "ext/extconf.rb", - "ext/ruby_whisper.h", - "ext/ruby_whisper.cpp", - "whispercpp.gemspec", - ] +task build: FileList[ + "ext/Makefile", + "ext/ruby_whisper.h", + "ext/ruby_whisper.cpp", + "whispercpp.gemspec", + ] directory "pkg" CLOBBER.include "pkg" TEST_MODEL = "../../models/ggml-base.en.bin" LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"]) +SO_FILE = File.join("ext", LIB_NAME) LIB_FILE = File.join("lib", LIB_NAME) -directory "lib" -task LIB_FILE => SOURCES + ["lib"] do |t| +file "ext/Makefile" => ["ext/extconf.rb", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp"] + SOURCES do |t| + Dir.chdir "ext" do + ruby "extconf.rb" + end +end + +file SO_FILE => "ext/Makefile" do |t| Dir.chdir "ext" do - sh "ruby extconf.rb" sh "make" end - mv "ext/#{LIB_NAME}", t.name end CLEAN.include LIB_FILE +directory "lib" +file LIB_FILE => [SO_FILE, "lib"] do |t| + copy t.source, t.name +end + Rake::TestTask.new do |t| t.test_files = FileList["tests/test_*.rb"] end diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 3b54a4a..5e98b39 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -2,6 +2,9 @@ require 'mkmf' # need to use c++ compiler flags $CXXFLAGS << ' -std=c++11' + +$LDFLAGS << ' -lstdc++' + # Set to true when building binary gems if enable_config('static-stdlib', false) $LDFLAGS << ' -static-libgcc -static-libstdc++' @@ -12,34 +15,6 @@ if enable_config('march-tune-native', false) $CXXFLAGS << ' -march=native -mtune=native' end -def with_disabling_unsupported_files - disabled_files = [] - - unless $GGML_METAL - disabled_files << 'ggml-metal.h' << 'ggml-metal.m' - end - - unless $GGML_METAL_EMBED_LIBRARY - disabled_files << 'ggml-metal.metal' - end - - unless $OBJ_ALL&.include? 'ggml-blas.o' - disabled_files << 'ggml-blas.h' << 'ggml-blas.cpp' - end - - disabled_files.filter! {|file| File.exist? file} - - disabled_files.each do |file| - File.rename file, "#{file}.disabled" - end - - yield - - disabled_files.each do |file| - File.rename "#{file}.disabled", file - end -end - if ENV['WHISPER_METAL'] $GGML_METAL ||= true $DEPRECATE_WARNING ||= true @@ -66,10 +41,10 @@ $MK_CXXFLAGS = '-std=c++11 -fPIC' $MK_NVCCFLAGS = '-std=c++11' $MK_LDFLAGS = '' -$OBJ_GGML = '' -$OBJ_WHISPER = '' -$OBJ_COMMON = '' -$OBJ_SDL = '' +$OBJ_GGML = [] +$OBJ_WHISPER = [] +$OBJ_COMMON = [] +$OBJ_SDL = [] $MK_CPPFLAGS << ' -D_XOPEN_SOURCE=600' @@ -152,7 +127,7 @@ unless ENV['GGML_NO_ACCELERATE'] $MK_CPPFLAGS << ' -DACCELERATE_NEW_LAPACK' $MK_CPPFLAGS << ' -DACCELERATE_LAPACK_ILP64' $MK_LDFLAGS << ' -framework Accelerate' - $OBJ_GGML << ' ggml-blas.o' + $OBJ_GGML << 'ggml-blas.o' end end @@ -160,20 +135,20 @@ if ENV['GGML_OPENBLAS'] $MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas`.chomp}" $MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas)`.chomp}" $MK_LDFLAGS << " #{`pkg-config --libs openblas`}" - $OBJ_GGML << ' ggml-blas.o' + $OBJ_GGML << 'ggml-blas.o' end if ENV['GGML_OPENBLAS64'] $MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas64`.chomp}" $MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas64)`.chomp}" $MK_LDFLAGS << " #{`pkg-config --libs openblas64`}" - $OBJ_GGML << ' ggml-blas.o' + $OBJ_GGML << 'ggml-blas.o' end if $GGML_METAL $MK_CPPFLAGS << ' -DGGML_USE_METAL' $MK_LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit' - $OBJ_GGML << ' ggml-metal.o' + $OBJ_GGML << 'ggml-metal.o' if ENV['GGML_METAL_NDEBUG'] $MK_CPPFLAGS << ' -DGGML_METAL_NDEBUG' @@ -181,21 +156,22 @@ if $GGML_METAL if $GGML_METAL_EMBED_LIBRARY $MK_CPPFLAGS << ' -DGGML_METAL_EMBED_LIBRARY' - $OBJ_GGML << ' ggml-metal-embed.o' + $OBJ_GGML << 'ggml-metal-embed.o' end end $OBJ_GGML << - ' ggml.o' << - ' ggml-alloc.o' << - ' ggml-backend.o' << - ' ggml-quants.o' << - ' ggml-aarch64.o' + 'ggml.o' << + 'ggml-alloc.o' << + 'ggml-backend.o' << + 'ggml-quants.o' << + 'ggml-aarch64.o' $OBJ_WHISPER << - ' whisper.o' + 'whisper.o' -$OBJ_ALL = "#{$OBJ_GGML} #{$OBJ_WHISPER} #{$OBJ_COMMON} #{$OBJ_SDL}" +$objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL +$objs << "ruby_whisper.o" $CPPFLAGS = "#{$MK_CPPFLAGS} #{$CPPFLAGS}" $CFLAGS = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}" @@ -204,26 +180,13 @@ $CXXFLAGS = "#{$BASE_CXXFLAGS} #{$HOST_CXXFLAGS} #{$GF_CXXFLAGS} #{$CPPFLAGS}" $NVCCFLAGS = "#{$MK_NVCCFLAGS} #{$NVCCFLAGS}" $LDFLAGS = "#{$MK_LDFLAGS} #{$LDFLAGS}" -if $GGML_METAL_EMBED_LIBRARY - File.write 'depend', "$(OBJS): $(OBJS) ggml-metal-embed.o\n" -end - -with_disabling_unsupported_files do - - create_makefile('whisper') - -end +create_makefile('whisper') File.open 'Makefile', 'a' do |file| file.puts 'include get-flags.mk' if $GGML_METAL if $GGML_METAL_EMBED_LIBRARY - # mkmf determines object files to compile dependent on existing *.{c,cpp,m} files - # but ggml-metal-embed.c doesn't exist on creating Makefile. - file.puts "objs := $(OBJS)" - file.puts "OBJS = $(objs) 'ggml-metal-embed.o'" - file.puts 'include metal-embed.mk' end end diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 2c720e9..3f528ee 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -41,6 +41,8 @@ static ID id_call; static ID id___method__; static ID id_to_enum; +static bool is_log_callback_finalized = false; + /* * call-seq: * lang_max_id -> Integer @@ -88,6 +90,39 @@ static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) { return rb_str_new2(str_full); } +static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { + is_log_callback_finalized = true; + return Qnil; +} + +/* + * call-seq: + * log_set ->(level, buffer, user_data) { ... }, user_data -> nil + */ +static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) { + VALUE old_callback = rb_iv_get(self, "@log_callback"); + if (!NIL_P(old_callback)) { + rb_undefine_finalizer(old_callback); + } + + rb_iv_set(self, "@log_callback", log_callback); + rb_iv_set(self, "@user_data", user_data); + + VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); + rb_define_finalizer(log_callback, finalize_log_callback); + + whisper_log_set([](ggml_log_level level, const char * buffer, void * user_data) { + if (is_log_callback_finalized) { + return; + } + VALUE log_callback = rb_iv_get(mWhisper, "@log_callback"); + VALUE udata = rb_iv_get(mWhisper, "@user_data"); + rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata); + }, nullptr); + + return Qnil; +} + static void ruby_whisper_free(ruby_whisper *rw) { if (rw->context) { whisper_free(rw->context); @@ -389,6 +424,126 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { return self; } +/* + * call-seq: + * model_n_vocab -> Integer + */ +VALUE ruby_whisper_model_n_vocab(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_vocab(rw->context)); +} + +/* + * call-seq: + * model_n_audio_ctx -> Integer + */ +VALUE ruby_whisper_model_n_audio_ctx(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_ctx(rw->context)); +} + +/* + * call-seq: + * model_n_audio_state -> Integer + */ +VALUE ruby_whisper_model_n_audio_state(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_state(rw->context)); +} + +/* + * call-seq: + * model_n_audio_head -> Integer + */ +VALUE ruby_whisper_model_n_audio_head(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_head(rw->context)); +} + +/* + * call-seq: + * model_n_audio_layer -> Integer + */ +VALUE ruby_whisper_model_n_audio_layer(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_layer(rw->context)); +} + +/* + * call-seq: + * model_n_text_ctx -> Integer + */ +VALUE ruby_whisper_model_n_text_ctx(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_ctx(rw->context)); +} + +/* + * call-seq: + * model_n_text_state -> Integer + */ +VALUE ruby_whisper_model_n_text_state(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_state(rw->context)); +} + +/* + * call-seq: + * model_n_text_head -> Integer + */ +VALUE ruby_whisper_model_n_text_head(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_head(rw->context)); +} + +/* + * call-seq: + * model_n_text_layer -> Integer + */ +VALUE ruby_whisper_model_n_text_layer(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_layer(rw->context)); +} + +/* + * call-seq: + * model_n_mels -> Integer + */ +VALUE ruby_whisper_model_n_mels(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_mels(rw->context)); +} + +/* + * call-seq: + * model_ftype -> Integer + */ +VALUE ruby_whisper_model_ftype(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_ftype(rw->context)); +} + +/* + * call-seq: + * model_type -> String + */ +VALUE ruby_whisper_model_type(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return rb_str_new2(whisper_model_type_readable(rw->context)); +} + /* * Number of segments. * @@ -1015,7 +1170,12 @@ typedef struct { int index; } ruby_whisper_segment; +typedef struct { + VALUE context; +} ruby_whisper_model; + VALUE cSegment; +VALUE cModel; static void rb_whisper_segment_mark(ruby_whisper_segment *rws) { rb_gc_mark(rws->context); @@ -1188,6 +1348,176 @@ static VALUE ruby_whisper_segment_get_text(VALUE self) { return rb_str_new2(text); } +static void rb_whisper_model_mark(ruby_whisper_model *rwm) { + rb_gc_mark(rwm->context); +} + +static VALUE ruby_whisper_model_allocate(VALUE klass) { + ruby_whisper_model *rwm; + rwm = ALLOC(ruby_whisper_model); + return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm); +} + +static VALUE rb_whisper_model_initialize(VALUE context) { + ruby_whisper_model *rwm; + const VALUE model = ruby_whisper_model_allocate(cModel); + Data_Get_Struct(model, ruby_whisper_model, rwm); + rwm->context = context; + return model; +}; + +/* + * call-seq: + * model -> Whisper::Model + */ +static VALUE ruby_whisper_get_model(VALUE self) { + return rb_whisper_model_initialize(self); +} + +/* + * call-seq: + * n_vocab -> Integer + */ +static VALUE ruby_whisper_c_model_n_vocab(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_vocab(rw->context)); +} + +/* + * call-seq: + * n_audio_ctx -> Integer + */ +static VALUE ruby_whisper_c_model_n_audio_ctx(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_ctx(rw->context)); +} + +/* + * call-seq: + * n_audio_state -> Integer + */ +static VALUE ruby_whisper_c_model_n_audio_state(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_state(rw->context)); +} + +/* + * call-seq: + * n_audio_head -> Integer + */ +static VALUE ruby_whisper_c_model_n_audio_head(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_head(rw->context)); +} + +/* + * call-seq: + * n_audio_layer -> Integer + */ +static VALUE ruby_whisper_c_model_n_audio_layer(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_layer(rw->context)); +} + +/* + * call-seq: + * n_text_ctx -> Integer + */ +static VALUE ruby_whisper_c_model_n_text_ctx(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_ctx(rw->context)); +} + +/* + * call-seq: + * n_text_state -> Integer + */ +static VALUE ruby_whisper_c_model_n_text_state(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_state(rw->context)); +} + +/* + * call-seq: + * n_text_head -> Integer + */ +static VALUE ruby_whisper_c_model_n_text_head(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_head(rw->context)); +} + +/* + * call-seq: + * n_text_layer -> Integer + */ +static VALUE ruby_whisper_c_model_n_text_layer(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_layer(rw->context)); +} + +/* + * call-seq: + * n_mels -> Integer + */ +static VALUE ruby_whisper_c_model_n_mels(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_mels(rw->context)); +} + +/* + * call-seq: + * ftype -> Integer + */ +static VALUE ruby_whisper_c_model_ftype(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_ftype(rw->context)); +} + +/* + * call-seq: + * type -> String + */ +static VALUE ruby_whisper_c_model_type(VALUE self) { + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return rb_str_new2(whisper_model_type_readable(rw->context)); +} + void Init_whisper() { id_to_s = rb_intern("to_s"); id_call = rb_intern("call"); @@ -1198,15 +1528,36 @@ void Init_whisper() { cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); cParams = rb_define_class_under(mWhisper, "Params", rb_cObject); + rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE)); + rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO)); + rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN)); + rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR)); + rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG)); + rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT)); + rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0); rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1); rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); + rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2); + rb_define_singleton_method(mWhisper, "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1); rb_define_alloc_func(cContext, ruby_whisper_allocate); rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); + rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0); + rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0); + rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0); + rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0); + rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0); + rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0); + rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0); + rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0); + rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0); + rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0); + rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0); + rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0); rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0); rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0); rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1); @@ -1284,6 +1635,22 @@ void Init_whisper() { rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0); rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0); rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0); + + cModel = rb_define_class_under(mWhisper, "Model", rb_cObject); + rb_define_alloc_func(cModel, ruby_whisper_model_allocate); + rb_define_method(cContext, "model", ruby_whisper_get_model, 0); + rb_define_method(cModel, "n_vocab", ruby_whisper_c_model_n_vocab, 0); + rb_define_method(cModel, "n_audio_ctx", ruby_whisper_c_model_n_audio_ctx, 0); + rb_define_method(cModel, "n_audio_state", ruby_whisper_c_model_n_audio_state, 0); + rb_define_method(cModel, "n_audio_head", ruby_whisper_c_model_n_audio_head, 0); + rb_define_method(cModel, "n_audio_layer", ruby_whisper_c_model_n_audio_layer, 0); + rb_define_method(cModel, "n_text_ctx", ruby_whisper_c_model_n_text_ctx, 0); + rb_define_method(cModel, "n_text_state", ruby_whisper_c_model_n_text_state, 0); + rb_define_method(cModel, "n_text_head", ruby_whisper_c_model_n_text_head, 0); + rb_define_method(cModel, "n_text_layer", ruby_whisper_c_model_n_text_layer, 0); + rb_define_method(cModel, "n_mels", ruby_whisper_c_model_n_mels, 0); + rb_define_method(cModel, "ftype", ruby_whisper_c_model_ftype, 0); + rb_define_method(cModel, "type", ruby_whisper_c_model_type, 0); } #ifdef __cplusplus } diff --git a/bindings/ruby/tests/helper.rb b/bindings/ruby/tests/helper.rb new file mode 100644 index 0000000..4172ebc --- /dev/null +++ b/bindings/ruby/tests/helper.rb @@ -0,0 +1,7 @@ +require "test/unit" +require "whisper" + +class TestBase < Test::Unit::TestCase + MODEL = File.join(__dir__, "..", "..", "..", "models", "ggml-base.en.bin") + AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav") +end diff --git a/bindings/ruby/tests/test_model.rb b/bindings/ruby/tests/test_model.rb new file mode 100644 index 0000000..2310522 --- /dev/null +++ b/bindings/ruby/tests/test_model.rb @@ -0,0 +1,44 @@ +require_relative "helper" + +class TestModel < TestBase + def test_model + whisper = Whisper::Context.new(MODEL) + assert_instance_of Whisper::Model, whisper.model + end + + def test_attributes + whisper = Whisper::Context.new(MODEL) + model = whisper.model + + assert_equal 51864, model.n_vocab + assert_equal 1500, model.n_audio_ctx + assert_equal 512, model.n_audio_state + assert_equal 8, model.n_audio_head + assert_equal 6, model.n_audio_layer + assert_equal 448, model.n_text_ctx + assert_equal 512, model.n_text_state + assert_equal 8, model.n_text_head + assert_equal 6, model.n_text_layer + assert_equal 80, model.n_mels + assert_equal 1, model.ftype + assert_equal "base", model.type + end + + def test_gc + model = Whisper::Context.new(MODEL).model + GC.start + + assert_equal 51864, model.n_vocab + assert_equal 1500, model.n_audio_ctx + assert_equal 512, model.n_audio_state + assert_equal 8, model.n_audio_head + assert_equal 6, model.n_audio_layer + assert_equal 448, model.n_text_ctx + assert_equal 512, model.n_text_state + assert_equal 8, model.n_text_head + assert_equal 6, model.n_text_layer + assert_equal 80, model.n_mels + assert_equal 1, model.ftype + assert_equal "base", model.type + end +end diff --git a/bindings/ruby/tests/test_package.rb b/bindings/ruby/tests/test_package.rb index f51eab5..9c47870 100644 --- a/bindings/ruby/tests/test_package.rb +++ b/bindings/ruby/tests/test_package.rb @@ -1,9 +1,9 @@ -require 'test/unit' +require_relative "helper" require 'tempfile' require 'tmpdir' require 'shellwords' -class TestPackage < Test::Unit::TestCase +class TestPackage < TestBase def test_build Tempfile.create do |file| assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) diff --git a/bindings/ruby/tests/test_params.rb b/bindings/ruby/tests/test_params.rb index 6386049..bf73fd6 100644 --- a/bindings/ruby/tests/test_params.rb +++ b/bindings/ruby/tests/test_params.rb @@ -1,7 +1,6 @@ -require 'test/unit' -require 'whisper' +require_relative "helper" -class TestParams < Test::Unit::TestCase +class TestParams < TestBase def setup @params = Whisper::Params.new end diff --git a/bindings/ruby/tests/test_segment.rb b/bindings/ruby/tests/test_segment.rb index f3ebc0e..8129ae5 100644 --- a/bindings/ruby/tests/test_segment.rb +++ b/bindings/ruby/tests/test_segment.rb @@ -1,18 +1,14 @@ -require "test/unit" -require "whisper" - -class TestSegment < Test::Unit::TestCase - TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) +require_relative "helper" +class TestSegment < TestBase class << self attr_reader :whisper def startup - @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + @whisper = Whisper::Context.new(TestBase::MODEL) params = Whisper::Params.new params.print_timestamps = false - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - @whisper.transcribe(jfk, params) + @whisper.transcribe(TestBase::AUDIO, params) end end @@ -60,7 +56,7 @@ class TestSegment < Test::Unit::TestCase end index += 1 end - whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params) + whisper.transcribe(AUDIO, params) assert_equal 0, seg.start_time assert_match /ask not what your country can do for you, ask what you can do for your country/, seg.text end @@ -76,7 +72,7 @@ class TestSegment < Test::Unit::TestCase assert_same seg, segment return end - whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params) + whisper.transcribe(AUDIO, params) end private diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 5ebb815..e37e24c 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -1,20 +1,20 @@ -require 'whisper' -require 'test/unit' +require_relative "helper" +require "stringio" -class TestWhisper < Test::Unit::TestCase - TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) +# Exists to detect memory-related bug +Whisper.log_set ->(level, buffer, user_data) {}, nil +class TestWhisper < TestBase def setup @params = Whisper::Params.new end def test_whisper - @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + @whisper = Whisper::Context.new(MODEL) params = Whisper::Params.new params.print_timestamps = false - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - @whisper.transcribe(jfk, params) {|text| + @whisper.transcribe(AUDIO, params) {|text| assert_match /ask not what your country can do for you, ask what you can do for your country/, text } end @@ -24,11 +24,10 @@ class TestWhisper < Test::Unit::TestCase attr_reader :whisper def startup - @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + @whisper = Whisper::Context.new(TestBase::MODEL) params = Whisper::Params.new params.print_timestamps = false - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - @whisper.transcribe(jfk, params) + @whisper.transcribe(TestBase::AUDIO, params) end end @@ -96,4 +95,33 @@ class TestWhisper < Test::Unit::TestCase Whisper.lang_str_full(Whisper.lang_max_id + 1) end end + + def test_log_set + user_data = Object.new + logs = [] + log_callback = ->(level, buffer, udata) { + logs << [level, buffer, udata] + } + Whisper.log_set log_callback, user_data + Whisper::Context.new(MODEL) + + assert logs.length > 30 + logs.each do |log| + assert_equal Whisper::LOG_LEVEL_INFO, log[0] + assert_same user_data, log[2] + end + end + + def test_log_suppress + stderr = $stderr + Whisper.log_set ->(level, buffer, user_data) { + # do nothing + }, nil + dev = StringIO.new("") + $stderr = dev + Whisper::Context.new(MODEL) + assert_empty dev.string + ensure + $stderr = stderr + end end