Skip to content

Commit e1bd421

Browse files
authored
[FIX] Fix the error propagation in the case of tensor arguments (#409)
This PR fixes error propagation in the case of tensor arguments. The bug was previously hidden and revealed after a fix landed in 0.1.8, so it does not impact previous versions. Added a regression test to cover this case.
1 parent 10cb004 commit e1bd421

File tree

4 files changed

+26
-5
lines changed

4 files changed

+26
-5
lines changed

include/tvm/ffi/c_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
/*! \brief TVM FFI minor version. */
6363
#define TVM_FFI_VERSION_MINOR 1
6464
/*! \brief TVM FFI patch version. */
65-
#define TVM_FFI_VERSION_PATCH 9
65+
#define TVM_FFI_VERSION_PATCH 8
6666
// NOLINTEND(modernize-macro-to-enum)
6767

6868
#ifdef __cplusplus

python/tvm_ffi/cython/tvm_ffi_python_helpers.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,16 @@ class TVMFFIPyCallManager {
321321
return -1;
322322
}
323323
}
324+
324325
if (ctx.dlpack_c_exchange_api != nullptr &&
325326
prev_tensor_allocator != ctx.dlpack_c_exchange_api->managed_tensor_allocator) {
326-
c_api_ret_code[0] =
327-
TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0, nullptr);
327+
// note: we cannot set the error value to c_api_ret_code[0] here because it
328+
// will be overwritten by the error value from the function call
329+
if (TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0, nullptr) != 0) {
330+
PyErr_SetString(PyExc_RuntimeError, "Failed to recover DLPack managed tensor allocator");
331+
return -1;
332+
}
333+
// return error after
328334
if (c_api_ret_code[0] != 0) return 0;
329335
}
330336
if (optional_out_ctx_dlpack_api != nullptr && ctx.dlpack_c_exchange_api != nullptr) {

src/ffi/extra/env_context.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ class EnvContext {
6666
int write_to_global_context,
6767
DLPackManagedTensorAllocator* opt_out_original_allocator) {
6868
if (opt_out_original_allocator != nullptr) {
69-
*opt_out_original_allocator = GetDLPackManagedTensorAllocator();
69+
// only returns the cached local allocator and ignore global allocator
70+
*opt_out_original_allocator = dlpack_allocator_;
7071
}
7172
if (write_to_global_context != 0) {
7273
GlobalTensorAllocator() = allocator;

tests/python/test_tensor.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
from types import ModuleType
21-
from typing import Any, NamedTuple
21+
from typing import Any, NamedTuple, NoReturn
2222

2323
import numpy.typing as npt
2424
import pytest
@@ -78,6 +78,20 @@ def test_tensor_auto_dlpack() -> None:
7878
np.testing.assert_equal(y.numpy(), x.numpy())
7979

8080

81+
@pytest.mark.skipif(torch is None, reason="Fast torch dlpack importer is not enabled")
82+
def test_tensor_auto_dlpack_with_error() -> None:
83+
assert torch is not None
84+
x = torch.arange(128)
85+
86+
def raise_torch_error(x: Any) -> NoReturn:
87+
raise ValueError("error XYZ")
88+
89+
f = tvm_ffi.convert(raise_torch_error)
90+
with pytest.raises(ValueError):
91+
# pass in torch argment to trigger the error in set allocator path
92+
f(x)
93+
94+
8195
def test_tensor_class_override() -> None:
8296
class MyTensor(tvm_ffi.Tensor):
8397
pass

0 commit comments

Comments
 (0)