Skip to content

Commit

Permalink
feat(ios): make option for enable/disable Core ML (#145)
Browse files Browse the repository at this point in the history
* feat(ios): make option for enable/disable Core ML

* fix(ios): use whisper_init_from_file_no_coreml if defined WHISPER_USE_COREML

* fix(example): update env

* fix(cpp): patch
  • Loading branch information
jhen0409 committed Oct 12, 2023
1 parent f957d9d commit 66e9a0c
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 7 deletions.
22 changes: 22 additions & 0 deletions cpp/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,9 @@ struct whisper_context {
whisper_state * state = nullptr;

std::string path_model; // populated by whisper_init_from_file()
#ifdef WHISPER_USE_COREML
bool load_coreml = true;
#endif
};

static void whisper_default_log(const char * text) {
Expand Down Expand Up @@ -2854,6 +2857,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
}

#ifdef WHISPER_USE_COREML
if (ctx->load_coreml) { // Not in correct layer for easy patch
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);

log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
Expand All @@ -2869,6 +2873,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
} else {
log("%s: Core ML model loaded\n", __func__);
}
}
#endif

state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
Expand Down Expand Up @@ -2989,6 +2994,23 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
return state;
}

#ifdef WHISPER_USE_COREML
struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) {
whisper_context * ctx = whisper_init_from_file_no_state(path_model);
if (!ctx) {
return nullptr;
}
ctx->load_coreml = false;
ctx->state = whisper_init_state(ctx);
if (!ctx->state) {
whisper_free(ctx);
return nullptr;
}

return ctx;
}
#endif

int whisper_ctx_init_openvino_encoder(
struct whisper_context * ctx,
const char * model_path,
Expand Down
3 changes: 3 additions & 0 deletions cpp/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ extern "C" {
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
#ifdef WHISPER_USE_COREML
WHISPER_API struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model);
#endif
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
Expand Down
5 changes: 3 additions & 2 deletions example/ios/Podfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ ENV['RCT_NEW_ARCH_ENABLED'] = '1'

target 'RNWhisperExample' do
# Tip: You can use RNWHISPER_DISABLE_COREML = '1' to disable CoreML support.
ENV['RNWHISPER_DISABLE_COREML'] = '1' # TEMP
ENV['RNWHISPER_DISABLE_METAL'] = '0' # TEMP
ENV['RNWHISPER_DISABLE_COREML'] = '0'

ENV['RNWHISPER_ENABLE_METAL'] = '0' # TODO

config = use_native_modules!

Expand Down
6 changes: 3 additions & 3 deletions example/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ PODS:
- SSZipArchive (~> 2.2)
- SocketRocket (0.6.0)
- SSZipArchive (2.4.3)
- whisper-rn (0.3.9):
- whisper-rn (0.4.0-rc.0):
- RCT-Folly
- RCTRequired
- RCTTypeSafety
Expand Down Expand Up @@ -1006,10 +1006,10 @@ SPEC CHECKSUMS:
RNZipArchive: ef9451b849c45a29509bf44e65b788829ab07801
SocketRocket: fccef3f9c5cedea1353a9ef6ada904fde10d6608
SSZipArchive: fe6a26b2a54d5a0890f2567b5cc6de5caa600aef
whisper-rn: b3c5abf27f09df7c9d5d089ad1275c2ec20a23aa
whisper-rn: a333c75700c2d031cecf12db9255459b01602d56
Yoga: f7decafdc5e8c125e6fa0da38a687e35238420fa
YogaKit: f782866e155069a2cca2517aafea43200b01fd5a

PODFILE CHECKSUM: 37f5c1045c7d04c6e5332174cca5f32f528700cf
PODFILE CHECKSUM: a78cf54fa529c6dc4b44aaf32b861fdf1245919a

COCOAPODS: 1.11.3
2 changes: 2 additions & 0 deletions ios/RNWhisper.mm
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ - (NSDictionary *)constantsToExport

NSString *modelPath = [modelOptions objectForKey:@"filePath"];
BOOL isBundleAsset = [[modelOptions objectForKey:@"isBundleAsset"] boolValue];
BOOL useCoreMLIos = [[modelOptions objectForKey:@"useCoreMLIos"] boolValue];

// For support debug assets in development mode
BOOL downloadCoreMLAssets = [[modelOptions objectForKey:@"downloadCoreMLAssets"] boolValue];
Expand Down Expand Up @@ -75,6 +76,7 @@ - (NSDictionary *)constantsToExport
RNWhisperContext *context = [RNWhisperContext
initWithModelPath:path
contextId:contextId
noCoreML:!useCoreMLIos
];
if ([context getContext] == NULL) {
reject(@"whisper_cpp_error", @"Failed to load the model", nil);
Expand Down
2 changes: 1 addition & 1 deletion ios/RNWhisperContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ typedef struct {
RNWhisperContextRecordState recordState;
}

+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId;
+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId noCoreML:(BOOL)noCoreML;
- (struct whisper_context *)getContext;
- (dispatch_queue_t)getDispatchQueue;
- (OSStatus)transcribeRealtime:(int)jobId
Expand Down
10 changes: 9 additions & 1 deletion ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@

@implementation RNWhisperContext

+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId {
+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId noCoreML:(BOOL)noCoreML {
RNWhisperContext *context = [[RNWhisperContext alloc] init];
context->contextId = contextId;
#ifdef WHISPER_USE_COREML
if (noCoreML) {
context->ctx = whisper_init_from_file_no_coreml([modelPath UTF8String]);
} else {
context->ctx = whisper_init_from_file([modelPath UTF8String]);
}
#else
context->ctx = whisper_init_from_file([modelPath UTF8String]);
#endif
context->dQueue = dispatch_queue_create(
[[NSString stringWithFormat:@"RNWhisperContext-%d", contextId] UTF8String],
DISPATCH_QUEUE_SERIAL
Expand Down
2 changes: 2 additions & 0 deletions scripts/bootstrap.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ yarn example

# Apply patch
patch -p0 -d ./cpp < ./scripts/ggml-metal.m.patch
patch -p0 -d ./cpp < ./scripts/whisper.h.patch
patch -p0 -d ./cpp < ./scripts/whisper.cpp.patch
patch -p0 -d ./cpp/coreml < ./scripts/whisper-encoder.mm.patch

# Download model for example
Expand Down
53 changes: 53 additions & 0 deletions scripts/whisper.cpp.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
--- whisper.cpp.orig 2023-10-12 11:44:51
+++ whisper.cpp 2023-10-12 11:43:31
@@ -770,6 +770,9 @@
whisper_state * state = nullptr;

std::string path_model; // populated by whisper_init_from_file()
+#ifdef WHISPER_USE_COREML
+ bool load_coreml = true;
+#endif
};

static void whisper_default_log(const char * text) {
@@ -2854,6 +2857,7 @@
}

#ifdef WHISPER_USE_COREML
+if (ctx->load_coreml) { // Not in correct layer for easy patch
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);

log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
@@ -2869,6 +2873,7 @@
} else {
log("%s: Core ML model loaded\n", __func__);
}
+}
#endif

state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
@@ -2987,7 +2992,24 @@
state->rng = std::mt19937(0);

return state;
+}
+
+#ifdef WHISPER_USE_COREML
+struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) {
+ whisper_context * ctx = whisper_init_from_file_no_state(path_model);
+ if (!ctx) {
+ return nullptr;
+ }
+ ctx->load_coreml = false;
+ ctx->state = whisper_init_state(ctx);
+ if (!ctx->state) {
+ whisper_free(ctx);
+ return nullptr;
+ }
+
+ return ctx;
}
+#endif

