Skip to content

Commit ae0d0ef

Browse files
metascroyfacebook-github-bot
authored andcommitted
Add CoreML backend options with compute unit configuration (#18369)
Summary: This diff adds type-safe backend options support to the CoreML delegate, allowing users to configure compute units (CPU, GPU, Neural Engine) at model load time using the new `LoadBackendOptionsMap` infrastructure. Key changes: - Added `LoadOptionsBuilder` class in `coreml_backend_options.h` providing a fluent API for setting CoreML options with compile-time type safety - `ComputeUnit` enum nested inside the builder for type-safe compute unit selection (CPU_ONLY, CPU_AND_GPU, CPU_AND_NE, ALL) - Integrated runtime spec retrieval in `backend_delegate.mm` to read `compute_unit` option from `BackendInitContext` - Added comprehensive unit tests for the new options builder Example usage: ```cpp using executorch::backends::coreml::LoadOptionsBuilder; LoadOptionsBuilder coreml_opts; coreml_opts.setComputeUnit(LoadOptionsBuilder::ComputeUnit::CPU_AND_GPU); LoadBackendOptionsMap map; map.set_options(coreml_opts); module.load(method_name, map); ``` Differential Revision: D92358632
1 parent aa7c8ce commit ae0d0ef

File tree

6 files changed

+361
-30
lines changed

6 files changed

+361
-30
lines changed

backends/apple/coreml/BUCK

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ runtime.cxx_library(
8989
platforms = [APPLE],
9090
visibility = ["PUBLIC"],
9191
deps = [
92+
"//executorch/runtime/backend:backend_options",
9293
"//executorch/runtime/backend:interface",
9394
"//executorch/runtime/core:core",
9495
"//executorch/runtime/kernel:kernel_includes",
@@ -133,6 +134,32 @@ _PROTOS = [
133134
"WordTagger",
134135
]
135136

137+
runtime.cxx_test(
138+
name = "coreml_backend_options_test",
139+
srcs = [
140+
"runtime/test/coreml_backend_options_test.cpp",
141+
],
142+
deps = [
143+
":coreml_backend_options",
144+
"//executorch/runtime/backend:backend_options",
145+
"//executorch/runtime/backend:backend_options_map",
146+
"//executorch/runtime/core:core",
147+
],
148+
)
149+
150+
# Header-only library for CoreML backend options
151+
runtime.cxx_library(
152+
name = "coreml_backend_options",
153+
exported_headers = [
154+
"runtime/include/coreml_backend/coreml_backend_options.h",
155+
],
156+
header_namespace = "executorch/backends/apple/coreml",
157+
visibility = ["PUBLIC"],
158+
exported_deps = [
159+
"//executorch/runtime/backend:backend_options",
160+
],
161+
)
162+
136163
runtime.cxx_library(
137164
name = "proto",
138165
srcs = [

backends/apple/coreml/TARGETS

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,29 @@ runtime.python_test(
133133
"fbsource//third-party/pypi/scikit-learn:scikit-learn",
134134
],
135135
)
136+
137+
# Header-only library for CoreML backend options
138+
runtime.cxx_library(
139+
name = "coreml_backend_options",
140+
exported_headers = [
141+
"runtime/include/coreml_backend/coreml_backend_options.h",
142+
],
143+
header_namespace = "executorch/backends/apple/coreml",
144+
visibility = ["PUBLIC"],
145+
exported_deps = [
146+
"//executorch/runtime/backend:backend_options",
147+
],
148+
)
149+
150+
runtime.cxx_test(
151+
name = "coreml_backend_options_test",
152+
srcs = [
153+
"runtime/test/coreml_backend_options_test.cpp",
154+
],
155+
deps = [
156+
":coreml_backend_options",
157+
"//executorch/runtime/backend:backend_options",
158+
"//executorch/runtime/backend:backend_options_map",
159+
"//executorch/runtime/core:core",
160+
],
161+
)

backends/apple/coreml/runtime/delegate/backend_delegate.mm

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,49 @@
1919
namespace {
2020
using namespace executorchcoreml;
2121

22-
MLComputeUnits get_compute_units(const Buffer& buffer) {
22+
std::optional<MLComputeUnits> get_compute_units(const Buffer& buffer) {
2323
std::string value(reinterpret_cast<const char *>(buffer.data()), buffer.size());
2424
if (value == std::string(ETCoreMLStrings.cpuComputeUnitName.UTF8String)) {
2525
return MLComputeUnitsCPUOnly;
2626
} else if (value == std::string(ETCoreMLStrings.cpuAndGpuComputeUnitsName.UTF8String)) {
2727
return MLComputeUnitsCPUAndGPU;
2828
} else if (value == std::string(ETCoreMLStrings.cpuAndNeuralEngineComputeUnitsName.UTF8String)) {
2929
return MLComputeUnitsCPUAndNeuralEngine;
30-
} else {
30+
} else if (value == std::string(ETCoreMLStrings.allComputeUnitsName.UTF8String)) {
3131
return MLComputeUnitsAll;
32+
} else {
33+
return std::nullopt;
3234
}
3335
}
3436

35-
MLModelConfiguration *get_model_configuration(const std::unordered_map<std::string, Buffer>& specs) {
37+
MLModelConfiguration * _Nullable get_model_configuration(const std::unordered_map<std::string, Buffer>& specs,
38+
NSError * __autoreleasing *error) {
3639
std::string compute_units_key(ETCoreMLStrings.computeUnitsKeyName.UTF8String);
3740
MLModelConfiguration *configuration = [[MLModelConfiguration alloc] init];
38-
41+
3942
for (const auto& [key, buffer] : specs) {
4043
if (key == compute_units_key) {
41-
configuration.computeUnits = get_compute_units(buffer);
44+
auto compute_units = get_compute_units(buffer);
45+
if (!compute_units.has_value()) {
46+
std::string value(reinterpret_cast<const char *>(buffer.data()), buffer.size());
47+
NSString *errorMessage = [NSString stringWithFormat:@"Invalid compute_unit value: '%s'. Valid values are: %@, %@, %@, %@",
48+
value.c_str(),
49+
ETCoreMLStrings.cpuComputeUnitName,
50+
ETCoreMLStrings.cpuAndGpuComputeUnitsName,
51+
ETCoreMLStrings.cpuAndNeuralEngineComputeUnitsName,
52+
ETCoreMLStrings.allComputeUnitsName];
53+
if (error) {
54+
*error = [NSError errorWithDomain:ETCoreMLStrings.productIdentifier
55+
code:-1
56+
userInfo:@{NSLocalizedDescriptionKey: errorMessage}];
57+
}
58+
return nil;
59+
}
60+
configuration.computeUnits = compute_units.value();
4261
break;
4362
}
4463
}
45-
64+
4665
return configuration;
4766
}
4867

@@ -112,15 +131,15 @@ - (instancetype)initWithConfig:(BackendDelegate::Config)config {
112131
_config = std::move(config);
113132
_syncQueue = dispatch_queue_create("com.executorchcoreml.modelmanagerdelegate.sync", DISPATCH_QUEUE_SERIAL_WITH_AUTORELEASE_POOL);
114133
}
115-
134+
116135
return self;
117136
}
118137

119138
- (BOOL)_loadAndReturnError:(NSError * _Nullable __autoreleasing *)error {
120139
if (self.impl != nil) {
121140
return YES;
122141
}
123-
142+
124143
ETCoreMLAssetManager *assetManager = create_asset_manager(ETCoreMLStrings.assetsDirectoryPath,
125144
ETCoreMLStrings.trashDirectoryPath,
126145
ETCoreMLStrings.databaseDirectoryPath,
@@ -130,14 +149,14 @@ - (BOOL)_loadAndReturnError:(NSError * _Nullable __autoreleasing *)error {
130149
if (!assetManager) {
131150
return NO;
132151
}
133-
152+
134153
ETCoreMLModelManager *modelManager = [[ETCoreMLModelManager alloc] initWithAssetManager:assetManager];
135154
if (!modelManager) {
136155
return NO;
137156
}
138-
157+
139158
self.impl = modelManager;
140-
159+
141160
if (self.config.should_prewarm_asset) {
142161
[modelManager prewarmRecentlyUsedAssetsWithMaxCount:1];
143162
}
@@ -151,11 +170,11 @@ - (BOOL)loadAndReturnError:(NSError * _Nullable __autoreleasing *)error {
151170
dispatch_sync(self.syncQueue, ^{
152171
result = [self _loadAndReturnError:&localError];
153172
});
154-
173+
155174
if (error) {
156175
*error = localError;
157176
}
158-
177+
159178
return result;
160179
}
161180

@@ -183,7 +202,7 @@ - (ModelHandle*)loadModelFromAOTData:(NSData*)data
183202
if (![self loadAndReturnError:error]) {
184203
return nil;
185204
}
186-
205+
187206
auto handle = [self.impl loadModelFromAOTData:data
188207
configuration:configuration
189208
methodName:methodName
@@ -223,15 +242,15 @@ - (BOOL)purgeModelsCacheAndReturnError:(NSError * _Nullable __autoreleasing *)er
223242
if (![self loadAndReturnError:error]) {
224243
return NO;
225244
}
226-
245+
227246
return [self.impl purgeModelsCacheAndReturnError:error];;
228247
}
229248

230249
- (BOOL)isAvailable {
231250
if (![self loadAndReturnError:nil]) {
232251
return NO;
233252
}
234-
253+
235254
return YES;
236255
}
237256

@@ -267,20 +286,24 @@ explicit BackendDelegateImpl(const Config& config) noexcept
267286
{
268287
[model_manager_ loadAsynchronously];
269288
}
270-
289+
271290
BackendDelegateImpl(BackendDelegateImpl const&) = delete;
272291
BackendDelegateImpl& operator=(BackendDelegateImpl const&) = delete;
273-
292+
274293
Handle *init(Buffer processed,
275294
const std::unordered_map<std::string, Buffer>& specs,
276295
const char* method_name = nullptr,
277296
const char* function_name = nullptr) const noexcept override {
278297
NSError *localError = nil;
279-
MLModelConfiguration *configuration = get_model_configuration(specs);
280-
298+
MLModelConfiguration *configuration = get_model_configuration(specs, &localError);
299+
if (configuration == nil) {
300+
ETCoreMLLogError(localError, "Invalid model configuration");
301+
return nullptr;
302+
}
303+
281304
NSString *methodNameStr = method_name ? @(method_name) : nil;
282305
NSString *functionNameStr = function_name ? @(function_name) : nil;
283-
306+
284307
NSData *data = [NSData dataWithBytesNoCopy:const_cast<void *>(processed.data())
285308
length:processed.size()
286309
freeWhenDone:NO];
@@ -294,7 +317,7 @@ explicit BackendDelegateImpl(const Config& config) noexcept
294317
}
295318
return modelHandle;
296319
}
297-
320+
298321
bool execute(Handle* handle,
299322
std::vector<MultiArray>& args,
300323
const ModelLoggingOptions& logging_options,
@@ -309,36 +332,36 @@ bool execute(Handle* handle,
309332
if (localError != nil) {
310333
ETCoreMLLogError(localError, "Model execution failed");
311334
ec = static_cast<ErrorCode>(localError.code);
312-
}
335+
}
313336
return false;
314337
}
315-
338+
316339
return true;
317340
}
318-
341+
319342
bool is_valid_handle(Handle* handle) const noexcept override {
320343
return [model_manager_ modelWithHandle:handle] != nil;
321344
}
322-
345+
323346
bool is_available() const noexcept override {
324347
return static_cast<bool>(model_manager_.isAvailable);
325348
}
326-
349+
327350
std::pair<size_t, size_t> get_num_arguments(Handle* handle) const noexcept override {
328351
ETCoreMLModel *model = [model_manager_ modelWithHandle:handle];
329352
return {model.orderedInputNames.count, model.orderedOutputNames.count};
330353
}
331-
354+
332355
void destroy(Handle* handle) const noexcept override {
333356
[model_manager_ unloadModelWithHandle:handle];
334357
}
335-
358+
336359
bool purge_models_cache() const noexcept override {
337360
NSError *localError = nil;
338361
bool result = static_cast<bool>([model_manager_ purgeModelsCacheAndReturnError:&localError]);
339362
return result;
340363
}
341-
364+
342365
ETCoreMLModelManagerDelegate *model_manager_;
343366
Config config_;
344367
};

backends/apple/coreml/runtime/delegate/coreml_backend_delegate.mm

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,24 @@ ModelLoggingOptions get_logging_options(BackendExecutionContext& context) {
263263
specs_map.emplace(spec.key, std::move(buffer));
264264
}
265265

266+
// Check RuntimeSpec for compute_unit override.
267+
// RuntimeSpec takes precedence over CompileSpec for load-time configuration.
268+
std::string runtime_compute_unit_value;
269+
auto runtime_specs = context.runtime_specs();
270+
if (runtime_specs.size() > 0) {
271+
auto compute_unit_result = context.get_runtime_spec<const char*>("compute_unit");
272+
if (compute_unit_result.ok()) {
273+
runtime_compute_unit_value = compute_unit_result.get();
274+
ET_LOG(Debug, "%s: Using compute_unit from RuntimeSpec: %s",
275+
ETCoreMLStrings.delegateIdentifier.UTF8String,
276+
runtime_compute_unit_value.c_str());
277+
// Override the compile spec with runtime spec value
278+
std::string compute_units_key(ETCoreMLStrings.computeUnitsKeyName.UTF8String);
279+
auto buffer = Buffer(runtime_compute_unit_value.data(), runtime_compute_unit_value.size());
280+
specs_map.insert_or_assign(compute_units_key, std::move(buffer));
281+
}
282+
}
283+
266284
// This will hold the NamedDataStore data if needed, keeping it alive until scope exit
267285
std::optional<FreeableBuffer> namedDataStoreBuffer;
268286
Buffer buffer(nullptr, 0);

0 commit comments

Comments
 (0)