diff --git a/alloc.go b/alloc.go deleted file mode 100644 index 96e83c3d..00000000 --- a/alloc.go +++ /dev/null @@ -1,97 +0,0 @@ -package lua - -import ( - "unsafe" -) - -// iface is an internal representation of the go-interface. -type iface struct { - itab unsafe.Pointer - word unsafe.Pointer -} - -const preloadLimit = 256 -const intPreloadLimit = 65536 // Common loop bounds - -var preloads [preloadLimit]LValue -var intPreloads [intPreloadLimit * 2]LValue // [-65536, 65536) - -func init() { - for i := 0; i < preloadLimit; i++ { - preloads[i] = LNumber(i) - } - for i := -intPreloadLimit; i < intPreloadLimit; i++ { - intPreloads[i+intPreloadLimit] = LInteger(i) - } -} - -// allocator is a fast bulk memory allocator for the LValue. -// Uses a single backing slice for both float64 and int64 since both are 8 bytes. -type allocator struct { - size int - ptrs []uint64 // shared backing store for float64 and int64 - - scratchValue LValue - scratchValueP *iface - scratchIntVal LValue - scratchIntValP *iface -} - -func newAllocator(size int) *allocator { - al := &allocator{ - size: size, - ptrs: nil, // lazy alloc on first use - } - al.scratchValue = LNumber(0) - al.scratchValueP = (*iface)(unsafe.Pointer(&al.scratchValue)) - al.scratchIntVal = LInteger(0) - al.scratchIntValP = (*iface)(unsafe.Pointer(&al.scratchIntVal)) - - return al -} - -// LNumber2I takes a number value and returns an interface LValue representing the same number. -// Converting an LNumber to a LValue naively, by doing: -// `var val LValue = myLNumber` -// will result in an individual heap alloc of 8 bytes for the float value. LNumber2I amortizes the cost and memory -// overhead of these allocs by allocating blocks instead. -// The downside of this is that all values on a given block have to become eligible for gc before the block -// as a whole can be gc-ed. -func (al *allocator) LNumber2I(v LNumber) LValue { - // Fast path: check for shared preloaded integers [0, preloadLimit) - iv := int(v) - if iv >= 0 && iv < preloadLimit && LNumber(iv) == v { - return preloads[iv] - } - - // check if we need a new alloc page - if cap(al.ptrs) == len(al.ptrs) { - al.ptrs = make([]uint64, 0, al.size) - } - - // alloc from shared pool, reinterpret as float64 - al.ptrs = append(al.ptrs, *(*uint64)(unsafe.Pointer(&v))) - ptr := &al.ptrs[len(al.ptrs)-1] - - al.scratchValueP.word = unsafe.Pointer(ptr) - return al.scratchValue -} - -// LInteger2I converts an LInteger to LValue with zero-alloc for values in [-65536, 65536). -func (al *allocator) LInteger2I(v LInteger) LValue { - iv := int(v) - if iv >= -intPreloadLimit && iv < intPreloadLimit { - return intPreloads[iv+intPreloadLimit] - } - - if cap(al.ptrs) == len(al.ptrs) { - al.ptrs = make([]uint64, 0, al.size) - } - - // alloc from shared pool, reinterpret as int64 - al.ptrs = append(al.ptrs, uint64(v)) - ptr := &al.ptrs[len(al.ptrs)-1] - - al.scratchIntValP.word = unsafe.Pointer(ptr) - return al.scratchIntVal -} diff --git a/alloc_test.go b/alloc_test.go deleted file mode 100644 index d02891ac..00000000 --- a/alloc_test.go +++ /dev/null @@ -1,217 +0,0 @@ -package lua - -import ( - "math" - "testing" -) - -func TestAllocatorLNumber2I_Preloaded(t *testing.T) { - al := newAllocator(32) - - // Test preloaded range [0, 256) - for i := 0; i < preloadLimit; i++ { - v := al.LNumber2I(LNumber(i)) - if v.Type() != LTNumber { - t.Errorf("preloaded %d: expected LTNumber, got %v", i, v.Type()) - } - if float64(v.(LNumber)) != float64(i) { - t.Errorf("preloaded %d: expected %d, got %v", i, i, v) - } - // Verify it's the same preloaded instance - if v != preloads[i] { - t.Errorf("preloaded %d: not using preloaded instance", i) - } - } -} - -func TestAllocatorLNumber2I_OutsidePreload(t *testing.T) { - al := newAllocator(32) - - tests := []LNumber{ - LNumber(preloadLimit), // just outside preload - LNumber(preloadLimit + 1), // outside preload - LNumber(1000), // larger value - LNumber(-1), // negative - LNumber(-1000), // larger negative - LNumber(0.5), // float - LNumber(1.5), // float that's not integer - LNumber(math.Pi), // irrational - LNumber(math.MaxFloat64), // max float - LNumber(-math.MaxFloat64), // min float - } - - for _, num := range tests { - v := al.LNumber2I(num) - if v.Type() != LTNumber { - t.Errorf("LNumber(%v): expected LTNumber, got %v", num, v.Type()) - } - if float64(v.(LNumber)) != float64(num) { - t.Errorf("LNumber(%v): expected %v, got %v", num, num, v) - } - } -} - -func TestAllocatorLNumber2I_FloatIntBoundary(t *testing.T) { - al := newAllocator(32) - - // Test that 5.0 uses preload but 5.5 doesn't - v5 := al.LNumber2I(LNumber(5)) - if v5 != preloads[5] { - t.Error("5.0 should use preloaded value") - } - - v55 := al.LNumber2I(LNumber(5.5)) - if v55 == preloads[5] { - t.Error("5.5 should not use preloaded value") - } -} - -func TestAllocatorLInteger2I_Preloaded(t *testing.T) { - al := newAllocator(32) - - // Test preloaded range [-65536, 65536) - testCases := []int64{ - 0, 1, -1, 100, -100, - intPreloadLimit - 1, - -intPreloadLimit, - } - - for _, i := range testCases { - v := al.LInteger2I(LInteger(i)) - if v.Type() != LTInteger { - t.Errorf("preloaded %d: expected LTInteger, got %v", i, v.Type()) - } - if int64(v.(LInteger)) != i { - t.Errorf("preloaded %d: expected %d, got %v", i, i, v) - } - // Verify it's the same preloaded instance - if v != intPreloads[i+intPreloadLimit] { - t.Errorf("preloaded %d: not using preloaded instance", i) - } - } -} - -func TestAllocatorLInteger2I_OutsidePreload(t *testing.T) { - al := newAllocator(32) - - tests := []LInteger{ - LInteger(intPreloadLimit), // just outside preload - LInteger(-intPreloadLimit - 1), // just outside negative preload - LInteger(100000), // larger value - LInteger(-100000), // larger negative - LInteger(math.MaxInt64), // max int - LInteger(math.MinInt64), // min int - } - - for _, num := range tests { - v := al.LInteger2I(num) - if v.Type() != LTInteger { - t.Errorf("LInteger(%v): expected LTInteger, got %v", num, v.Type()) - } - if int64(v.(LInteger)) != int64(num) { - t.Errorf("LInteger(%v): expected %v, got %v", num, num, v) - } - } -} - -func TestAllocatorPageAllocation(t *testing.T) { - al := newAllocator(4) // small page size - - // Force multiple page allocations - values := make([]LValue, 20) - for i := 0; i < 20; i++ { - values[i] = al.LNumber2I(LNumber(1000 + i)) // outside preload range - } - - // Verify all values are correct - for i := 0; i < 20; i++ { - expected := LNumber(1000 + i) - if float64(values[i].(LNumber)) != float64(expected) { - t.Errorf("value %d: expected %v, got %v", i, expected, values[i]) - } - } -} - -func TestAllocatorPreloadInit(t *testing.T) { - // Verify preloads are initialized correctly - for i := 0; i < preloadLimit; i++ { - if preloads[i].Type() != LTNumber { - t.Errorf("preloads[%d] type: expected LTNumber, got %v", i, preloads[i].Type()) - } - if float64(preloads[i].(LNumber)) != float64(i) { - t.Errorf("preloads[%d]: expected %d, got %v", i, i, preloads[i]) - } - } - - // Verify intPreloads are initialized correctly - for i := -intPreloadLimit; i < intPreloadLimit; i++ { - idx := i + intPreloadLimit - if intPreloads[idx].Type() != LTInteger { - t.Errorf("intPreloads[%d] type: expected LTInteger, got %v", i, intPreloads[idx].Type()) - } - if int64(intPreloads[idx].(LInteger)) != int64(i) { - t.Errorf("intPreloads[%d]: expected %d, got %v", i, i, intPreloads[idx]) - } - } -} - -func TestAllocatorSpecialFloats(t *testing.T) { - al := newAllocator(32) - - // Test special float values - tests := []struct { - name string - val LNumber - }{ - {"positive infinity", LNumber(math.Inf(1))}, - {"negative infinity", LNumber(math.Inf(-1))}, - {"NaN", LNumber(math.NaN())}, - {"smallest positive", LNumber(math.SmallestNonzeroFloat64)}, - {"negative zero", LNumber(math.Copysign(0, -1))}, - } - - for _, tt := range tests { - v := al.LNumber2I(tt.val) - got := float64(v.(LNumber)) - - if math.IsNaN(float64(tt.val)) { - if !math.IsNaN(got) { - t.Errorf("%s: expected NaN, got %v", tt.name, got) - } - } else if got != float64(tt.val) { - t.Errorf("%s: expected %v, got %v", tt.name, tt.val, got) - } - } -} - -func BenchmarkAllocatorLNumber2I_Preloaded(b *testing.B) { - al := newAllocator(32) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = al.LNumber2I(LNumber(i % preloadLimit)) - } -} - -func BenchmarkAllocatorLNumber2I_NonPreloaded(b *testing.B) { - al := newAllocator(1024) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = al.LNumber2I(LNumber(i + preloadLimit)) - } -} - -func BenchmarkAllocatorLInteger2I_Preloaded(b *testing.B) { - al := newAllocator(32) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = al.LInteger2I(LInteger(i % intPreloadLimit)) - } -} - -func BenchmarkAllocatorLInteger2I_NonPreloaded(b *testing.B) { - al := newAllocator(1024) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = al.LInteger2I(LInteger(i + intPreloadLimit)) - } -} diff --git a/baselib.go b/baselib.go index 9b02f6ba..ff8a23c7 100644 --- a/baselib.go +++ b/baselib.go @@ -206,7 +206,7 @@ func basePCall(L *LState) int { }() // Check for yield before any cleanup - if L.yielded { + if L.yieldState != yieldNone { return -1 } @@ -416,7 +416,7 @@ func baseXPCall(L *LState) int { L.CallK(0, MultRet, xpcallContinuation, top) // Check for yield before any cleanup - if L.yielded { + if L.yieldState != yieldNone { return -1 } diff --git a/coroutinelib.go b/coroutinelib.go index 68f5842b..57a714f8 100644 --- a/coroutinelib.go +++ b/coroutinelib.go @@ -63,8 +63,8 @@ func coCreate(L *LState) int { return 1 } -func coYield(_ *LState) int { - return -1 +func coYield(L *LState) int { + return -2 // -2 signals user yield (vs -1 for system yield) } func coResume(L *LState) int { @@ -104,9 +104,75 @@ func coResume(L *LState) int { nargs := L.GetTop() - 1 L.XMoveTo(th, nargs) } + top := L.GetTop() + th.yieldState = yieldNone + threadRun(th) + if th.yieldState != yieldSystem || L.Parent == nil { + return L.GetTop() - top + } + return coResumePropagate(L, th, top) +} + +// coResumePropagate handles system yield propagation through a coroutine boundary. +// Called only when the inner thread yielded via a Go function returning -1 +// (not coroutine.yield) and this coroutine has a parent to propagate to. +func coResumePropagate(L *LState, th *LState, top int) int { + // switchToParentThread already moved yield values to L's stack. + // For non-wrapped threads it also pushed LTrue before the values. + // Extract the raw yield values. + yieldStart := top + 1 + if !th.wrapped { + yieldStart++ // skip LTrue from switchToParentThread + } + nvals := L.GetTop() - yieldStart + 1 + + // Transfer yield values to parent thread. For non-wrapped outer coroutines, + // Resume expects [LTrue, val1, val2, ...] on the parent's stack. + parent := L.Parent + if !L.wrapped { + parent.Push(LTrue) + } + for i := 0; i < nvals; i++ { + parent.Push(L.Get(yieldStart + i)) + } + + // Clear our stack so resume values land cleanly. + L.SetTop(0) + + // Perform the thread switch manually to preserve the current frame. + // Unlike switchToParentThread, we do NOT pop the frame — the continuation + // installed below needs it to survive for the next resume. + L.G.CurrentThread = parent + L.Parent = nil + L.yieldState = yieldSystem + + // Install continuation so the next resume re-enters the inner thread. + ext := L.setFrameExt(L.currentFrame) + ext.Continuation = coResumeContinuation + ext.ContinuationCtx = th + + // callGFunction checks L.yieldState and skips switchToParentThread when set, + // preserving the frame on the stack. + return -1 +} + +// coResumeContinuation re-resumes the inner thread after a system yield was +// propagated through this coroutine boundary. Resume values are on L's stack. +func coResumeContinuation(L *LState, ctx interface{}, _ ResumeState) int { + th := ctx.(*LState) + + th.Parent = L + L.G.CurrentThread = th + nargs := L.GetTop() + L.XMoveTo(th, nargs) + th.yieldState = yieldNone + top := L.GetTop() threadRun(th) - return L.GetTop() - top + if th.yieldState != yieldSystem || L.Parent == nil { + return L.GetTop() - top + } + return coResumePropagate(L, th, top) } func coRunning(L *LState) int { diff --git a/integer_test.go b/integer_test.go index 8484a79c..4388fdc5 100644 --- a/integer_test.go +++ b/integer_test.go @@ -30,11 +30,9 @@ func TestLInteger_TypeName(t *testing.T) { } func TestLInteger_Preloads(t *testing.T) { - al := newAllocator(64) - // Values in [-256, 256) should be preloaded (zero alloc) for i := int64(-256); i < 256; i++ { - v := al.LInteger2I(LInteger(i)) + v := lintegerToValue(LInteger(i)) if v.Type() != LTInteger { t.Errorf("LInteger(%d) should have type LTInteger", i) } @@ -45,12 +43,10 @@ func TestLInteger_Preloads(t *testing.T) { } func TestLInteger_LargeValues(t *testing.T) { - al := newAllocator(64) - // Values outside preload range should also work large := []int64{1000, -1000, 1 << 62, -(1 << 62)} for _, i := range large { - v := al.LInteger2I(LInteger(i)) + v := lintegerToValue(LInteger(i)) if v.Type() != LTInteger { t.Errorf("LInteger(%d) should have type LTInteger", i) } diff --git a/registry.go b/registry.go index 422aa2ae..dc9f9b31 100644 --- a/registry.go +++ b/registry.go @@ -10,12 +10,11 @@ type registry struct { top int growBy int maxSize int - alloc *allocator handler registryHandler } -func newRegistry(handler registryHandler, initialSize int, growBy int, maxSize int, alloc *allocator) *registry { - return ®istry{make([]LValue, initialSize), 0, growBy, maxSize, alloc, handler} +func newRegistry(handler registryHandler, initialSize int, growBy int, maxSize int) *registry { + return ®istry{make([]LValue, initialSize), 0, growBy, maxSize, handler} } func (rg *registry) resize(requiredSize int) bool { // +inline-start @@ -266,7 +265,7 @@ func (rg *registry) SetNumber(regi int, vali LNumber) { // +inline-start rg.resize(requiredSize) } } - rg.array[regi] = rg.alloc.LNumber2I(vali) + rg.array[regi] = lnumberToValue(vali) if regi >= rg.top { rg.top = regi + 1 } diff --git a/registry_test.go b/registry_test.go index b5acbd78..0139d148 100644 --- a/registry_test.go +++ b/registry_test.go @@ -14,8 +14,7 @@ func (h *testRegistryHandler) registryOverflow() { func TestRegistryBasicOperations(t *testing.T) { handler := &testRegistryHandler{} - alloc := newAllocator(32) - rg := newRegistry(handler, 10, 5, 100, alloc) + rg := newRegistry(handler, 10, 5, 100) // Test initial state if rg.Top() != 0 { @@ -51,8 +50,7 @@ func TestRegistryBasicOperations(t *testing.T) { func TestRegistrySetTop(t *testing.T) { handler := &testRegistryHandler{} - alloc := newAllocator(32) - rg := newRegistry(handler, 10, 5, 100, alloc) + rg := newRegistry(handler, 10, 5, 100) // Push some values rg.Push(LNumber(1)) @@ -80,8 +78,7 @@ func TestRegistrySetTop(t *testing.T) { func TestRegistrySet(t *testing.T) { handler := &testRegistryHandler{} - alloc := newAllocator(32) - rg := newRegistry(handler, 10, 5, 100, alloc) + rg := newRegistry(handler, 10, 5, 100) // Set beyond current top rg.Set(5, LNumber(100)) @@ -101,8 +98,7 @@ func TestRegistrySet(t *testing.T) { func TestRegistryResize(t *testing.T) { handler := &testRegistryHandler{} - alloc := newAllocator(32) - rg := newRegistry(handler, 2, 2, 100, alloc) // small initial size + rg := newRegistry(handler, 2, 2, 100) // small initial size // Push beyond initial capacity for i := 0; i < 10; i++ { @@ -123,8 +119,7 @@ func TestRegistryResize(t *testing.T) { func TestRegistryOverflow(t *testing.T) { handler := &testRegistryHandler{} - alloc := newAllocator(32) - rg := newRegistry(handler, 2, 2, 5, alloc) // small max size + rg := newRegistry(handler, 2, 2, 5) // small max size // Push until we hit max size for i := 0; i < 5; i++ { @@ -147,8 +142,7 @@ func TestRegistryOverflow(t *testing.T) { func TestRegistryCopyRange(t *testing.T) { handler := &testRegistryHandler{} - alloc := newAllocator(32) - rg := newRegistry(handler, 20, 5, 100, alloc) + rg := newRegistry(handler, 20, 5, 100) // Setup: push 5 values for i := 0; i < 5; i++ { @@ -176,8 +170,7 @@ func TestRegistryCopyRange(t *testing.T) { func TestRegistryFillNil(t *testing.T) { handler := &testRegistryHandler{} - alloc := newAllocator(32) - rg := newRegistry(handler, 20, 5, 100, alloc) + rg := newRegistry(handler, 20, 5, 100) // Setup: push some values rg.Push(LNumber(1)) @@ -200,8 +193,7 @@ func TestRegistryFillNil(t *testing.T) { func TestRegistryInsert(t *testing.T) { handler := &testRegistryHandler{} - alloc := newAllocator(32) - rg := newRegistry(handler, 20, 5, 100, alloc) + rg := newRegistry(handler, 20, 5, 100) // Push initial values rg.Push(LNumber(1)) @@ -225,8 +217,7 @@ func TestRegistryInsert(t *testing.T) { func TestRegistryIsFull(t *testing.T) { handler := &testRegistryHandler{} - alloc := newAllocator(32) - rg := newRegistry(handler, 3, 2, 100, alloc) + rg := newRegistry(handler, 3, 2, 100) if rg.IsFull() { t.Error("registry should not be full initially") @@ -249,10 +240,9 @@ func TestRegistryIsFull(t *testing.T) { func TestRegistrySetNumber(t *testing.T) { handler := &testRegistryHandler{} - alloc := newAllocator(32) - rg := newRegistry(handler, 10, 5, 100, alloc) + rg := newRegistry(handler, 10, 5, 100) - // SetNumber uses allocator for non-preloaded values + // SetNumber uses number boxing helper for non-preloaded values. rg.SetNumber(0, LNumber(1000)) // outside preload range v := rg.Get(0) @@ -266,8 +256,7 @@ func TestRegistrySetNumber(t *testing.T) { func TestRegistryGetNilValue(t *testing.T) { handler := &testRegistryHandler{} - alloc := newAllocator(32) - rg := newRegistry(handler, 10, 5, 100, alloc) + rg := newRegistry(handler, 10, 5, 100) rg.SetTop(5) @@ -281,8 +270,7 @@ func TestRegistryGetNilValue(t *testing.T) { func BenchmarkRegistryPush(b *testing.B) { handler := &testRegistryHandler{} - alloc := newAllocator(1024) - rg := newRegistry(handler, 1024, 256, 10000, alloc) + rg := newRegistry(handler, 1024, 256, 10000) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -295,8 +283,7 @@ func BenchmarkRegistryPush(b *testing.B) { func BenchmarkRegistrySetNumber(b *testing.B) { handler := &testRegistryHandler{} - alloc := newAllocator(1024) - rg := newRegistry(handler, 1024, 256, 10000, alloc) + rg := newRegistry(handler, 1024, 256, 10000) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/state.go b/state.go index ffb068c1..dacae74a 100644 --- a/state.go +++ b/state.go @@ -136,6 +136,30 @@ type Debug struct { // ctx is user-defined context passed through yield, status is the resume result. type LGContinuation func(L *LState, ctx any, status ResumeState) int +// Yield continuation types for Lua frames. When an opcode calls a function +// via callR/Call and the call yields, these identify the post-call logic +// to execute on resume. +const ( + yieldContNone uint8 = 0 + yieldContTForLoop uint8 = 1 // OP_TFORLOOP: check iterator result, update loop + yieldContGetField uint8 = 2 // getField/getFieldString result → store at RA + yieldContSetField uint8 = 3 // setField/setFieldString → no result to store + yieldContSelf uint8 = 4 // OP_SELF: store method at RA, selfobj at RA+1 + yieldContArith uint8 = 5 // objectArith result → store at RA + yieldContUnm uint8 = 6 // OP_UNM __unm result → store at RA + yieldContLen uint8 = 7 // OP_LEN __len result → store at RA + yieldContConcat uint8 = 8 // OP_CONCAT __concat result → store at RA + yieldContCompare uint8 = 9 // OP_EQ/LT/LE comparison result → affects Pc jump +) + +// Yield state: combined yielded flag + yield kind in a single field. +// 0 = not yielded, nonzero = yielded with specific kind. +const ( + yieldNone uint8 = 0 // not yielded + yieldSystem uint8 = 1 // Go function returned -1 (propagates through coroutine boundaries) + yieldUser uint8 = 2 // coroutine.yield (caught by the immediate resumer) +) + // callFrameExt holds rarely-used fields for protected calls and continuations. // Allocated lazily only when needed. type callFrameExt struct { @@ -445,8 +469,8 @@ func panicWithoutTraceback(L *LState) { func newLState(options Options) *LState { // Try to get a state from the pool if pooled := statePool.Get(); pooled != nil { - if ls, ok := pooled.(*LState); ok && ls != nil && ls.alloc != nil { - // Reuse pooled state with its allocator + if ls, ok := pooled.(*LState); ok && ls != nil { + // Reuse pooled state. ls.G = newGlobal() ls.Parent = nil ls.Panic = panicWithTraceback @@ -461,11 +485,6 @@ func newLState(options Options) *LState { ls.ctx = nil ls.ctxDone = nil - // Reset allocator slice but keep capacity - if ls.alloc.ptrs != nil { - ls.alloc.ptrs = ls.alloc.ptrs[:0] - } - // Reuse or recreate registry if ls.reg != nil && cap(ls.reg.array) >= options.RegistrySize { ls.reg.handler = ls @@ -473,7 +492,7 @@ func newLState(options Options) *LState { ls.reg.maxSize = options.RegistryMaxSize ls.reg.growBy = options.RegistryGrowStep } else { - ls.reg = newRegistry(ls, options.RegistrySize, options.RegistryGrowStep, options.RegistryMaxSize, ls.alloc) + ls.reg = newRegistry(ls, options.RegistrySize, options.RegistryGrowStep, options.RegistryMaxSize) } // Reuse auto-growing stack (can handle any size), recreate fixed stacks @@ -494,7 +513,6 @@ func newLState(options Options) *LState { } // Create fresh state - al := newAllocator(64) ls := &LState{ G: newGlobal(), Parent: nil, @@ -503,7 +521,6 @@ func newLState(options Options) *LState { Options: options, stop: 0, - alloc: al, currentFrame: nil, wrapped: false, uvcache: nil, @@ -516,7 +533,7 @@ func newLState(options Options) *LState { } else { ls.stack = newFixedCallFrameStack(options.CallStackSize) } - ls.reg = newRegistry(ls, options.RegistrySize, options.RegistryGrowStep, options.RegistryMaxSize, al) + ls.reg = newRegistry(ls, options.RegistrySize, options.RegistryGrowStep, options.RegistryMaxSize) ls.Env = ls.G.Global return ls } @@ -1040,8 +1057,10 @@ func (ls *LState) callR(nargs, nret, rbase int) { } else { ls.mainLoop(ls, ls.currentFrame) } - // Skip register adjustment if yield happened (state is already set by switchToParentThread) - if ls.yielded { + // Skip register adjustment if yield happened (state is already set by switchToParentThread). + // Save the return base so the continuation handler knows where the result lands. + if ls.yieldState != yieldNone { + ls.yieldContRB = int32(rbase) return } if nret != MultRet { @@ -1084,6 +1103,9 @@ func (ls *LState) getField(obj LValue, key LValue) LValue { ls.reg.Push(curobj) ls.reg.Push(key) ls.Call(2, 1) + if ls.yieldState != yieldNone { + return LNil + } return ls.reg.Pop() } curobj = metaindex @@ -1124,6 +1146,9 @@ func (ls *LState) getFieldString(obj LValue, key string) LValue { ls.reg.Push(curobj) ls.reg.Push(LString(key)) ls.Call(2, 1) + if ls.yieldState != yieldNone { + return LNil + } return ls.reg.Pop() } curobj = metaindex @@ -1172,7 +1197,9 @@ func (ls *LState) setField(obj LValue, key LValue, value LValue) { } curobj = metaindex } - ls.RaiseError("too many recursions in settable") + if ls.yieldState == yieldNone { + ls.RaiseError("too many recursions in settable") + } } func (ls *LState) setFieldString(obj LValue, key string, value LValue) { @@ -1215,7 +1242,9 @@ func (ls *LState) setFieldString(obj LValue, key string, value LValue) { } curobj = metaindex } - ls.RaiseError("too many recursions in settable") + if ls.yieldState == yieldNone { + ls.RaiseError("too many recursions in settable") + } } /* }}} */ @@ -1423,7 +1452,7 @@ func (ls *LState) CreateTable(acap, hcap int) *LTable { // NewThreadWithContext returns a new LState with the given context. // Pass nil for no context (faster execution without cancellation checks). func (ls *LState) NewThreadWithContext(ctx context.Context) *LState { - thread := newLStateWithGAndAlloc(ls.Options, ls.G, ls.Env, ls.alloc) + thread := newLStateWithGlobal(ls.Options, ls.G, ls.Env) if ctx != nil { thread.mainLoop = mainLoopWithContext thread.ctx = ctx @@ -1839,7 +1868,7 @@ func (ls *LState) CallK(nargs, nret int, cont LGContinuation, ctx any) { } ls.callR(nargs, nret, -1) // If yield happened, keep continuation for resume - if ls.yielded { + if ls.yieldState != yieldNone { return } // Call completed without yield - clear continuation @@ -1910,7 +1939,7 @@ func (ls *LState) PCall(nargs, nret int, errfunc *LFunction) (err error) { ls.reg.SetTop(base) } // Skip stack reset if yield happened - if ls.yielded { + if ls.yieldState != yieldNone { return } ls.stack.SetSp(sp) @@ -2035,7 +2064,7 @@ func (ls *LState) Resume(th *LState, fn *LFunction, args ...LValue) (ResumeState } } top := ls.GetTop() - th.yielded = false // Clear yield flag for new resume + th.yieldState = yieldNone // Clear yield flag for new resume threadRun(th) haserror := LVIsFalse(ls.Get(top + 1)) ret := make([]LValue, 0, ls.GetTop()) @@ -2113,7 +2142,7 @@ func (ls *LState) ResumeInto(th *LState, fn *LFunction, retBuf []LValue, args .. } } top := ls.GetTop() - th.yielded = false // Clear yield flag for new resume + th.yieldState = yieldNone // Clear yield flag for new resume threadRun(th) haserror := LVIsFalse(ls.Get(top + 1)) diff --git a/state_pool.go b/state_pool.go index 5b03ead2..daab5692 100644 --- a/state_pool.go +++ b/state_pool.go @@ -65,8 +65,8 @@ func (ls *LState) Close() { } } -// newLStateWithGAndAlloc creates a thread that shares the parent's allocator -func newLStateWithGAndAlloc(options Options, G *Global, env *LTable, parentAlloc *allocator) *LState { +// newLStateWithGlobal creates a thread that shares the parent's global/env. +func newLStateWithGlobal(options Options, G *Global, env *LTable) *LState { // Try to get a state from the pool pooledState := statePool.Get() @@ -77,17 +77,15 @@ func newLStateWithGAndAlloc(options Options, G *Global, env *LTable, parentAlloc ls.Panic = panicWithTraceback ls.Options = options ls.mainLoop = mainLoop - ls.alloc = parentAlloc ls.stop = 0 ls.ctx = nil ls.ctxDone = nil // Registry was preserved but might need resetting if options changed if ls.reg != nil && cap(ls.reg.array) != options.RegistrySize { - ls.reg = newRegistry(ls, options.RegistrySize, options.RegistryGrowStep, options.RegistryMaxSize, parentAlloc) + ls.reg = newRegistry(ls, options.RegistrySize, options.RegistryGrowStep, options.RegistryMaxSize) } else if ls.reg != nil { ls.reg.handler = ls - ls.reg.alloc = parentAlloc } return ls @@ -101,7 +99,6 @@ func newLStateWithGAndAlloc(options Options, G *Global, env *LTable, parentAlloc Dead: false, Options: options, stop: 0, - alloc: parentAlloc, currentFrame: nil, wrapped: false, uvcache: nil, @@ -116,7 +113,7 @@ func newLStateWithGAndAlloc(options Options, G *Global, env *LTable, parentAlloc ls.stack = newFixedCallFrameStack(options.CallStackSize) } - ls.reg = newRegistry(ls, options.RegistrySize, options.RegistryGrowStep, options.RegistryMaxSize, parentAlloc) + ls.reg = newRegistry(ls, options.RegistrySize, options.RegistryGrowStep, options.RegistryMaxSize) ls.Env = env return ls diff --git a/state_test.go b/state_test.go index c70459f5..3e8dc1f0 100644 --- a/state_test.go +++ b/state_test.go @@ -746,9 +746,8 @@ func (registryTestHandler) registryOverflow() { // test pushing and popping from the registry func BenchmarkRegistryPushPopAutoGrow(t *testing.B) { - al := newAllocator(32) sz := 256 * 20 - reg := newRegistry(registryTestHandler(0), sz/2, 64, sz, al) + reg := newRegistry(registryTestHandler(0), sz/2, 64, sz) value := LString("test") t.ResetTimer() @@ -764,9 +763,8 @@ func BenchmarkRegistryPushPopAutoGrow(t *testing.B) { } func BenchmarkRegistryPushPopFixed(t *testing.B) { - al := newAllocator(32) sz := 256 * 20 - reg := newRegistry(registryTestHandler(0), sz, 0, sz, al) + reg := newRegistry(registryTestHandler(0), sz, 0, sz) value := LString("test") t.ResetTimer() @@ -782,9 +780,8 @@ func BenchmarkRegistryPushPopFixed(t *testing.B) { } func BenchmarkRegistrySetTop(t *testing.B) { - al := newAllocator(32) sz := 256 * 20 - reg := newRegistry(registryTestHandler(0), sz, 32, sz*2, al) + reg := newRegistry(registryTestHandler(0), sz, 32, sz*2) t.ResetTimer() diff --git a/utils_test.go b/utils_test.go index 2eda3743..58b11f5b 100644 --- a/utils_test.go +++ b/utils_test.go @@ -61,29 +61,25 @@ func TestIsArrayKey(t *testing.T) { } func TestLNumber2IPreload(t *testing.T) { - al := newAllocator(32) - for i := 0; i < preloadLimit; i++ { - v := al.LNumber2I(LNumber(i)) - if v != preloads[i] { - t.Errorf("LNumber2I(%d) did not return preloaded value", i) + v := lnumberToValue(LNumber(i)) + if v != preloadedNumbers[i] { + t.Errorf("lnumberToValue(%d) did not return preloaded value", i) } } - v := al.LNumber2I(LNumber(preloadLimit)) - if v == preloads[int(preloadLimit)-1] { - t.Errorf("LNumber2I(%d) should not return preloaded value", preloadLimit) + v := lnumberToValue(LNumber(preloadLimit)) + if v == preloadedNumbers[int(preloadLimit)-1] { + t.Errorf("lnumberToValue(%d) should not return preloaded value", preloadLimit) } } func TestLNumber2INonInteger(t *testing.T) { - al := newAllocator(32) - tests := []LNumber{0.5, 1.5, -0.5, math.Pi, 127.5} for _, v := range tests { - result := al.LNumber2I(v) + result := lnumberToValue(v) if n, ok := result.(LNumber); !ok || n != v { - t.Errorf("LNumber2I(%v) returned incorrect value", v) + t.Errorf("lnumberToValue(%v) returned incorrect value", v) } } } diff --git a/value.go b/value.go index e5a43464..fdf55f82 100644 --- a/value.go +++ b/value.go @@ -225,7 +225,6 @@ type LState struct { stop int32 reg *registry stack callFrameStack - alloc *allocator currentFrame *callFrame wrapped bool uvcache *Upvalue @@ -235,7 +234,11 @@ type LState struct { ctxCancelFn context.CancelFunc ctxDone <-chan struct{} frameExt map[int16]*callFrameExt // lazy-allocated frame extensions keyed by Idx - yielded bool // set when coroutine yields without panic + yieldState uint8 // 0=not yielded, 1=system yield, 2=user yield + yieldCont uint8 // pending yield continuation type for Lua frames + yieldContRA int32 // target register for continuation result + yieldContRB int32 // call's ReturnBase (where the result lands) + yieldContIdx int16 // frame Idx that owns this continuation } func (ls *LState) String() string { return fmt.Sprintf("thread: %p", ls) } diff --git a/value_boxing.go b/value_boxing.go new file mode 100644 index 00000000..03a905f2 --- /dev/null +++ b/value_boxing.go @@ -0,0 +1,32 @@ +package lua + +const preloadLimit = 256 +const intPreloadLimit = 65536 // Common loop bounds + +var preloadedNumbers [preloadLimit]LValue +var preloadedIntegers [intPreloadLimit * 2]LValue // [-65536, 65536) + +func init() { + for i := 0; i < preloadLimit; i++ { + preloadedNumbers[i] = LNumber(i) + } + for i := -intPreloadLimit; i < intPreloadLimit; i++ { + preloadedIntegers[i+intPreloadLimit] = LInteger(i) + } +} + +func lnumberToValue(v LNumber) LValue { + iv := int(v) + if iv >= 0 && iv < preloadLimit && LNumber(iv) == v { + return preloadedNumbers[iv] + } + return v +} + +func lintegerToValue(v LInteger) LValue { + iv := int(v) + if iv >= -intPreloadLimit && iv < intPreloadLimit { + return preloadedIntegers[iv+intPreloadLimit] + } + return v +} diff --git a/vm.go b/vm.go index 77d46394..1a0ef19f 100644 --- a/vm.go +++ b/vm.go @@ -66,6 +66,14 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { inst = cf.Fn.Proto.Code[cf.Pc] cf.Pc++ + // Handle yield continuation: when an opcode's inner call yielded and has + // now completed, finish the originating opcode's post-call work. Only fires + // when the current frame is the one that owns the continuation. + if L.yieldCont != 0 && cf.Idx == L.yieldContIdx { + handleYieldContinuation(L, cf, inst) + continue + } + // Note: Some opcodes (CALL, TAILCALL, RETURN) may need to `return` from mainLoop // Others just `continue` to next instruction switch int(inst >> 26) { @@ -302,6 +310,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { Bx := int(inst & 0x3ffff) //GETBX //reg.Set(RA, L.getField(cf.Fn.Env, cf.Fn.Proto.Constants[Bx])) v := L.getFieldString(cf.Fn.Env, cf.Fn.Proto.stringConstants[Bx]) + if L.yieldState != yieldNone { + L.yieldCont = yieldContGetField + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } // this section is inlined by go-inline // source function is 'func (rg *registry) Set(regi int, vali LValue) ' in '_state.go' { @@ -363,6 +378,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { B := int(inst & 0x1ff) //GETB C := int(inst>>9) & 0x1ff //GETC v := L.getField(reg.Get(int(lbase)+B), L.rkValue(C)) + if L.yieldState != yieldNone { + L.yieldCont = yieldContGetField + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } // this section is inlined by go-inline // source function is 'func (rg *registry) Set(regi int, vali LValue) ' in '_state.go' { @@ -392,6 +414,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { B := int(inst & 0x1ff) //GETB C := int(inst>>9) & 0x1ff //GETC v := L.getFieldString(reg.Get(int(lbase)+B), L.rkString(C)) + if L.yieldState != yieldNone { + L.yieldCont = yieldContGetField + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } // this section is inlined by go-inline // source function is 'func (rg *registry) Set(regi int, vali LValue) ' in '_state.go' { @@ -421,6 +450,12 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { Bx := int(inst & 0x3ffff) //GETBX value := reg.Get(RA) L.setFieldString(cf.Fn.Env, cf.Fn.Proto.stringConstants[Bx], value) + if L.yieldState != yieldNone { + L.yieldCont = yieldContSetField + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } case OP_SETUPVAL: reg := L.reg @@ -439,6 +474,12 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { B := int(inst & 0x1ff) //GETB C := int(inst>>9) & 0x1ff //GETC L.setField(reg.Get(RA), L.rkValue(B), L.rkValue(C)) + if L.yieldState != yieldNone { + L.yieldCont = yieldContSetField + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } case OP_SETTABLEKS: reg := L.reg @@ -448,6 +489,12 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { B := int(inst & 0x1ff) //GETB C := int(inst>>9) & 0x1ff //GETC L.setFieldString(reg.Get(RA), L.rkString(B), L.rkValue(C)) + if L.yieldState != yieldNone { + L.yieldCont = yieldContSetField + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } case OP_NEWTABLE: reg := L.reg @@ -487,6 +534,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { C := int(inst>>9) & 0x1ff //GETC selfobj := reg.Get(int(lbase) + B) v := L.getFieldString(selfobj, L.rkString(C)) + if L.yieldState != yieldNone { + L.yieldCont = yieldContSelf + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } // this section is inlined by go-inline // source function is 'func (rg *registry) Set(regi int, vali LValue) ' in '_state.go' { @@ -552,7 +606,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { // Fast path: both integers if lhsI, ok1 := lhs.(LInteger); ok1 { if rhsI, ok2 := rhs.(LInteger); ok2 { - v := reg.alloc.LInteger2I(lhsI + rhsI) + v := lintegerToValue(lhsI + rhsI) newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -567,7 +621,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { // Fast path: both numbers if lhsN, ok1 := lhs.(LNumber); ok1 { if rhsN, ok2 := rhs.(LNumber); ok2 { - v := reg.alloc.LNumber2I(lhsN + rhsN) + v := lnumberToValue(lhsN + rhsN) newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -583,7 +637,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { v1, ok1 := toNumber(lhs) v2, ok2 := toNumber(rhs) if ok1 && ok2 { - v := reg.alloc.LNumber2I(v1 + v2) + v := lnumberToValue(v1 + v2) newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -594,6 +648,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { } } else { v := objectArith(L, OP_ADD, lhs, rhs) + if L.yieldState != yieldNone { + L.yieldCont = yieldContArith + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -628,7 +689,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { // Fast path: both integers if lhsI, ok1 := lhs.(LInteger); ok1 { if rhsI, ok2 := rhs.(LInteger); ok2 { - v := reg.alloc.LInteger2I(lhsI - rhsI) + v := lintegerToValue(lhsI - rhsI) newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -643,7 +704,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { // Fast path: both numbers if lhsN, ok1 := lhs.(LNumber); ok1 { if rhsN, ok2 := rhs.(LNumber); ok2 { - v := reg.alloc.LNumber2I(lhsN - rhsN) + v := lnumberToValue(lhsN - rhsN) newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -659,7 +720,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { v1, ok1 := toNumber(lhs) v2, ok2 := toNumber(rhs) if ok1 && ok2 { - v := reg.alloc.LNumber2I(v1 - v2) + v := lnumberToValue(v1 - v2) newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -670,6 +731,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { } } else { v := objectArith(L, OP_SUB, lhs, rhs) + if L.yieldState != yieldNone { + L.yieldCont = yieldContArith + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -704,7 +772,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { // Fast path: both integers if lhsI, ok1 := lhs.(LInteger); ok1 { if rhsI, ok2 := rhs.(LInteger); ok2 { - v := reg.alloc.LInteger2I(lhsI * rhsI) + v := lintegerToValue(lhsI * rhsI) newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -719,7 +787,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { // Fast path: both numbers if lhsN, ok1 := lhs.(LNumber); ok1 { if rhsN, ok2 := rhs.(LNumber); ok2 { - v := reg.alloc.LNumber2I(lhsN * rhsN) + v := lnumberToValue(lhsN * rhsN) newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -735,7 +803,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { v1, ok1 := toNumber(lhs) v2, ok2 := toNumber(rhs) if ok1 && ok2 { - v := reg.alloc.LNumber2I(v1 * v2) + v := lnumberToValue(v1 * v2) newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -746,6 +814,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { } } else { v := objectArith(L, OP_MUL, lhs, rhs) + if L.yieldState != yieldNone { + L.yieldCont = yieldContArith + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -784,13 +859,20 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { rg.resize(requiredSize) } } - rg.array[regi] = rg.alloc.LNumber2I(vali) + rg.array[regi] = lnumberToValue(vali) if regi >= rg.top { rg.top = regi + 1 } } } else { v := objectArith(L, OP_DIV, lhs, rhs) + if L.yieldState != yieldNone { + L.yieldCont = yieldContArith + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } // this section is inlined by go-inline // source function is 'func (rg *registry) Set(regi int, vali LValue) ' in '_state.go' { @@ -841,13 +923,20 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { rg.resize(requiredSize) } } - rg.array[regi] = rg.alloc.LNumber2I(vali) + rg.array[regi] = lnumberToValue(vali) if regi >= rg.top { rg.top = regi + 1 } } } else { v := objectArith(L, OP_MOD, lhs, rhs) + if L.yieldState != yieldNone { + L.yieldCont = yieldContArith + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } // this section is inlined by go-inline // source function is 'func (rg *registry) Set(regi int, vali LValue) ' in '_state.go' { @@ -898,13 +987,20 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { rg.resize(requiredSize) } } - rg.array[regi] = rg.alloc.LNumber2I(vali) + rg.array[regi] = lnumberToValue(vali) if regi >= rg.top { rg.top = regi + 1 } } } else { v := objectArith(L, OP_POW, lhs, rhs) + if L.yieldState != yieldNone { + L.yieldCont = yieldContArith + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } // this section is inlined by go-inline // source function is 'func (rg *registry) Set(regi int, vali LValue) ' in '_state.go' { @@ -967,7 +1063,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { if (l^r) < 0 && l%r != 0 { q-- } - vali := rg.alloc.LInteger2I(LInteger(q)) + vali := lintegerToValue(LInteger(q)) newSize := regi + 1 if newSize > cap(rg.array) { rg.resize(newSize) @@ -1009,7 +1105,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { { rg := reg regi := RA - vali := rg.alloc.LInteger2I(LInteger(l & r)) + vali := lintegerToValue(LInteger(l & r)) newSize := regi + 1 if newSize > cap(rg.array) { rg.resize(newSize) @@ -1051,7 +1147,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { { rg := reg regi := RA - vali := rg.alloc.LInteger2I(LInteger(l | r)) + vali := lintegerToValue(LInteger(l | r)) newSize := regi + 1 if newSize > cap(rg.array) { rg.resize(newSize) @@ -1093,7 +1189,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { { rg := reg regi := RA - vali := rg.alloc.LInteger2I(LInteger(l ^ r)) + vali := lintegerToValue(LInteger(l ^ r)) newSize := regi + 1 if newSize > cap(rg.array) { rg.resize(newSize) @@ -1143,7 +1239,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { { rg := reg regi := RA - vali := rg.alloc.LInteger2I(LInteger(result)) + vali := lintegerToValue(LInteger(result)) newSize := regi + 1 if newSize > cap(rg.array) { rg.resize(newSize) @@ -1193,7 +1289,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { { rg := reg regi := RA - vali := rg.alloc.LInteger2I(LInteger(result)) + vali := lintegerToValue(LInteger(result)) newSize := regi + 1 if newSize > cap(rg.array) { rg.resize(newSize) @@ -1217,7 +1313,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { { rg := reg regi := RA - vali := rg.alloc.LNumber2I(-nm) + vali := lnumberToValue(-nm) newSize := regi + 1 // this section is inlined by go-inline // source function is 'func (rg *registry) checkSize(requiredSize int) ' in '_state.go' @@ -1238,6 +1334,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { reg.Push(op) reg.Push(unaryv) L.Call(1, 1) + if L.yieldState != yieldNone { + L.yieldCont = yieldContUnm + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } // this section is inlined by go-inline // source function is 'func (rg *registry) Set(regi int, vali LValue) ' in '_state.go' { @@ -1308,7 +1411,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { { rg := reg regi := RA - vali := rg.alloc.LInteger2I(LInteger(^n)) + vali := lintegerToValue(LInteger(^n)) newSize := regi + 1 if newSize > cap(rg.array) { rg.resize(newSize) @@ -1392,7 +1495,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { rg.resize(requiredSize) } } - rg.array[regi] = rg.alloc.LNumber2I(vali) + rg.array[regi] = lnumberToValue(vali) if regi >= rg.top { rg.top = regi + 1 } @@ -1403,6 +1506,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { reg.Push(op) reg.Push(lv) L.Call(1, 1) + if L.yieldState != yieldNone { + L.yieldCont = yieldContLen + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } ret := reg.Pop() if ret.Type() == LTNumber { v, _ := ret.(LNumber) @@ -1421,7 +1531,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { rg.resize(requiredSize) } } - rg.array[regi] = rg.alloc.LNumber2I(vali) + rg.array[regi] = lnumberToValue(vali) if regi >= rg.top { rg.top = regi + 1 } @@ -1464,7 +1574,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { rg.resize(requiredSize) } } - rg.array[regi] = rg.alloc.LNumber2I(vali) + rg.array[regi] = lnumberToValue(vali) if regi >= rg.top { rg.top = regi + 1 } @@ -1484,6 +1594,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { RC := int(lbase) + C RB := int(lbase) + B v := stringConcat(L, RC-RB+1, RC) + if L.yieldState != yieldNone { + L.yieldCont = yieldContConcat + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } // this section is inlined by go-inline // source function is 'func (rg *registry) Set(regi int, vali LValue) ' in '_state.go' { @@ -1517,6 +1634,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { B := int(inst & 0x1ff) //GETB C := int(inst>>9) & 0x1ff //GETC ret := equals(L, L.rkValue(B), L.rkValue(C), false) + if L.yieldState != yieldNone { + L.yieldCont = yieldContCompare + L.yieldContRA = int32(A) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } v := 1 if ret { v = 0 @@ -1530,6 +1654,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { B := int(inst & 0x1ff) //GETB C := int(inst>>9) & 0x1ff //GETC ret := lessThan(L, L.rkValue(B), L.rkValue(C)) + if L.yieldState != yieldNone { + L.yieldCont = yieldContCompare + L.yieldContRA = int32(A) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } v := 1 if ret { v = 0 @@ -1567,8 +1698,21 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { ret = true case 0: ret = false + case -2: + L.yieldCont = yieldContCompare + L.yieldContRA = int32(A) + L.yieldContIdx = cf.Idx + cf.Pc-- + return default: ret = !objectRationalWithError(L, rhs, lhs, "__lt") + if L.yieldState != yieldNone { + L.yieldCont = yieldContCompare + L.yieldContRA = int32(A) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } } } } @@ -1835,9 +1979,20 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { if callGFunction(L, true) { return } - if L.currentFrame == nil || L.currentFrame.GoFunc != nil || (L.currentFrame.Fn != nil && L.currentFrame.Fn.IsG) || luaframe == baseframe { + if L.currentFrame == nil || luaframe == baseframe { return } + // If tail call returned to a Go frame, check for continuation (e.g. pcall) + if L.currentFrame.GoFunc != nil || (L.currentFrame.Fn != nil && L.currentFrame.Fn.IsG) { + ext := L.getFrameExt(L.currentFrame) + if ext != nil && ext.Continuation != nil { + if callGFunction(L, false) { + return + } + } else { + return + } + } } else { base := cf.Base cf.Fn = callable @@ -2237,7 +2392,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { init := int64(initI) + int64(stepI) limit := int64(limitI) step := int64(stepI) - v := reg.alloc.LInteger2I(LInteger(init)) + v := lintegerToValue(LInteger(init)) newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -2295,7 +2450,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { rg.resize(requiredSize) } } - rg.array[regi] = rg.alloc.LNumber2I(vali) + rg.array[regi] = lnumberToValue(vali) if regi >= rg.top { rg.top = regi + 1 } @@ -2314,7 +2469,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { rg.resize(requiredSize) } } - rg.array[regi] = rg.alloc.LNumber2I(vali) + rg.array[regi] = lnumberToValue(vali) if regi >= rg.top { rg.top = regi + 1 } @@ -2373,7 +2528,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { if initI, ok1 := initVal.(LInteger); ok1 { if stepI, ok2 := stepVal.(LInteger); ok2 { result := int64(initI) - int64(stepI) - v := reg.alloc.LInteger2I(LInteger(result)) + v := lintegerToValue(LInteger(result)) newSize := RA + 1 if newSize > cap(reg.array) { reg.resize(newSize) @@ -2397,7 +2552,7 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { if newSize > cap(rg.array) { rg.resize(newSize) } - rg.array[regi] = rg.alloc.LNumber2I(vali) + rg.array[regi] = lnumberToValue(vali) if regi >= rg.top { rg.top = regi + 1 } @@ -2511,6 +2666,13 @@ func mainLoopWithContext(L *LState, baseframe *callFrame) { } } L.callR(2, nret, RA+3) + if L.yieldState != yieldNone { + L.yieldCont = yieldContTForLoop + L.yieldContRA = int32(RA) + L.yieldContIdx = cf.Idx + cf.Pc-- + return + } if value := reg.Get(RA + 3); value != LNil { // this section is inlined by go-inline // source function is 'func (rg *registry) Set(regi int, vali LValue) ' in '_state.go' @@ -2712,9 +2874,12 @@ func switchToParentThread(L *LState, nargs int, haserror bool, kill bool) { L.kill() } - // For yield, mark as yielded (no panic needed - callers check L.yielded) + // For yield, mark as yielded (no panic needed - callers check L.yieldState). + // Preserve yieldUser if already set by coYield before this call. if !haserror && !kill { - L.yielded = true + if L.yieldState == yieldNone { + L.yieldState = yieldSystem + } } } @@ -2778,6 +2943,65 @@ func returnFromTailcall(L *LState, baseframe *callFrame, cf *callFrame, RA int, return false } +// handleYieldContinuation finishes an opcode whose inner call yielded. +// The called function has completed and OP_RETURN placed the result at +// yieldContRB. We execute only the post-call logic of the originating opcode. +func handleYieldContinuation(L *LState, cf *callFrame, inst uint32) { + contType := L.yieldCont + ra := int(L.yieldContRA) + rb := int(L.yieldContRB) + reg := L.reg + + // Clear continuation state before executing post-call logic. + L.yieldCont = yieldContNone + L.yieldContRA = 0 + L.yieldContRB = 0 + L.yieldContIdx = 0 + + switch contType { + case yieldContGetField, yieldContArith, yieldContUnm, yieldContLen, yieldContConcat: + // All these place a single return value at RA. + v := reg.Get(rb) + reg.Set(ra, v) + + case yieldContSetField: + // No result to store, execution continues. + + case yieldContSelf: + // OP_SELF: store method result at RA, self object at RA+1. + v := reg.Get(rb) + reg.Set(ra, v) + // Re-extract B from instruction to get the self object. + lbase := cf.LocalBase + B := int(inst & 0x1ff) //GETB + selfobj := reg.Get(int(lbase) + B) + reg.Set(ra+1, selfobj) + + case yieldContCompare: + // RA holds the A operand from the comparison instruction (not a register index). + // The metamethod result is at rb. Evaluate as bool and apply the skip logic. + v := reg.Get(rb) + result := 1 + if LVAsBool(v) { + result = 0 + } + if result == int(ra) { + cf.Pc++ + } + + case yieldContTForLoop: + // OP_TFORLOOP: iterator results are at RA+3..RA+3+nret-1 (placed by callR/OP_RETURN). + // Check first result: if nil, loop ends. Otherwise update control variable. + if value := reg.Get(ra + 3); value != LNil { + reg.Set(ra+2, value) + // Read the JMP instruction that follows OP_TFORLOOP. + pc := cf.Fn.Proto.Code[cf.Pc] + cf.Pc += int32(int(pc&0x3ffff) - opMaxArgSbx) + } + cf.Pc++ + } +} + func callGFunction(L *LState, tailcall bool) bool { frame := L.currentFrame var gfnret int @@ -2800,10 +3024,15 @@ func callGFunction(L *LState, tailcall bool) bool { } if gfnret < 0 { - // Only call switchToParentThread if not already yielded - // (yield propagation through pcall/xpcall returns -1 with L.yielded already set) - if !L.yielded { + // Only call switchToParentThread for the first yield in the chain. + // Subsequent Go functions returning -1 (pcall, xpcall, coResume) detect + // the yield via yieldState and propagate it without a second thread switch. + if L.yieldState == yieldNone { switchToParentThread(L, L.GetTop(), false, false) + // -2 = user yield (coroutine.yield), -1 = system yield (Go function) + if gfnret == -2 { + L.yieldState = yieldUser + } } return true } @@ -3021,6 +3250,9 @@ func objectArith(L *LState, opcode int, lhs, rhs LValue) LValue { L.reg.Push(rhs) L.Call(2, 1) + if L.yieldState != yieldNone { + return LNil + } return L.reg.Pop() } if str, ok := lhs.(LString); ok { @@ -3070,6 +3302,9 @@ func stringConcat(L *LState, total, last int) LValue { L.reg.Push(lhs) L.reg.Push(rhs) L.Call(2, 1) + if L.yieldState != yieldNone { + return LNil + } rhs = L.reg.Pop() total-- i-- @@ -3188,6 +3423,8 @@ func equals(L *LState, lhs, rhs LValue, raw bool) bool { switch objectRational(L, lhs, rhs, "__eq") { case 1: ret = true + case -2: + return false // yield happened, caller checks L.yieldState default: ret = false } @@ -3217,6 +3454,8 @@ func objectRationalWithError(L *LState, lhs, rhs LValue, event string) bool { return true case 0: return false + case -2: + return false // yield happened, caller checks L.yieldState } L.RaiseError("attempt to compare %v with %v", lhs.Type().String(), rhs.Type().String()) return false @@ -3230,6 +3469,9 @@ func objectRational(L *LState, lhs, rhs LValue, event string) int { L.reg.Push(lhs) L.reg.Push(rhs) L.Call(2, 1) + if L.yieldState != yieldNone { + return -2 + } if LVAsBool(L.reg.Pop()) { return 1 } diff --git a/yield_across_boundaries_test.go b/yield_across_boundaries_test.go new file mode 100644 index 00000000..2afe6bb6 --- /dev/null +++ b/yield_across_boundaries_test.go @@ -0,0 +1,1794 @@ +package lua + +import ( + "context" + "strings" + "testing" +) + +// Tests for yielding across call boundaries that use callR/Call internally. +// These opcodes use nested mainLoop calls which don't propagate yields correctly. +// +// Affected paths: +// - OP_TFORLOOP: generic for loop iterator calls (callR) +// - getFieldString: __index metamethod (Call) +// - setField: __newindex metamethod (Call) +// - objectArith: __add/__sub/__mul/__div/__mod/__pow metamethods (Call) +// - objectConcat: __concat metamethod (Call) +// - objectRational: __eq/__lt/__le metamethods (Call) +// - OP_UNM: __unm metamethod (Call) +// - OP_LEN: __len metamethod (Call) + +// helper: creates a Go function that yields the first argument. +func yieldingGoFunc(L *LState) int { + return L.Yield(L.Get(1)) +} + +// helper: resume coroutine expecting yield, return yielded values. +func expectYield(t *testing.T, L *LState, co *LState, fn *LFunction, args ...LValue) []LValue { + t.Helper() + state, results, err := L.Resume(co, fn, args...) + if err != nil { + t.Fatalf("Resume failed: %v", err) + } + if state != ResumeYield { + t.Fatalf("Expected ResumeYield, got %v (results: %v)", state, results) + } + return results +} + +// helper: resume coroutine expecting completion, return results. +func expectDone(t *testing.T, L *LState, co *LState, fn *LFunction, args ...LValue) []LValue { + t.Helper() + state, results, err := L.Resume(co, fn, args...) + if err != nil { + t.Fatalf("Resume failed: %v", err) + } + if state != ResumeOK { + t.Fatalf("Expected ResumeOK, got %v (results: %v)", state, results) + } + return results +} + +// --------------------------------------------------------------------------- +// OP_TFORLOOP: yield from generic for-loop iterator +// --------------------------------------------------------------------------- + +func TestYieldFromForInIterator_LuaYield(t *testing.T) { + L := NewState() + defer L.Close() + + // Iterator that yields each value before returning it + if err := L.DoString(` + function yielding_iter(items) + local i = 0 + return function() + i = i + 1 + if i > #items then return nil end + coroutine.yield("producing:" .. items[i]) + return items[i] + end + end + + function test() + local results = {} + for val in yielding_iter({"a", "b", "c"}) do + results[#results + 1] = val + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + // Each iteration yields before returning the value + r := expectYield(t, L, co, fn) + if r[0].String() != "producing:a" { + t.Fatalf("Expected 'producing:a', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "producing:b" { + t.Fatalf("Expected 'producing:b', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "producing:c" { + t.Fatalf("Expected 'producing:c', got %v", r[0]) + } + + // Final resume completes + results := expectDone(t, L, co, fn) + if results[0].String() != "a,b,c" { + t.Errorf("Expected 'a,b,c', got %v", results[0]) + } +} + +func TestYieldFromForInIterator_GoFunctionYield(t *testing.T) { + L := NewState() + defer L.Close() + + // Go function that yields its argument + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + function yielding_iter(items) + local i = 0 + return function() + i = i + 1 + if i > #items then return nil end + go_yield("producing:" .. items[i]) + return items[i] + end + end + + function test() + local results = {} + for val in yielding_iter({"a", "b", "c"}) do + results[#results + 1] = val + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + // Each iteration: Go function yields, then resume returns to iterator + for _, expected := range []string{"producing:a", "producing:b", "producing:c"} { + r := expectYield(t, L, co, fn) + if r[0].String() != expected { + t.Fatalf("Expected %q, got %v", expected, r[0]) + } + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "a,b,c" { + t.Errorf("Expected 'a,b,c', got %v", results[0]) + } +} + +func TestYieldFromForInIterator_ResumeValues(t *testing.T) { + L := NewState() + defer L.Close() + + // Iterator that yields a request and uses the resume value + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + function lazy_iter(ids) + local i = 0 + return function() + i = i + 1 + if i > #ids then return nil end + -- yield request, receive loaded data on resume + local data = go_yield("load:" .. ids[i]) + return data + end + end + + function test() + local results = {} + for val in lazy_iter({"x", "y"}) do + results[#results + 1] = val + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + // First iteration yields load request + r := expectYield(t, L, co, fn) + if r[0].String() != "load:x" { + t.Fatalf("Expected 'load:x', got %v", r[0]) + } + + // Resume with loaded data + r = expectYield(t, L, co, fn, LString("data_x")) + if r[0].String() != "load:y" { + t.Fatalf("Expected 'load:y', got %v", r[0]) + } + + // Resume with second loaded data, completes + results := expectDone(t, L, co, fn, LString("data_y")) + if results[0].String() != "data_x,data_y" { + t.Errorf("Expected 'data_x,data_y', got %v", results[0]) + } +} + +func TestYieldFromForInIterator_MultipleReturnValues(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + // Iterator returns key, value pairs (like pairs/ipairs) + if err := L.DoString(` + function kv_iter(tbl) + local keys = {} + for k in pairs(tbl) do keys[#keys + 1] = k end + table.sort(keys) + local i = 0 + return function() + i = i + 1 + if i > #keys then return nil end + go_yield("fetch:" .. keys[i]) + return keys[i], tbl[keys[i]] + end + end + + function test() + local results = {} + for k, v in kv_iter({a=1, b=2}) do + results[#results + 1] = k .. "=" .. v + end + table.sort(results) + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + expectYield(t, L, co, fn) // fetch:a + expectYield(t, L, co, fn) // fetch:b + results := expectDone(t, L, co, fn) + if results[0].String() != "a=1,b=2" { + t.Errorf("Expected 'a=1,b=2', got %v", results[0]) + } +} + +func TestYieldFromForInIterator_WithPcall(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + function yielding_iter(n) + local i = 0 + return function() + i = i + 1 + if i > n then return nil end + go_yield("item:" .. i) + return i + end + end + + function test() + local ok, result = pcall(function() + local sum = 0 + for v in yielding_iter(3) do + sum = sum + v + end + return sum + end) + return ok, result + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + expectYield(t, L, co, fn) // item:1 + expectYield(t, L, co, fn) // item:2 + expectYield(t, L, co, fn) // item:3 + + results := expectDone(t, L, co, fn) + if results[0] != LTrue { + t.Errorf("Expected pcall success, got %v", results[0]) + } + if LVAsNumber(results[1]) != 6 { + t.Errorf("Expected 6, got %v", results[1]) + } +} + +func TestYieldFromForInIterator_ErrorAfterYield(t *testing.T) { + L := NewState() + defer L.Close() + + // Track how many times the iterator is called per resume cycle. + // The bug causes double-execution: iterator called twice per single resume. + callCount := 0 + L.SetGlobal("go_yield", L.NewFunction(func(L *LState) int { + callCount++ + return L.Yield(L.Get(1)) + })) + + if err := L.DoString(` + function failing_iter() + local i = 0 + return function() + i = i + 1 + if i > 2 then error("iterator exhausted badly") end + go_yield("item:" .. i) + return i + end + end + + function test() + local ok, err = pcall(function() + local sum = 0 + for v in failing_iter() do + sum = sum + v + end + return sum + end) + return ok, err + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + callCount = 0 + r := expectYield(t, L, co, fn) + if r[0].String() != "item:1" { + t.Fatalf("Expected 'item:1', got %v", r[0]) + } + if callCount != 1 { + t.Fatalf("Expected go_yield called once per resume, called %d times (double-execution bug)", callCount) + } + + callCount = 0 + r = expectYield(t, L, co, fn) + if r[0].String() != "item:2" { + t.Fatalf("Expected 'item:2', got %v", r[0]) + } + if callCount != 1 { + t.Fatalf("Expected go_yield called once per resume, called %d times (double-execution bug)", callCount) + } + + // Third iteration errors before yield + results := expectDone(t, L, co, fn) + if results[0] != LFalse { + t.Errorf("Expected pcall failure, got %v", results[0]) + } +} + +func TestYieldFromForInIterator_NestedForLoops(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + function yiter(items) + local i = 0 + return function() + i = i + 1 + if i > #items then return nil end + go_yield("y:" .. items[i]) + return items[i] + end + end + + function test() + local results = {} + for a in yiter({"x", "y"}) do + for b in yiter({"1", "2"}) do + results[#results + 1] = a .. b + end + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + // outer "x", inner "1", inner "2", outer "y", inner "1", inner "2" + expectedYields := []string{"y:x", "y:1", "y:2", "y:y", "y:1", "y:2"} + for _, expected := range expectedYields { + r := expectYield(t, L, co, fn) + if r[0].String() != expected { + t.Fatalf("Expected %q, got %v", expected, r[0]) + } + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "x1,x2,y1,y2" { + t.Errorf("Expected 'x1,x2,y1,y2', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// __index metamethod: yield from field access +// --------------------------------------------------------------------------- + +func TestYieldFromIndexMetamethod(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __index = function(self, key) + go_yield("index:" .. key) + return rawget(self, "_data")[key] + end + } + + function make_proxy(data) + local obj = {_data = data} + setmetatable(obj, mt) + return obj + end + + function test() + local p = make_proxy({name = "alice", age = 30}) + local n = p.name + local a = p.age + return n .. ":" .. a + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "index:name" { + t.Fatalf("Expected 'index:name', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "index:age" { + t.Fatalf("Expected 'index:age', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "alice:30" { + t.Errorf("Expected 'alice:30', got %v", results[0]) + } +} + +func TestYieldFromIndexMetamethod_MethodCall(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + // OP_SELF uses getFieldString which calls __index + if err := L.DoString(` + local mt = { + __index = function(self, key) + go_yield("lookup:" .. key) + if key == "greet" then + return function(self) + return "hello " .. rawget(self, "name") + end + end + return rawget(self, key) + end + } + + function test() + local obj = setmetatable({name = "bob"}, mt) + return obj:greet() + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "lookup:greet" { + t.Fatalf("Expected 'lookup:greet', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "hello bob" { + t.Errorf("Expected 'hello bob', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// __newindex metamethod: yield from field assignment +// --------------------------------------------------------------------------- + +func TestYieldFromNewIndexMetamethod(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local log = {} + local mt = { + __newindex = function(self, key, value) + go_yield("set:" .. key .. "=" .. tostring(value)) + rawset(self, key, value) + end + } + + function test() + local obj = setmetatable({}, mt) + obj.x = 10 + obj.y = 20 + return obj.x + obj.y + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "set:x=10" { + t.Fatalf("Expected 'set:x=10', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "set:y=20" { + t.Fatalf("Expected 'set:y=20', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if LVAsNumber(results[0]) != 30 { + t.Errorf("Expected 30, got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// __add metamethod: yield from arithmetic +// --------------------------------------------------------------------------- + +func TestYieldFromAddMetamethod(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __add = function(a, b) + go_yield("add:" .. tostring(a.v) .. "+" .. tostring(b.v)) + return setmetatable({v = a.v + b.v}, getmetatable(a)) + end + } + + function num(v) + return setmetatable({v = v}, mt) + end + + function test() + local a = num(10) + local b = num(20) + local c = a + b + return c.v + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "add:10+20" { + t.Fatalf("Expected 'add:10+20', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if LVAsNumber(results[0]) != 30 { + t.Errorf("Expected 30, got %v", results[0]) + } +} + +func TestYieldFromSubMetamethod(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __sub = function(a, b) + go_yield("sub") + return setmetatable({v = a.v - b.v}, getmetatable(a)) + end + } + function num(v) return setmetatable({v = v}, mt) end + + function test() + local r = num(50) - num(8) + return r.v + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + expectYield(t, L, co, fn) + results := expectDone(t, L, co, fn) + if LVAsNumber(results[0]) != 42 { + t.Errorf("Expected 42, got %v", results[0]) + } +} + +func TestYieldFromMulMetamethod(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __mul = function(a, b) + go_yield("mul") + return setmetatable({v = a.v * b.v}, getmetatable(a)) + end + } + function num(v) return setmetatable({v = v}, mt) end + + function test() + local r = num(6) * num(7) + return r.v + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + expectYield(t, L, co, fn) + results := expectDone(t, L, co, fn) + if LVAsNumber(results[0]) != 42 { + t.Errorf("Expected 42, got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// __concat metamethod: yield from string concatenation +// --------------------------------------------------------------------------- + +func TestYieldFromConcatMetamethod(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __concat = function(a, b) + local av = type(a) == "table" and a.v or tostring(a) + local bv = type(b) == "table" and b.v or tostring(b) + go_yield("concat:" .. av .. ".." .. bv) + return setmetatable({v = av .. bv}, getmetatable(a) or getmetatable(b)) + end + } + function str(v) return setmetatable({v = v}, mt) end + + function test() + local r = str("hello") .. str(" world") + return r.v + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "concat:hello.. world" { + t.Fatalf("Expected 'concat:hello.. world', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "hello world" { + t.Errorf("Expected 'hello world', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// __unm metamethod: yield from unary minus +// --------------------------------------------------------------------------- + +func TestYieldFromUnmMetamethod(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __unm = function(a) + go_yield("neg:" .. a.v) + return setmetatable({v = -a.v}, getmetatable(a)) + end + } + function num(v) return setmetatable({v = v}, mt) end + + function test() + local r = -num(42) + return r.v + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "neg:42" { + t.Fatalf("Expected 'neg:42', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if LVAsNumber(results[0]) != -42 { + t.Errorf("Expected -42, got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// __len metamethod: yield from # operator +// --------------------------------------------------------------------------- + +func TestYieldFromLenMetamethod(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __len = function(a) + go_yield("len") + return #rawget(a, "items") + end + } + + function test() + local obj = setmetatable({items = {1, 2, 3, 4, 5}}, mt) + return #obj + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + expectYield(t, L, co, fn) + results := expectDone(t, L, co, fn) + if LVAsNumber(results[0]) != 5 { + t.Errorf("Expected 5, got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// __eq metamethod: yield from equality comparison +// --------------------------------------------------------------------------- + +func TestYieldFromEqMetamethod(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __eq = function(a, b) + go_yield("eq:" .. a.v .. "==" .. b.v) + return a.v == b.v + end + } + function val(v) return setmetatable({v = v}, mt) end + + function test() + local a = val(42) + local b = val(42) + local c = val(99) + local r1 = a == b + local r2 = a == c + return r1, r2 + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "eq:42==42" { + t.Fatalf("Expected 'eq:42==42', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "eq:42==99" { + t.Fatalf("Expected 'eq:42==99', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if results[0] != LTrue { + t.Errorf("Expected true for 42==42, got %v", results[0]) + } + if results[1] != LFalse { + t.Errorf("Expected false for 42==99, got %v", results[1]) + } +} + +// --------------------------------------------------------------------------- +// __lt metamethod: yield from less-than comparison +// --------------------------------------------------------------------------- + +func TestYieldFromLtMetamethod(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __lt = function(a, b) + go_yield("lt:" .. a.v .. "<" .. b.v) + return a.v < b.v + end + } + function val(v) return setmetatable({v = v}, mt) end + + function test() + local r1 = val(1) < val(2) + local r2 = val(5) < val(3) + return r1, r2 + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + expectYield(t, L, co, fn) // 1<2 + expectYield(t, L, co, fn) // 5<3 + + results := expectDone(t, L, co, fn) + if results[0] != LTrue { + t.Errorf("Expected true for 1<2, got %v", results[0]) + } + if results[1] != LFalse { + t.Errorf("Expected false for 5<3, got %v", results[1]) + } +} + +// --------------------------------------------------------------------------- +// __le metamethod: yield from less-than-or-equal comparison +// --------------------------------------------------------------------------- + +func TestYieldFromLeMetamethod(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __le = function(a, b) + go_yield("le:" .. a.v .. "<=" .. b.v) + return a.v <= b.v + end + } + function val(v) return setmetatable({v = v}, mt) end + + function test() + local r1 = val(3) <= val(3) + local r2 = val(4) <= val(3) + return r1, r2 + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + expectYield(t, L, co, fn) // 3<=3 + expectYield(t, L, co, fn) // 4<=3 + + results := expectDone(t, L, co, fn) + if results[0] != LTrue { + t.Errorf("Expected true for 3<=3, got %v", results[0]) + } + if results[1] != LFalse { + t.Errorf("Expected false for 4<=3, got %v", results[1]) + } +} + +// --------------------------------------------------------------------------- +// Combined: yield from multiple boundary types in one coroutine +// --------------------------------------------------------------------------- + +func TestYieldFromMixedBoundaries(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __index = function(self, key) + go_yield("index:" .. key) + return rawget(self, "_d")[key] + end, + __add = function(a, b) + go_yield("add") + return a._d.v + b._d.v + end + } + + function wrap(tbl) + return setmetatable({_d = tbl}, mt) + end + + function yielding_iter(items) + local i = 0 + return function() + i = i + 1 + if i > #items then return nil end + go_yield("iter:" .. i) + return items[i] + end + end + + function test() + local a = wrap({v = 10}) + local b = wrap({v = 20}) + + -- triggers __index yield + local av = a.v + + -- triggers __add yield + local sum = a + b + + -- triggers iterator yield + local items = {} + for val in yielding_iter({av, sum}) do + items[#items + 1] = val + end + + return table.concat(items, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + // __index for a.v + r := expectYield(t, L, co, fn) + if r[0].String() != "index:v" { + t.Fatalf("Expected 'index:v', got %v", r[0]) + } + + // __add for a + b + r = expectYield(t, L, co, fn) + if r[0].String() != "add" { + t.Fatalf("Expected 'add', got %v", r[0]) + } + + // iterator yields + r = expectYield(t, L, co, fn) + if r[0].String() != "iter:1" { + t.Fatalf("Expected 'iter:1', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "iter:2" { + t.Fatalf("Expected 'iter:2', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "10,30" { + t.Errorf("Expected '10,30', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// Edge case: yield from iterator inside pcall inside another for loop +// --------------------------------------------------------------------------- + +func TestYieldFromIterator_InsidePcallInsideForLoop(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + function yiter(items) + local i = 0 + return function() + i = i + 1 + if i > #items then return nil end + go_yield("y:" .. items[i]) + return items[i] + end + end + + function test() + local results = {} + for outer in yiter({"A", "B"}) do + local ok, inner_result = pcall(function() + local inner = {} + for v in yiter({"1", "2"}) do + inner[#inner + 1] = v + end + return table.concat(inner, "+") + end) + if ok then + results[#results + 1] = outer .. ":" .. inner_result + end + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + // A, 1, 2, B, 1, 2 + expected := []string{"y:A", "y:1", "y:2", "y:B", "y:1", "y:2"} + for _, exp := range expected { + r := expectYield(t, L, co, fn) + if r[0].String() != exp { + t.Fatalf("Expected %q, got %v", exp, r[0]) + } + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "A:1+2,B:1+2" { + t.Errorf("Expected 'A:1+2,B:1+2', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// coroutine.wrap: yield from iterator using coroutine.wrap +// --------------------------------------------------------------------------- + +func TestYieldFromCoroutineWrapIterator(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + function make_iter(items) + return coroutine.wrap(function() + for i = 1, #items do + go_yield("load:" .. items[i]) + coroutine.yield(items[i]) + end + end) + end + + function test() + local results = {} + for val in make_iter({"a", "b", "c"}) do + results[#results + 1] = val + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + for _, item := range []string{"a", "b", "c"} { + r := expectYield(t, L, co, fn) + if r[0].String() != "load:"+item { + t.Fatalf("Expected 'load:%s', got %v", item, r[0]) + } + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "a,b,c" { + t.Errorf("Expected 'a,b,c', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// System yield from coroutine.resume (non-wrapped, explicit resume) +// --------------------------------------------------------------------------- + +func TestYieldFromCoroutineResumeExplicit(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + function test() + local th = coroutine.create(function() + go_yield("sys:1") + coroutine.yield("user:1") + go_yield("sys:2") + return "done" + end) + + local results = {} + while true do + local ok, val = coroutine.resume(th) + if not ok then break end + results[#results + 1] = val + if coroutine.status(th) == "dead" then break end + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "sys:1" { + t.Fatalf("Expected 'sys:1', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "sys:2" { + t.Fatalf("Expected 'sys:2', got %v", r[0]) + } + + // System yields propagate to the host transparently — they don't appear + // in the Lua code's coroutine.resume return values. The Lua code only sees + // the user yield and the final return. + results := expectDone(t, L, co, fn) + if results[0].String() != "user:1,done" { + t.Errorf("Expected 'user:1,done', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// System yield through nested coroutine.wrap (2 levels deep) +// --------------------------------------------------------------------------- + +func TestYieldFromNestedCoroutineWrap(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + function inner_iter(items) + return coroutine.wrap(function() + for _, v in ipairs(items) do + go_yield("inner:" .. v) + coroutine.yield(v) + end + end) + end + + function outer_iter() + return coroutine.wrap(function() + for v in inner_iter({"x", "y"}) do + go_yield("outer:" .. v) + coroutine.yield("got:" .. v) + end + end) + end + + function test() + local results = {} + for val in outer_iter() do + results[#results + 1] = val + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + expected := []string{"inner:x", "outer:x", "inner:y", "outer:y"} + for _, exp := range expected { + r := expectYield(t, L, co, fn) + if r[0].String() != exp { + t.Fatalf("Expected %q, got %v", exp, r[0]) + } + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "got:x,got:y" { + t.Errorf("Expected 'got:x,got:y', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// System yield interleaved with pcall inside coroutine.wrap +// --------------------------------------------------------------------------- + +func TestYieldFromCoroutineWrapWithPcall(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + function test() + local iter = coroutine.wrap(function() + go_yield("before_pcall") + local ok, err = pcall(function() + go_yield("inside_pcall") + error("planned_error") + end) + coroutine.yield(ok) + go_yield("after_pcall") + coroutine.yield(tostring(err)) + end) + + local results = {} + for val in iter do + results[#results + 1] = tostring(val) + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "before_pcall" { + t.Fatalf("Expected 'before_pcall', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "inside_pcall" { + t.Fatalf("Expected 'inside_pcall', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "after_pcall" { + t.Fatalf("Expected 'after_pcall', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + got := results[0].String() + if !strings.Contains(got, "false") || !strings.Contains(got, "planned_error") { + t.Errorf("Expected result containing 'false' and 'planned_error', got %v", got) + } +} + +// --------------------------------------------------------------------------- +// System yield from __index inside coroutine.wrap iterator +// --------------------------------------------------------------------------- + +func TestYieldFromMetamethodInsideCoroutineWrap(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __index = function(t, k) + go_yield("index:" .. k) + return rawget(t, "_" .. k) + end + } + + function test() + local obj = setmetatable({_name = "alice", _age = "30"}, mt) + local iter = coroutine.wrap(function() + coroutine.yield(obj.name) + coroutine.yield(obj.age) + end) + + local results = {} + for val in iter do + results[#results + 1] = val + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "index:name" { + t.Fatalf("Expected 'index:name', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "index:age" { + t.Fatalf("Expected 'index:age', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "alice,30" { + t.Errorf("Expected 'alice,30', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// System yield with multiple return values propagated through wrap +// --------------------------------------------------------------------------- + +func TestYieldFromCoroutineWrapMultipleValues(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield_multi", L.NewFunction(func(L *LState) int { + return L.Yield(L.Get(1), L.Get(2)) + })) + + if err := L.DoString(` + function test() + local iter = coroutine.wrap(function() + go_yield_multi("a", "b") + coroutine.yield("single") + go_yield_multi("c", "d") + end) + + local results = {} + for val in iter do + results[#results + 1] = val + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if len(r) < 2 || r[0].String() != "a" || r[1].String() != "b" { + t.Fatalf("Expected [a, b], got %v", r) + } + + r = expectYield(t, L, co, fn) + if len(r) < 2 || r[0].String() != "c" || r[1].String() != "d" { + t.Fatalf("Expected [c, d], got %v", r) + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "single" { + t.Errorf("Expected 'single', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// System yield from arithmetic metamethod inside coroutine.wrap +// --------------------------------------------------------------------------- + +func TestYieldFromArithInsideCoroutineWrap(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __add = function(a, b) + go_yield("add:" .. rawget(a, "v") .. "+" .. rawget(b, "v")) + return setmetatable({v = rawget(a, "v") + rawget(b, "v")}, getmetatable(a)) + end + } + + function num(x) + return setmetatable({v = x}, mt) + end + + function test() + local iter = coroutine.wrap(function() + local a = num(10) + local b = num(20) + local c = a + b + coroutine.yield(rawget(c, "v")) + local d = c + num(5) + coroutine.yield(rawget(d, "v")) + end) + + local results = {} + for val in iter do + results[#results + 1] = tostring(val) + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "add:10+20" { + t.Fatalf("Expected 'add:10+20', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "add:30+5" { + t.Fatalf("Expected 'add:30+5', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "30,35" { + t.Errorf("Expected '30,35', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// System yield from for-in iterator with resume values passed back +// --------------------------------------------------------------------------- + +func TestYieldFromIteratorWithResumeValues(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_request", L.NewFunction(func(L *LState) int { + query := L.CheckString(1) + L.SetTop(0) + L.Push(LString(query)) + return -1 + })) + + if err := L.DoString(` + function test() + local iter = coroutine.wrap(function() + local resp1 = go_request("get_name") + coroutine.yield(resp1) + local resp2 = go_request("get_age") + coroutine.yield(resp2) + end) + + local results = {} + for val in iter do + results[#results + 1] = val + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "get_name" { + t.Fatalf("Expected 'get_name', got %v", r[0]) + } + + r = expectYield(t, L, co, fn, LString("Alice")) + if r[0].String() != "get_age" { + t.Fatalf("Expected 'get_age', got %v", r[0]) + } + + results := expectDone(t, L, co, fn, LString("30")) + if results[0].String() != "Alice,30" { + t.Errorf("Expected 'Alice,30', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// System yield through coroutine.wrap that dies mid-iteration +// --------------------------------------------------------------------------- + +func TestYieldFromCoroutineWrapWithError(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + function test() + local iter = coroutine.wrap(function() + go_yield("step1") + coroutine.yield("ok1") + go_yield("step2") + error("boom") + end) + + local results = {} + local ok, err = pcall(function() + for val in iter do + results[#results + 1] = val + end + end) + results[#results + 1] = tostring(err) + return table.concat(results, "|") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "step1" { + t.Fatalf("Expected 'step1', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "step2" { + t.Fatalf("Expected 'step2', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + got := results[0].String() + if !strings.HasPrefix(got, "ok1|") || !strings.Contains(got, "boom") { + t.Errorf("Expected 'ok1|...boom', got %v", got) + } +} + +// --------------------------------------------------------------------------- +// System yield with __newindex inside coroutine.wrap +// --------------------------------------------------------------------------- + +func TestYieldFromNewIndexInsideCoroutineWrap(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local store = {} + local mt = { + __newindex = function(t, k, v) + go_yield("set:" .. k .. "=" .. tostring(v)) + store[k] = v + end, + __index = function(t, k) + return store[k] + end + } + + function test() + local obj = setmetatable({}, mt) + local iter = coroutine.wrap(function() + obj.x = 10 + obj.y = 20 + coroutine.yield(obj.x + obj.y) + end) + + local results = {} + for val in iter do + results[#results + 1] = tostring(val) + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "set:x=10" { + t.Fatalf("Expected 'set:x=10', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "set:y=20" { + t.Fatalf("Expected 'set:y=20', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "30" { + t.Errorf("Expected '30', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// High-frequency system yields: many yields in tight loop +// --------------------------------------------------------------------------- + +func TestYieldFromIteratorHighFrequency(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + function test() + local sum = 0 + for i = 1, 100 do + go_yield(i) + sum = sum + i + end + return sum + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + for i := 1; i <= 100; i++ { + r := expectYield(t, L, co, fn) + if LVAsNumber(r[0]) != LNumber(i) { + t.Fatalf("Yield %d: expected %d, got %v", i, i, r[0]) + } + } + + results := expectDone(t, L, co, fn) + if LVAsNumber(results[0]) != LNumber(5050) { + t.Errorf("Expected 5050, got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// System yield from comparison metamethod inside coroutine.wrap +// --------------------------------------------------------------------------- + +func TestYieldFromComparisonInsideCoroutineWrap(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __lt = function(a, b) + go_yield("cmp:" .. rawget(a, "v") .. "<" .. rawget(b, "v")) + return rawget(a, "v") < rawget(b, "v") + end + } + + function val(x) + return setmetatable({v = x}, mt) + end + + function test() + local iter = coroutine.wrap(function() + local a, b = val(3), val(7) + if a < b then + coroutine.yield("less") + else + coroutine.yield("greater") + end + local c = val(10) + if c < b then + coroutine.yield("less2") + else + coroutine.yield("greater2") + end + end) + + local results = {} + for val in iter do + results[#results + 1] = val + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "cmp:3<7" { + t.Fatalf("Expected 'cmp:3<7', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "cmp:10<7" { + t.Fatalf("Expected 'cmp:10<7', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "less,greater2" { + t.Errorf("Expected 'less,greater2', got %v", results[0]) + } +} + +// --------------------------------------------------------------------------- +// System yield from __len and __concat inside coroutine.wrap +// --------------------------------------------------------------------------- + +func TestYieldFromLenConcatInsideCoroutineWrap(t *testing.T) { + L := NewState() + defer L.Close() + + L.SetGlobal("go_yield", L.NewFunction(yieldingGoFunc)) + + if err := L.DoString(` + local mt = { + __len = function(t) + go_yield("len:" .. rawget(t, "name")) + return rawget(t, "size") + end, + __concat = function(a, b) + local av = type(a) == "table" and rawget(a, "name") or tostring(a) + local bv = type(b) == "table" and rawget(b, "name") or tostring(b) + go_yield("concat:" .. av .. ".." .. bv) + return av .. bv + end + } + + function obj(name, size) + return setmetatable({name = name, size = size}, mt) + end + + function test() + local iter = coroutine.wrap(function() + local a = obj("foo", 3) + local b = obj("bar", 5) + coroutine.yield(#a) + coroutine.yield(#b) + coroutine.yield(a .. b) + end) + + local results = {} + for val in iter do + results[#results + 1] = tostring(val) + end + return table.concat(results, ",") + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("test").(*LFunction) + + r := expectYield(t, L, co, fn) + if r[0].String() != "len:foo" { + t.Fatalf("Expected 'len:foo', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "len:bar" { + t.Fatalf("Expected 'len:bar', got %v", r[0]) + } + + r = expectYield(t, L, co, fn) + if r[0].String() != "concat:foo..bar" { + t.Fatalf("Expected 'concat:foo..bar', got %v", r[0]) + } + + results := expectDone(t, L, co, fn) + if results[0].String() != "3,5,foobar" { + t.Errorf("Expected '3,5,foobar', got %v", results[0]) + } +}