diff --git a/cava/nightwatch/generator/c/callee.py b/cava/nightwatch/generator/c/callee.py index 8915b44e..e5beeda5 100644 --- a/cava/nightwatch/generator/c/callee.py +++ b/cava/nightwatch/generator/c/callee.py @@ -437,7 +437,7 @@ def buffer_case(): return "" def default_case(): - return (Expr(type_.transfer).equals("NW_HANDLE")).if_then_else( + return (Expr(type_.transfer).one_of(["NW_HANDLE", "NW_OPAQUE"])).if_then_else( Expr(not type_.deallocates).if_then_else( assign_record_replay_functions(param_value, type_).then(record_call_metadata(param_value, type_)), expunge_calls(param_value), @@ -446,9 +446,10 @@ def default_case(): if type_.fields: return for_all_elements(values, cast_type, type_, depth=depth, original_type=original_type, **other) - return type_.is_simple_buffer().if_then_else( - simple_buffer_case, Expr(type_.transfer).equals("NW_BUFFER").if_then_else(buffer_case, default_case) - ) + return Expr(type_.transfer).equals("NW_BUFFER").if_then_else( + buffer_case, + default_case + ) with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_result_value( diff --git a/cava/samples/cudart/cudart.cpp b/cava/samples/cudart/cudart.cpp index f2783813..5fd92bb8 100644 --- a/cava/samples/cudart/cudart.cpp +++ b/cava/samples/cudart/cudart.cpp @@ -54,6 +54,7 @@ typedef struct { GHashTable *fatbin_funcs; /* for NULL, the hash table */ int num_funcs; struct fatbin_function *func; /* for functions */ + size_t buffer_size; /* global states */ CUmodule cur_module; @@ -394,18 +395,46 @@ EXPORTED __host__ cudaError_t CUDARTAPI cudaMallocHost(void **ptr, size_t size) EXPORTED __host__ cudaError_t CUDARTAPI cudaFreeHost(void *ptr) { free(ptr); } ava_end_replacement; +/// Migration: @mem_extract_tag +ava_utility void *object_extract(void *obj, size_t *length) { + // called from host + void *buffer; + printf("object_replay: object=%lx\n", (uintptr_t)obj); + cudaDeviceSynchronize(); + + *length = ava_metadata(obj)->buffer_size; + buffer = malloc(*length); + cudaMemcpy(buffer, obj, *length, cudaMemcpyDeviceToHost); + return buffer; +} + +ava_utility void object_replace(void *obj, void *data, size_t length) { + printf("object_replace: object=%lx, len=%lu\n", (uintptr_t)obj, length); + assert(length != 0); + cudaMemcpy(obj, data, length, cudaMemcpyHostToDevice); +} + __host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaMalloc(void **devPtr, size_t size) { ava_argument(devPtr) { ava_out; ava_buffer(1); - ava_element ava_opaque; + ava_element { + ava_allocates; + ava_handle; + ava_object_explicit_state_functions(object_extract, object_replace); + ava_object_record; + } } + + ava_execute(); + ava_metadata(*devPtr)->buffer_size = size; } __host__ cudaError_t CUDARTAPI cudaMemcpy(void *dst, const void *src, size_t count, enum cudaMemcpyKind kind) { ava_argument(dst) { if (kind == cudaMemcpyHostToDevice) { - ava_opaque; + ava_handle; + ava_object_record; } else if (kind == cudaMemcpyDeviceToHost) { ava_out; ava_buffer(count); @@ -417,12 +446,17 @@ __host__ cudaError_t CUDARTAPI cudaMemcpy(void *dst, const void *src, size_t cou ava_in; ava_buffer(count); } else if (kind == cudaMemcpyDeviceToHost) { - ava_opaque; + ava_handle; } } } -__host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaFree(void *devPtr) { ava_argument(devPtr) ava_opaque; } +__host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaFree(void *devPtr) { + ava_argument(devPtr) { + ava_handle; + ava_object_record; + } +} /* Rich set of APIs */