Skip to content

Commit 4c712ca

Browse files
authored
[RFC][Error][ABI] Update Error to enable future compact to cause chaining (#396)
This PR brings a backward compatible update to error ABI to enable possible future support of cause chaining. Specifically, we add two fields: - cause_chain is an optional field for chaining errors - extra_context can be used to optionally attach opaque object (e.g. python error) if needed. The change is backward compatible as we only append to the error field. Most of the existing usages will simply ignore the two fields and use a single error.
1 parent f9b5e7d commit 4c712ca

File tree

4 files changed

+136
-0
lines changed

4 files changed

+136
-0
lines changed

include/tvm/ffi/c_api.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,16 @@ typedef struct {
394394
*/
395395
void (*update_backtrace)(TVMFFIObjectHandle self, const TVMFFIByteArray* backtrace,
396396
int32_t update_mode);
397+
/*!
398+
* \brief Optional cause error chain that caused this error to be raised.
399+
* \note This handle is owned by the ErrorCell.
400+
*/
401+
TVMFFIObjectHandle cause_chain;
402+
/*!
403+
* \brief Optional extra context that can be used to record additional info about the error.
404+
* \note This handle is owned by the ErrorCell.
405+
*/
406+
TVMFFIObjectHandle extra_context;
397407
} TVMFFIErrorCell;
398408

399409
/*!
@@ -631,6 +641,20 @@ TVM_FFI_DLL void TVMFFIErrorSetRaisedFromCStrParts(const char* kind, const char*
631641
TVM_FFI_DLL int TVMFFIErrorCreate(const TVMFFIByteArray* kind, const TVMFFIByteArray* message,
632642
const TVMFFIByteArray* backtrace, TVMFFIObjectHandle* out);
633643

644+
/*!
645+
* \brief Create an initial error object with cause chain and extra context.
646+
* \param kind The kind of the error.
647+
* \param message The error message.
648+
* \param backtrace The backtrace of the error.
649+
* \param cause_chain The cause error chain that caused this error to be raised.
650+
* \param extra_context The extra context that can be used to record additional information.
651+
* \param out The output error object handle.
652+
* \return 0 on success, nonzero on failure.
653+
*/
654+
TVM_FFI_DLL int TVMFFIErrorCreateWithCauseAndExtraContext(
655+
const TVMFFIByteArray* kind, const TVMFFIByteArray* message, const TVMFFIByteArray* backtrace,
656+
TVMFFIObjectHandle cause_chain, TVMFFIObjectHandle extra_context, TVMFFIObjectHandle* out);
657+
634658
//------------------------------------------------------------
635659
// Section: DLPack support APIs
636660
//------------------------------------------------------------

include/tvm/ffi/error.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <cstring>
3333
#include <iostream>
3434
#include <memory>
35+
#include <optional>
3536
#include <sstream>
3637
#include <string>
3738
#include <utility>
@@ -84,6 +85,20 @@ struct EnvErrorAlreadySet : public std::exception {};
8485
*/
8586
class ErrorObj : public Object, public TVMFFIErrorCell {
8687
public:
88+
ErrorObj() {
89+
this->cause_chain = nullptr;
90+
this->extra_context = nullptr;
91+
}
92+
93+
~ErrorObj() {
94+
if (this->cause_chain != nullptr) {
95+
details::ObjectUnsafe::DecRefObjectHandle(this->cause_chain);
96+
}
97+
if (this->extra_context != nullptr) {
98+
details::ObjectUnsafe::DecRefObjectHandle(this->extra_context);
99+
}
100+
}
101+
87102
/// \cond Doxygen_Suppress
88103
static constexpr const int32_t _type_index = TypeIndex::kTVMFFIError;
89104
TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIError, ErrorObj, Object);
@@ -146,6 +161,29 @@ class Error : public ObjectRef, public std::exception {
146161
std::move(backtrace));
147162
}
148163