int whisper_ctx_init_openvino_encoder(
struct whisper_context * ctx,
12 changes: 12 additions & 0 deletions scripts/whisper.h.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
--- whisper.h.orig 2023-10-12 10:41:41
+++ whisper.h 2023-10-12 10:38:11
@@ -99,6 +99,9 @@
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
+#ifdef WHISPER_USE_COREML
+ WHISPER_API struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model);
+#endif
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
1 change: 1 addition & 0 deletions src/NativeRNWhisper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ export type CoreMLAsset = {
type NativeContextOptions = {
filePath: string,
isBundleAsset: boolean,
useCoreMLIos?: boolean,
downloadCoreMLAssets?: boolean,
coreMLAssets?: CoreMLAsset[],
}
Expand Down
4 changes: 4 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ export type ContextOptions = {
}
/** Is the file path a bundle asset for pure string filePath */
isBundleAsset?: boolean
/** Prefer to use Core ML model if exists. If set to false, even if the Core ML model exists, it will not be used. */
useCoreMLIos?: boolean
}

const coreMLModelAssetPaths = [
Expand All @@ -451,6 +453,7 @@ export async function initWhisper({
filePath,
coreMLModelAsset,
isBundleAsset,
useCoreMLIos = true,
}: ContextOptions): Promise<WhisperContext> {
let path = ''
let coreMLAssets: CoreMLAsset[] | undefined
Expand Down Expand Up @@ -499,6 +502,7 @@ export async function initWhisper({
const id = await RNWhisper.initContext({
filePath: path,
isBundleAsset: !!isBundleAsset,
useCoreMLIos,
// Only development mode need download Core ML model assets (from packager server)
downloadCoreMLAssets: __DEV__ && !!coreMLAssets,
coreMLAssets,
Expand Down

0 comments on commit 66e9a0c

Please sign in to comment.