diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 4881844ac6d..7d34f1d6a64 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -173,7 +173,7 @@ Obtaining the constant data pointer can either be from within the flatbuffer payload (deprecated) or via offsets to the constant_data_ptr. If no constant data associated with the tensor value, then returns nullptr. */ -const uint8_t* getConstantDataPtr( +Result getConstantDataPtr( uint32_t buffer_idx, GraphPtr flatbuffer_graph, const uint8_t* constant_data_ptr, @@ -184,13 +184,39 @@ const uint8_t* getConstantDataPtr( if (!constant_data_ptr) { // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC // window - const auto& constant_buffer = *flatbuffer_graph->constant_buffer(); - return constant_buffer[buffer_idx]->storage()->data(); + auto* cb = flatbuffer_graph->constant_buffer(); + ET_CHECK_OR_RETURN_ERROR( + cb != nullptr, InvalidProgram, "constant_buffer is null"); + ET_CHECK_OR_RETURN_ERROR( + buffer_idx < cb->size(), + InvalidProgram, + "buffer_idx %u out of bounds for constant_buffer of size %u", + buffer_idx, + cb->size()); + auto* buffer_entry = (*cb)[buffer_idx]; + ET_CHECK_OR_RETURN_ERROR( + buffer_entry != nullptr && buffer_entry->storage() != nullptr, + InvalidProgram, + "Null constant_buffer entry at buffer_idx %u", + buffer_idx); + return buffer_entry->storage()->data(); } else { - ConstantDataOffsetPtr constant_data_offset = - flatbuffer_graph->constant_data()->Get(buffer_idx); + auto* cd = flatbuffer_graph->constant_data(); + ET_CHECK_OR_RETURN_ERROR( + cd != nullptr, InvalidProgram, "constant_data is null"); + ET_CHECK_OR_RETURN_ERROR( + buffer_idx < cd->size(), + InvalidProgram, + "buffer_idx %u out of bounds for constant_data of size %u", + buffer_idx, + cd->size()); + ConstantDataOffsetPtr constant_data_offset = cd->Get(buffer_idx); + ET_CHECK_OR_RETURN_ERROR( + constant_data_offset != nullptr, + InvalidProgram, + "Null constant_data entry at buffer_idx %u", + buffer_idx); uint64_t offset = constant_data_offset->offset(); - bool has_named_key = flatbuffers::IsFieldPresent( constant_data_offset, fb_xnnpack::ConstantDataOffset::VT_NAMED_KEY); // If there is no tensor name @@ -203,7 +229,7 @@ const uint8_t* getConstantDataPtr( weights_cache->load_unpacked_data(data_name); if (!data_ptr.ok()) { ET_LOG(Error, "Failed to load weights from cache"); - return nullptr; + return Error::InvalidProgram; } return data_ptr.get(); #else @@ -215,7 +241,7 @@ const uint8_t* getConstantDataPtr( "Failed to get constant data for key %s from named_data_map. Error code: %u", data_name.c_str(), static_cast(buffer.error())); - return nullptr; + return Error::InvalidProgram; } const uint8_t* data_ptr = static_cast(buffer.get().data()); @@ -229,7 +255,7 @@ const uint8_t* getConstantDataPtr( return nullptr; } -const uint8_t* getConstantDataPtr( +Result getConstantDataPtr( const fb_xnnpack::XNNTensorValue* tensor_value, GraphPtr flatbuffer_graph, const uint8_t* constant_data_ptr, @@ -298,13 +324,17 @@ Error defineTensor( // Get Pointer to constant data from flatbuffer, if its non-constant // it is a nullptr - const uint8_t* buffer_ptr = getConstantDataPtr( + auto buffer_result = getConstantDataPtr( tensor_value, flatbuffer_graph, constant_data_ptr, named_data_map, freeable_buffers, weights_cache); + if (!buffer_result.ok()) { + return buffer_result.error(); + } + const uint8_t* buffer_ptr = buffer_result.get(); xnn_status status; // The type we might have to convert to @@ -449,13 +479,17 @@ Error defineTensor( const float* scale = qparams->scale()->data(); if (qparams->scale_buffer_idx() != 0) { - scale = reinterpret_cast(getConstantDataPtr( + auto scale_result = getConstantDataPtr( qparams->scale_buffer_idx(), flatbuffer_graph, constant_data_ptr, named_data_map, freeable_buffers, - weights_cache)); + weights_cache); + if (!scale_result.ok()) { + return scale_result.error(); + } + scale = reinterpret_cast(scale_result.get()); ET_CHECK_OR_RETURN_ERROR( scale != nullptr, Internal, "Failed to load scale data."); } @@ -491,13 +525,18 @@ Error defineTensor( // Block scales are preferably serialized as bf16 but can also be // serialized as fp32 for backwards compatability. if (qparams->scale_buffer_idx() != 0) { - scale_data = reinterpret_cast(getConstantDataPtr( + auto scale_data_result = getConstantDataPtr( qparams->scale_buffer_idx(), flatbuffer_graph, constant_data_ptr, named_data_map, freeable_buffers, - weights_cache)); + weights_cache); + if (!scale_data_result.ok()) { + return scale_data_result.error(); + } + scale_data = + reinterpret_cast(scale_data_result.get()); ET_CHECK_OR_RETURN_ERROR( scale_data != nullptr, Internal, "Failed to load scale data."); scale_numel = qparams->num_scales();