File tree Expand file tree Collapse file tree 4 files changed +26
-5
lines changed
Expand file tree Collapse file tree 4 files changed +26
-5
lines changed Original file line number Diff line number Diff line change 6262/* ! \brief TVM FFI minor version. */
6363#define TVM_FFI_VERSION_MINOR 1
6464/* ! \brief TVM FFI patch version. */
65- #define TVM_FFI_VERSION_PATCH 9
65+ #define TVM_FFI_VERSION_PATCH 8
6666// NOLINTEND(modernize-macro-to-enum)
6767
6868#ifdef __cplusplus
Original file line number Diff line number Diff line change @@ -321,10 +321,16 @@ class TVMFFIPyCallManager {
321321 return -1 ;
322322 }
323323 }
324+
324325 if (ctx.dlpack_c_exchange_api != nullptr &&
325326 prev_tensor_allocator != ctx.dlpack_c_exchange_api ->managed_tensor_allocator ) {
326- c_api_ret_code[0 ] =
327- TVMFFIEnvSetDLPackManagedTensorAllocator (prev_tensor_allocator, 0 , nullptr );
327+ // note: we cannot set the error value to c_api_ret_code[0] here because it
328+ // will be overwritten by the error value from the function call
329+ if (TVMFFIEnvSetDLPackManagedTensorAllocator (prev_tensor_allocator, 0 , nullptr ) != 0 ) {
330+ PyErr_SetString (PyExc_RuntimeError, " Failed to recover DLPack managed tensor allocator" );
331+ return -1 ;
332+ }
333+ // return error after
328334 if (c_api_ret_code[0 ] != 0 ) return 0 ;
329335 }
330336 if (optional_out_ctx_dlpack_api != nullptr && ctx.dlpack_c_exchange_api != nullptr ) {
Original file line number Diff line number Diff line change @@ -66,7 +66,8 @@ class EnvContext {
6666 int write_to_global_context,
6767 DLPackManagedTensorAllocator* opt_out_original_allocator) {
6868 if (opt_out_original_allocator != nullptr ) {
69- *opt_out_original_allocator = GetDLPackManagedTensorAllocator ();
69+ // only returns the cached local allocator and ignore global allocator
70+ *opt_out_original_allocator = dlpack_allocator_;
7071 }
7172 if (write_to_global_context != 0 ) {
7273 GlobalTensorAllocator () = allocator;
Original file line number Diff line number Diff line change 1818from __future__ import annotations
1919
2020from types import ModuleType
21- from typing import Any , NamedTuple
21+ from typing import Any , NamedTuple , NoReturn
2222
2323import numpy .typing as npt
2424import pytest
@@ -78,6 +78,20 @@ def test_tensor_auto_dlpack() -> None:
7878 np .testing .assert_equal (y .numpy (), x .numpy ())
7979
8080
81+ @pytest .mark .skipif (torch is None , reason = "Fast torch dlpack importer is not enabled" )
82+ def test_tensor_auto_dlpack_with_error () -> None :
83+ assert torch is not None
84+ x = torch .arange (128 )
85+
86+ def raise_torch_error (x : Any ) -> NoReturn :
87+ raise ValueError ("error XYZ" )
88+
89+ f = tvm_ffi .convert (raise_torch_error )
90+ with pytest .raises (ValueError ):
91+ # pass in torch argment to trigger the error in set allocator path
92+ f (x )
93+
94+
8195def test_tensor_class_override () -> None :
8296 class MyTensor (tvm_ffi .Tensor ):
8397 pass
You can’t perform that action at this time.
0 commit comments