diff --git a/cpp/whisper.cpp b/cpp/whisper.cpp index 3ceab2b..fa58cef 100644 --- a/cpp/whisper.cpp +++ b/cpp/whisper.cpp @@ -770,6 +770,9 @@ struct whisper_context { whisper_state * state = nullptr; std::string path_model; // populated by whisper_init_from_file() +#ifdef WHISPER_USE_COREML + bool load_coreml = true; +#endif }; static void whisper_default_log(const char * text) { @@ -2854,6 +2857,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { } #ifdef WHISPER_USE_COREML +if (ctx->load_coreml) { // Not in correct layer for easy patch const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); @@ -2869,6 +2873,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { } else { log("%s: Core ML model loaded\n", __func__); } +} #endif state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); @@ -2989,6 +2994,23 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { return state; } +#ifdef WHISPER_USE_COREML +struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) { + whisper_context * ctx = whisper_init_from_file_no_state(path_model); + if (!ctx) { + return nullptr; + } + ctx->load_coreml = false; + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} +#endif + int whisper_ctx_init_openvino_encoder( struct whisper_context * ctx, const char * model_path, diff --git a/cpp/whisper.h b/cpp/whisper.h index eab3f22..5d2b013 100644 --- a/cpp/whisper.h +++ b/cpp/whisper.h @@ -99,6 +99,9 @@ extern "C" { // Various functions for loading a ggml whisper model. // Allocate (almost) all memory needed for the model. // Return NULL on failure +#ifdef WHISPER_USE_COREML + WHISPER_API struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model); +#endif WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model); WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); diff --git a/example/ios/Podfile b/example/ios/Podfile index ae64303..e1c5cf6 100644 --- a/example/ios/Podfile +++ b/example/ios/Podfile @@ -25,8 +25,9 @@ ENV['RCT_NEW_ARCH_ENABLED'] = '1' target 'RNWhisperExample' do # Tip: You can use RNWHISPER_DISABLE_COREML = '1' to disable CoreML support. - ENV['RNWHISPER_DISABLE_COREML'] = '1' # TEMP - ENV['RNWHISPER_DISABLE_METAL'] = '0' # TEMP + ENV['RNWHISPER_DISABLE_COREML'] = '0' + + ENV['RNWHISPER_ENABLE_METAL'] = '0' # TODO config = use_native_modules! diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index c28bc35..f476fd2 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -764,7 +764,7 @@ PODS: - SSZipArchive (~> 2.2) - SocketRocket (0.6.0) - SSZipArchive (2.4.3) - - whisper-rn (0.3.9): + - whisper-rn (0.4.0-rc.0): - RCT-Folly - RCTRequired - RCTTypeSafety @@ -1006,10 +1006,10 @@ SPEC CHECKSUMS: RNZipArchive: ef9451b849c45a29509bf44e65b788829ab07801 SocketRocket: fccef3f9c5cedea1353a9ef6ada904fde10d6608 SSZipArchive: fe6a26b2a54d5a0890f2567b5cc6de5caa600aef - whisper-rn: b3c5abf27f09df7c9d5d089ad1275c2ec20a23aa + whisper-rn: a333c75700c2d031cecf12db9255459b01602d56 Yoga: f7decafdc5e8c125e6fa0da38a687e35238420fa YogaKit: f782866e155069a2cca2517aafea43200b01fd5a -PODFILE CHECKSUM: 37f5c1045c7d04c6e5332174cca5f32f528700cf +PODFILE CHECKSUM: a78cf54fa529c6dc4b44aaf32b861fdf1245919a COCOAPODS: 1.11.3 diff --git a/ios/RNWhisper.mm b/ios/RNWhisper.mm index 512e5d6..f9699af 100644 --- a/ios/RNWhisper.mm +++ b/ios/RNWhisper.mm @@ -48,6 +48,7 @@ - (NSDictionary *)constantsToExport NSString *modelPath = [modelOptions objectForKey:@"filePath"]; BOOL isBundleAsset = [[modelOptions objectForKey:@"isBundleAsset"] boolValue]; + BOOL useCoreMLIos = [[modelOptions objectForKey:@"useCoreMLIos"] boolValue]; // For support debug assets in development mode BOOL downloadCoreMLAssets = [[modelOptions objectForKey:@"downloadCoreMLAssets"] boolValue]; @@ -75,6 +76,7 @@ - (NSDictionary *)constantsToExport RNWhisperContext *context = [RNWhisperContext initWithModelPath:path contextId:contextId + noCoreML:!useCoreMLIos ]; if ([context getContext] == NULL) { reject(@"whisper_cpp_error", @"Failed to load the model", nil); diff --git a/ios/RNWhisperContext.h b/ios/RNWhisperContext.h index 0fbb7a8..d96d645 100644 --- a/ios/RNWhisperContext.h +++ b/ios/RNWhisperContext.h @@ -42,7 +42,7 @@ typedef struct { RNWhisperContextRecordState recordState; } -+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId; ++ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId noCoreML:(BOOL)noCoreML; - (struct whisper_context *)getContext; - (dispatch_queue_t)getDispatchQueue; - (OSStatus)transcribeRealtime:(int)jobId diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 2caf82c..937c3e1 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -6,10 +6,18 @@ @implementation RNWhisperContext -+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId { ++ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId noCoreML:(BOOL)noCoreML { RNWhisperContext *context = [[RNWhisperContext alloc] init]; context->contextId = contextId; +#ifdef WHISPER_USE_COREML + if (noCoreML) { + context->ctx = whisper_init_from_file_no_coreml([modelPath UTF8String]); + } else { + context->ctx = whisper_init_from_file([modelPath UTF8String]); + } +#else context->ctx = whisper_init_from_file([modelPath UTF8String]); +#endif context->dQueue = dispatch_queue_create( [[NSString stringWithFormat:@"RNWhisperContext-%d", contextId] UTF8String], DISPATCH_QUEUE_SERIAL diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index 4e5f779..36e0a88 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -52,6 +52,8 @@ yarn example # Apply patch patch -p0 -d ./cpp < ./scripts/ggml-metal.m.patch +patch -p0 -d ./cpp < ./scripts/whisper.h.patch +patch -p0 -d ./cpp < ./scripts/whisper.cpp.patch patch -p0 -d ./cpp/coreml < ./scripts/whisper-encoder.mm.patch # Download model for example diff --git a/scripts/whisper.cpp.patch b/scripts/whisper.cpp.patch new file mode 100644 index 0000000..a41f652 --- /dev/null +++ b/scripts/whisper.cpp.patch @@ -0,0 +1,53 @@ +--- whisper.cpp.orig 2023-10-12 11:44:51 ++++ whisper.cpp 2023-10-12 11:43:31 +@@ -770,6 +770,9 @@ + whisper_state * state = nullptr; + + std::string path_model; // populated by whisper_init_from_file() ++#ifdef WHISPER_USE_COREML ++ bool load_coreml = true; ++#endif + }; + + static void whisper_default_log(const char * text) { +@@ -2854,6 +2857,7 @@ + } + + #ifdef WHISPER_USE_COREML ++if (ctx->load_coreml) { // Not in correct layer for easy patch + const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); + + log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); +@@ -2869,6 +2873,7 @@ + } else { + log("%s: Core ML model loaded\n", __func__); + } ++} + #endif + + state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); +@@ -2987,7 +2992,24 @@ + state->rng = std::mt19937(0); + + return state; ++} ++ ++#ifdef WHISPER_USE_COREML ++struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) { ++ whisper_context * ctx = whisper_init_from_file_no_state(path_model); ++ if (!ctx) { ++ return nullptr; ++ } ++ ctx->load_coreml = false; ++ ctx->state = whisper_init_state(ctx); ++ if (!ctx->state) { ++ whisper_free(ctx); ++ return nullptr; ++ } ++ ++ return ctx; + } ++#endif + + int whisper_ctx_init_openvino_encoder( + struct whisper_context * ctx, diff --git a/scripts/whisper.h.patch b/scripts/whisper.h.patch new file mode 100644 index 0000000..88fcd3d --- /dev/null +++ b/scripts/whisper.h.patch @@ -0,0 +1,12 @@ +--- whisper.h.orig 2023-10-12 10:41:41 ++++ whisper.h 2023-10-12 10:38:11 +@@ -99,6 +99,9 @@ + // Various functions for loading a ggml whisper model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure ++#ifdef WHISPER_USE_COREML ++ WHISPER_API struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model); ++#endif + WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model); + WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); + WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); diff --git a/src/NativeRNWhisper.ts b/src/NativeRNWhisper.ts index 79f27f9..18c7490 100644 --- a/src/NativeRNWhisper.ts +++ b/src/NativeRNWhisper.ts @@ -52,6 +52,7 @@ export type CoreMLAsset = { type NativeContextOptions = { filePath: string, isBundleAsset: boolean, + useCoreMLIos?: boolean, downloadCoreMLAssets?: boolean, coreMLAssets?: CoreMLAsset[], } diff --git a/src/index.ts b/src/index.ts index 6d5785a..c9fe796 100644 --- a/src/index.ts +++ b/src/index.ts @@ -438,6 +438,8 @@ export type ContextOptions = { } /** Is the file path a bundle asset for pure string filePath */ isBundleAsset?: boolean + /** Prefer to use Core ML model if exists. If set to false, even if the Core ML model exists, it will not be used. */ + useCoreMLIos?: boolean } const coreMLModelAssetPaths = [ @@ -451,6 +453,7 @@ export async function initWhisper({ filePath, coreMLModelAsset, isBundleAsset, + useCoreMLIos = true, }: ContextOptions): Promise { let path = '' let coreMLAssets: CoreMLAsset[] | undefined @@ -499,6 +502,7 @@ export async function initWhisper({ const id = await RNWhisper.initContext({ filePath: path, isBundleAsset: !!isBundleAsset, + useCoreMLIos, // Only development mode need download Core ML model assets (from packager server) downloadCoreMLAssets: __DEV__ && !!coreMLAssets, coreMLAssets,