164+
/*!
165+
* \brief Constructor
166+
* \param kind The kind of the error.
167+
* \param message The message of the error.
168+
* \param backtrace The backtrace of the error.
169+
* \param cause_chain The cause chain of the error.
170+
* \param extra_context The extra context of the error.
171+
*/
172+
Error(std::string kind, std::string message, std::string backtrace,
173+
std::optional<Error> cause_chain, std::optional<ObjectRef> extra_context) {
174+
ObjectPtr<ErrorObj> error_obj = make_object<details::ErrorObjFromStd>(
175+
std::move(kind), std::move(message), std::move(backtrace));
176+
if (cause_chain.has_value()) {
177+
error_obj->cause_chain =
178+
details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(*std::move(cause_chain));
179+
}
180+
if (extra_context.has_value()) {
181+
error_obj->extra_context =
182+
details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(*std::move(extra_context));
183+
}
184+
data_ = std::move(error_obj);
185+
}
186+
149187
/*!
150188
* \brief Constructor
151189
* \param kind The kind of the error.
@@ -173,6 +211,37 @@ class Error : public ObjectRef, public std::exception {
173211
return std::string(obj->message.data, obj->message.size);
174212
}
175213

214+
/*!
215+
216+
* \brief Get the cause chain of the error object.
217+
* \return The cause chain of the error object.
218+
*/
219+
std::optional<Error> cause_chain() const {
220+
ErrorObj* obj = static_cast<ErrorObj*>(data_.get());
221+
if (obj->cause_chain != nullptr) {
222+
return details::ObjectUnsafe::ObjectRefFromObjectPtr<Error>(
223+
details::ObjectUnsafe::ObjectPtrFromUnowned<ErrorObj>(
224+
static_cast<Object*>(obj->cause_chain)));
225+
} else {
226+
return std::nullopt;
227+
}
228+
}
229+
230+
/*!
231+
* \brief Get the extra context of the error object.
232+
* \return The extra context of the error object.
233+
*/
234+
std::optional<ObjectRef> extra_context() const {
235+
ErrorObj* obj = static_cast<ErrorObj*>(data_.get());
236+
if (obj->extra_context != nullptr) {
237+
return details::ObjectUnsafe::ObjectRefFromObjectPtr<ObjectRef>(
238+
details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(
239+
static_cast<Object*>(obj->extra_context)));
240+
} else {
241+
return std::nullopt;
242+
}
243+
}
244+
176245
/*!
177246
* \brief Get the backtrace of the error object.
178247
* \return The backtrace of the error object.

src/ffi/error.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,35 @@ int TVMFFIErrorCreate(const TVMFFIByteArray* kind, const TVMFFIByteArray* messag
111111
}
112112
TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIErrorCreate);
113113
}
114+
115+
int TVMFFIErrorCreateWithCauseAndExtraContext(
116+
const TVMFFIByteArray* kind, const TVMFFIByteArray* message, const TVMFFIByteArray* backtrace,
117+
TVMFFIObjectHandle cause_chain, TVMFFIObjectHandle extra_context, TVMFFIObjectHandle* out) {
118+
// log other errors to the logger
119+
TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
120+
try {
121+
std::optional<tvm::ffi::Error> cause_chain_error;
122+
if (cause_chain != nullptr) {
123+
cause_chain_error = tvm::ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr<tvm::ffi::Error>(
124+
tvm::ffi::details::ObjectUnsafe::ObjectPtrFromUnowned<tvm::ffi::ErrorObj>(
125+
static_cast<tvm::ffi::ErrorObj*>(cause_chain)));
126+
}
127+
std::optional<tvm::ffi::ObjectRef> extra_context_ref;
128+
if (extra_context != nullptr) {
129+
extra_context_ref =
130+
tvm::ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr<tvm::ffi::ObjectRef>(
131+
tvm::ffi::details::ObjectUnsafe::ObjectPtrFromUnowned<tvm::ffi::Object>(
132+
static_cast<tvm::ffi::Object*>(extra_context)));
133+
}
134+
135+
tvm::ffi::Error error(std::string(kind->data, kind->size),
136+
std::string(message->data, message->size),
137+
std::string(backtrace->data, backtrace->size),
138+
std::move(cause_chain_error), std::move(extra_context_ref));
139+
*out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(error));
140+
return 0;
141+
} catch (const std::bad_alloc& e) {
142+
return -1;
143+
}
144+
TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIErrorCreateWithCauseAndExtraContext);
145+
}

tests/cpp/test_error.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,15 @@ TEST(Error, TracebackMostRecentCallLast) {
118118
Error error("TypeError", "here", "test0\ntest1\ntest2\n");
119119
EXPECT_EQ(error.TracebackMostRecentCallLast(), "test2\ntest1\ntest0\n");
120120
}
121+
122+
TEST(Error, CauseChain) {
123+
Error original_error("TypeError", "here", "test0");
124+
Error cause_chain("ValueError", "cause", "test1", original_error, std::nullopt);
125+
auto opt_cause = cause_chain.cause_chain();
126+
EXPECT_TRUE(opt_cause.has_value());
127+
if (opt_cause.has_value()) {
128+
EXPECT_EQ(opt_cause->kind(), "TypeError");
129+
}
130+
EXPECT_TRUE(!cause_chain.extra_context().has_value());
131+
}
121132
} // namespace

0 commit comments

Comments
 (0)