From 11ebebde993a6f2c1c52a955d6e4c41fbeafa4b8 Mon Sep 17 00:00:00 2001 From: Shreyansh Sancheti Date: Mon, 23 Mar 2026 16:35:32 +0530 Subject: [PATCH] hcs: add driver interface for HCS system and process operations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The HCS layer (internal/hcs/system.go, process.go) calls vmcompute.dll directly, making it impossible to unit test the System/Process logic or swap the underlying DLL for the upcoming V2 migration (computecore.dll). This change introduces an internal hcsDriver interface that wraps all vmcompute system and process API calls. The System struct now holds a driver field (defaulting to the vmcomputeDriver, which delegates to the existing vmcompute package). Process inherits the driver from its parent System. No behavioral change — the vmcomputeDriver methods are one-liner delegations to the same vmcompute functions previously called directly. This enables writing unit tests against the HCS layer without admin privileges or a live HCS service, by injecting a mock driver. 17 tests are included covering: handle guards, error swallowing, async pending paths, system crash and service disconnect during operations, the waitBackground exit classification, multi-goroutine Wait fan-out, and late callback safety after unregistration. Signed-off-by: Shreyansh Sancheti --- internal/hcs/export_test.go | 97 +++++ internal/hcs/mock/mock_driver.go | 394 ++++++++++++++++++ internal/hcs/process.go | 23 +- internal/hcs/system.go | 43 +- internal/hcs/system_driver.go | 98 +++++ internal/hcs/system_driver_vmcompute.go | 110 ++++++ internal/hcs/system_test.go | 505 ++++++++++++++++++++++++ 7 files changed, 1241 insertions(+), 29 deletions(-) create mode 100644 internal/hcs/export_test.go create mode 100644 internal/hcs/mock/mock_driver.go create mode 100644 internal/hcs/system_driver.go create mode 100644 internal/hcs/system_driver_vmcompute.go create mode 100644 internal/hcs/system_test.go diff --git a/internal/hcs/export_test.go b/internal/hcs/export_test.go new file mode 100644 index 0000000000..cffe23cedb --- /dev/null +++ b/internal/hcs/export_test.go @@ -0,0 +1,97 @@ +//go:build windows + +package hcs + +import ( + "context" + + "github.com/Microsoft/hcsshim/internal/vmcompute" +) + +// SetDriverForTest replaces the SystemDriver on a System instance. +// This is intended for use in tests to inject a mock driver. +func SetDriverForTest(s *System, d hcsDriver) { + s.driver = d +} + +// FireNotificationForTest simulates an HCS notification arriving on the +// notification channel for the given callback number. This allows tests +// to exercise the async completion paths without a real HCS system. +func FireNotificationForTest(callbackNumber uintptr, notification hcsNotification, result error) { + callbackMapLock.RLock() + ctx := callbackMap[callbackNumber] + callbackMapLock.RUnlock() + if ctx == nil { + return + } + if ch, ok := ctx.channels[notification]; ok { + ch <- result + } +} + +// GetCallbackNumberForTest returns the callback number of a System. +func GetCallbackNumberForTest(s *System) uintptr { + return s.callbackNumber +} + +// NewTestSystem creates a System with the given ID for testing. +// The handle is zero (closed state), useful for testing error guards. +func NewTestSystem(id string) *System { + return newSystem(id) +} + +// NewTestSystemWithDriver creates a System with a custom driver and a +// non-zero handle for testing operations that require an open handle. +func NewTestSystemWithDriver(id string, driver hcsDriver, handle uintptr) *System { + s := newSystem(id) + s.driver = driver + s.handle = vmcompute.HcsSystem(handle) + return s +} + +// RegisterCallbackForTest calls registerCallback on a test system so the +// notification channels are set up. This is needed for testing async paths +// (operations that return ErrVmcomputeOperationPending and wait on channels). +func RegisterCallbackForTest(s *System) error { + return s.registerCallback(context.Background()) +} + +// StartWaitBackgroundForTest launches the waitBackground goroutine for a test +// system. This is normally done by CreateComputeSystem/OpenComputeSystem. +func StartWaitBackgroundForTest(s *System) { + go s.waitBackground() +} + +// WaitErrorForTest returns the waitError field of a System. +func WaitErrorForTest(s *System) error { + return s.waitError +} + +// ExitErrorForTest returns the exitError field of a System. +func ExitErrorForTest(s *System) error { + return s.exitError +} + +// UnregisterCallbackForTest calls unregisterCallback on a test system. +func UnregisterCallbackForTest(s *System) error { + return s.unregisterCallback(context.Background()) +} + +// CallbackExistsForTest checks if a callbackNumber is still in the global callbackMap. +func CallbackExistsForTest(callbackNumber uintptr) bool { + callbackMapLock.RLock() + defer callbackMapLock.RUnlock() + _, exists := callbackMap[callbackNumber] + return exists +} + +// Expose notification values for test assertions. +var ( + HcsNotificationSystemExited = hcsNotificationSystemExited + HcsNotificationSystemCreateCompleted = hcsNotificationSystemCreateCompleted + HcsNotificationSystemStartCompleted = hcsNotificationSystemStartCompleted + HcsNotificationSystemPauseCompleted = hcsNotificationSystemPauseCompleted + HcsNotificationSystemResumeCompleted = hcsNotificationSystemResumeCompleted + HcsNotificationSystemSaveCompleted = hcsNotificationSystemSaveCompleted + HcsNotificationServiceDisconnect = hcsNotificationServiceDisconnect +) diff --git a/internal/hcs/mock/mock_driver.go b/internal/hcs/mock/mock_driver.go new file mode 100644 index 0000000000..19d9a03c51 --- /dev/null +++ b/internal/hcs/mock/mock_driver.go @@ -0,0 +1,394 @@ +//go:build windows + +// Code generated by MockGen. DO NOT EDIT. +// Source: system_driver.go +// +// Generated by this command: +// +// mockgen -source=system_driver.go -package=mock -destination=mock/mock_driver.go +// + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + syscall "syscall" + + vmcompute "github.com/Microsoft/hcsshim/internal/vmcompute" + gomock "go.uber.org/mock/gomock" +) + +// MockhcsDriver is a mock of hcsDriver interface. +type MockhcsDriver struct { + ctrl *gomock.Controller + recorder *MockhcsDriverMockRecorder + isgomock struct{} +} + +// MockhcsDriverMockRecorder is the mock recorder for MockhcsDriver. +type MockhcsDriverMockRecorder struct { + mock *MockhcsDriver +} + +// NewMockhcsDriver creates a new mock instance. +func NewMockhcsDriver(ctrl *gomock.Controller) *MockhcsDriver { + mock := &MockhcsDriver{ctrl: ctrl} + mock.recorder = &MockhcsDriverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockhcsDriver) EXPECT() *MockhcsDriverMockRecorder { + return m.recorder +} + +// CloseComputeSystem mocks base method. +func (m *MockhcsDriver) CloseComputeSystem(ctx context.Context, system vmcompute.HcsSystem) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseComputeSystem", ctx, system) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseComputeSystem indicates an expected call of CloseComputeSystem. +func (mr *MockhcsDriverMockRecorder) CloseComputeSystem(ctx, system any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseComputeSystem", reflect.TypeOf((*MockhcsDriver)(nil).CloseComputeSystem), ctx, system) +} + +// CloseProcess mocks base method. +func (m *MockhcsDriver) CloseProcess(ctx context.Context, process vmcompute.HcsProcess) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseProcess", ctx, process) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseProcess indicates an expected call of CloseProcess. +func (mr *MockhcsDriverMockRecorder) CloseProcess(ctx, process any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseProcess", reflect.TypeOf((*MockhcsDriver)(nil).CloseProcess), ctx, process) +} + +// CreateComputeSystem mocks base method. +func (m *MockhcsDriver) CreateComputeSystem(ctx context.Context, id, configuration string, identity syscall.Handle) (vmcompute.HcsSystem, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateComputeSystem", ctx, id, configuration, identity) + ret0, _ := ret[0].(vmcompute.HcsSystem) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// CreateComputeSystem indicates an expected call of CreateComputeSystem. +func (mr *MockhcsDriverMockRecorder) CreateComputeSystem(ctx, id, configuration, identity any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateComputeSystem", reflect.TypeOf((*MockhcsDriver)(nil).CreateComputeSystem), ctx, id, configuration, identity) +} + +// CreateProcess mocks base method. +func (m *MockhcsDriver) CreateProcess(ctx context.Context, system vmcompute.HcsSystem, processParameters string) (vmcompute.HcsProcessInformation, vmcompute.HcsProcess, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateProcess", ctx, system, processParameters) + ret0, _ := ret[0].(vmcompute.HcsProcessInformation) + ret1, _ := ret[1].(vmcompute.HcsProcess) + ret2, _ := ret[2].(string) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// CreateProcess indicates an expected call of CreateProcess. +func (mr *MockhcsDriverMockRecorder) CreateProcess(ctx, system, processParameters any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProcess", reflect.TypeOf((*MockhcsDriver)(nil).CreateProcess), ctx, system, processParameters) +} + +// GetComputeSystemProperties mocks base method. +func (m *MockhcsDriver) GetComputeSystemProperties(ctx context.Context, system vmcompute.HcsSystem, propertyQuery string) (string, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetComputeSystemProperties", ctx, system, propertyQuery) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetComputeSystemProperties indicates an expected call of GetComputeSystemProperties. +func (mr *MockhcsDriverMockRecorder) GetComputeSystemProperties(ctx, system, propertyQuery any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetComputeSystemProperties", reflect.TypeOf((*MockhcsDriver)(nil).GetComputeSystemProperties), ctx, system, propertyQuery) +} + +// GetProcessInfo mocks base method. +func (m *MockhcsDriver) GetProcessInfo(ctx context.Context, process vmcompute.HcsProcess) (vmcompute.HcsProcessInformation, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProcessInfo", ctx, process) + ret0, _ := ret[0].(vmcompute.HcsProcessInformation) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetProcessInfo indicates an expected call of GetProcessInfo. +func (mr *MockhcsDriverMockRecorder) GetProcessInfo(ctx, process any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProcessInfo", reflect.TypeOf((*MockhcsDriver)(nil).GetProcessInfo), ctx, process) +} + +// GetProcessProperties mocks base method. +func (m *MockhcsDriver) GetProcessProperties(ctx context.Context, process vmcompute.HcsProcess) (string, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProcessProperties", ctx, process) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetProcessProperties indicates an expected call of GetProcessProperties. +func (mr *MockhcsDriverMockRecorder) GetProcessProperties(ctx, process any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProcessProperties", reflect.TypeOf((*MockhcsDriver)(nil).GetProcessProperties), ctx, process) +} + +// ModifyComputeSystem mocks base method. +func (m *MockhcsDriver) ModifyComputeSystem(ctx context.Context, system vmcompute.HcsSystem, configuration string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ModifyComputeSystem", ctx, system, configuration) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ModifyComputeSystem indicates an expected call of ModifyComputeSystem. +func (mr *MockhcsDriverMockRecorder) ModifyComputeSystem(ctx, system, configuration any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ModifyComputeSystem", reflect.TypeOf((*MockhcsDriver)(nil).ModifyComputeSystem), ctx, system, configuration) +} + +// ModifyProcess mocks base method. +func (m *MockhcsDriver) ModifyProcess(ctx context.Context, process vmcompute.HcsProcess, settings string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ModifyProcess", ctx, process, settings) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ModifyProcess indicates an expected call of ModifyProcess. +func (mr *MockhcsDriverMockRecorder) ModifyProcess(ctx, process, settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ModifyProcess", reflect.TypeOf((*MockhcsDriver)(nil).ModifyProcess), ctx, process, settings) +} + +// OpenComputeSystem mocks base method. +func (m *MockhcsDriver) OpenComputeSystem(ctx context.Context, id string) (vmcompute.HcsSystem, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenComputeSystem", ctx, id) + ret0, _ := ret[0].(vmcompute.HcsSystem) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// OpenComputeSystem indicates an expected call of OpenComputeSystem. +func (mr *MockhcsDriverMockRecorder) OpenComputeSystem(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenComputeSystem", reflect.TypeOf((*MockhcsDriver)(nil).OpenComputeSystem), ctx, id) +} + +// OpenProcess mocks base method. +func (m *MockhcsDriver) OpenProcess(ctx context.Context, system vmcompute.HcsSystem, pid uint32) (vmcompute.HcsProcess, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenProcess", ctx, system, pid) + ret0, _ := ret[0].(vmcompute.HcsProcess) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// OpenProcess indicates an expected call of OpenProcess. +func (mr *MockhcsDriverMockRecorder) OpenProcess(ctx, system, pid any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenProcess", reflect.TypeOf((*MockhcsDriver)(nil).OpenProcess), ctx, system, pid) +} + +// PauseComputeSystem mocks base method. +func (m *MockhcsDriver) PauseComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PauseComputeSystem", ctx, system, options) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PauseComputeSystem indicates an expected call of PauseComputeSystem. +func (mr *MockhcsDriverMockRecorder) PauseComputeSystem(ctx, system, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PauseComputeSystem", reflect.TypeOf((*MockhcsDriver)(nil).PauseComputeSystem), ctx, system, options) +} + +// RegisterComputeSystemCallback mocks base method. +func (m *MockhcsDriver) RegisterComputeSystemCallback(ctx context.Context, system vmcompute.HcsSystem, callback, callbackContext uintptr) (vmcompute.HcsCallback, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterComputeSystemCallback", ctx, system, callback, callbackContext) + ret0, _ := ret[0].(vmcompute.HcsCallback) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RegisterComputeSystemCallback indicates an expected call of RegisterComputeSystemCallback. +func (mr *MockhcsDriverMockRecorder) RegisterComputeSystemCallback(ctx, system, callback, callbackContext any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterComputeSystemCallback", reflect.TypeOf((*MockhcsDriver)(nil).RegisterComputeSystemCallback), ctx, system, callback, callbackContext) +} + +// RegisterProcessCallback mocks base method. +func (m *MockhcsDriver) RegisterProcessCallback(ctx context.Context, process vmcompute.HcsProcess, callback, callbackContext uintptr) (vmcompute.HcsCallback, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterProcessCallback", ctx, process, callback, callbackContext) + ret0, _ := ret[0].(vmcompute.HcsCallback) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RegisterProcessCallback indicates an expected call of RegisterProcessCallback. +func (mr *MockhcsDriverMockRecorder) RegisterProcessCallback(ctx, process, callback, callbackContext any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProcessCallback", reflect.TypeOf((*MockhcsDriver)(nil).RegisterProcessCallback), ctx, process, callback, callbackContext) +} + +// ResumeComputeSystem mocks base method. +func (m *MockhcsDriver) ResumeComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResumeComputeSystem", ctx, system, options) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ResumeComputeSystem indicates an expected call of ResumeComputeSystem. +func (mr *MockhcsDriverMockRecorder) ResumeComputeSystem(ctx, system, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResumeComputeSystem", reflect.TypeOf((*MockhcsDriver)(nil).ResumeComputeSystem), ctx, system, options) +} + +// SaveComputeSystem mocks base method. +func (m *MockhcsDriver) SaveComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveComputeSystem", ctx, system, options) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SaveComputeSystem indicates an expected call of SaveComputeSystem. +func (mr *MockhcsDriverMockRecorder) SaveComputeSystem(ctx, system, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveComputeSystem", reflect.TypeOf((*MockhcsDriver)(nil).SaveComputeSystem), ctx, system, options) +} + +// ShutdownComputeSystem mocks base method. +func (m *MockhcsDriver) ShutdownComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ShutdownComputeSystem", ctx, system, options) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ShutdownComputeSystem indicates an expected call of ShutdownComputeSystem. +func (mr *MockhcsDriverMockRecorder) ShutdownComputeSystem(ctx, system, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShutdownComputeSystem", reflect.TypeOf((*MockhcsDriver)(nil).ShutdownComputeSystem), ctx, system, options) +} + +// SignalProcess mocks base method. +func (m *MockhcsDriver) SignalProcess(ctx context.Context, process vmcompute.HcsProcess, options string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SignalProcess", ctx, process, options) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SignalProcess indicates an expected call of SignalProcess. +func (mr *MockhcsDriverMockRecorder) SignalProcess(ctx, process, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignalProcess", reflect.TypeOf((*MockhcsDriver)(nil).SignalProcess), ctx, process, options) +} + +// StartComputeSystem mocks base method. +func (m *MockhcsDriver) StartComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartComputeSystem", ctx, system, options) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// StartComputeSystem indicates an expected call of StartComputeSystem. +func (mr *MockhcsDriverMockRecorder) StartComputeSystem(ctx, system, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartComputeSystem", reflect.TypeOf((*MockhcsDriver)(nil).StartComputeSystem), ctx, system, options) +} + +// TerminateComputeSystem mocks base method. +func (m *MockhcsDriver) TerminateComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TerminateComputeSystem", ctx, system, options) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TerminateComputeSystem indicates an expected call of TerminateComputeSystem. +func (mr *MockhcsDriverMockRecorder) TerminateComputeSystem(ctx, system, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TerminateComputeSystem", reflect.TypeOf((*MockhcsDriver)(nil).TerminateComputeSystem), ctx, system, options) +} + +// TerminateProcess mocks base method. +func (m *MockhcsDriver) TerminateProcess(ctx context.Context, process vmcompute.HcsProcess) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TerminateProcess", ctx, process) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TerminateProcess indicates an expected call of TerminateProcess. +func (mr *MockhcsDriverMockRecorder) TerminateProcess(ctx, process any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TerminateProcess", reflect.TypeOf((*MockhcsDriver)(nil).TerminateProcess), ctx, process) +} + +// UnregisterComputeSystemCallback mocks base method. +func (m *MockhcsDriver) UnregisterComputeSystemCallback(ctx context.Context, callbackHandle vmcompute.HcsCallback) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnregisterComputeSystemCallback", ctx, callbackHandle) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnregisterComputeSystemCallback indicates an expected call of UnregisterComputeSystemCallback. +func (mr *MockhcsDriverMockRecorder) UnregisterComputeSystemCallback(ctx, callbackHandle any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterComputeSystemCallback", reflect.TypeOf((*MockhcsDriver)(nil).UnregisterComputeSystemCallback), ctx, callbackHandle) +} + +// UnregisterProcessCallback mocks base method. +func (m *MockhcsDriver) UnregisterProcessCallback(ctx context.Context, callbackHandle vmcompute.HcsCallback) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnregisterProcessCallback", ctx, callbackHandle) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnregisterProcessCallback indicates an expected call of UnregisterProcessCallback. +func (mr *MockhcsDriverMockRecorder) UnregisterProcessCallback(ctx, callbackHandle any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProcessCallback", reflect.TypeOf((*MockhcsDriver)(nil).UnregisterProcessCallback), ctx, callbackHandle) +} diff --git a/internal/hcs/process.go b/internal/hcs/process.go index fef2bf546c..ffb82a195c 100644 --- a/internal/hcs/process.go +++ b/internal/hcs/process.go @@ -106,7 +106,7 @@ func (process *Process) Signal(ctx context.Context, options interface{}) (bool, return false, err } - resultJSON, err := vmcompute.HcsSignalProcess(ctx, process.handle, string(optionsb)) + resultJSON, err := process.system.driver.SignalProcess(ctx, process.handle, string(optionsb)) events := processHcsResult(ctx, resultJSON) delivered, err := process.processSignalResult(ctx, err) if err != nil { @@ -171,7 +171,7 @@ func (process *Process) Kill(ctx context.Context) (bool, error) { } defer newProcessHandle.Close() - resultJSON, err := vmcompute.HcsTerminateProcess(ctx, newProcessHandle.handle) + resultJSON, err := process.system.driver.TerminateProcess(ctx, newProcessHandle.handle) if err != nil { // We still need to check these two cases, as processes may still be killed by an // external actor (human operator, OOM, random script etc). @@ -234,7 +234,7 @@ func (process *Process) waitBackground() { // Make sure we didn't race with Close() here if process.handle != 0 { - propertiesJSON, resultJSON, err = vmcompute.HcsGetProcessProperties(ctx, process.handle) + propertiesJSON, resultJSON, err = process.system.driver.GetProcessProperties(ctx, process.handle) events := processHcsResult(ctx, resultJSON) if err != nil { err = makeProcessError(process, operation, err, events) @@ -303,7 +303,7 @@ func (process *Process) ResizeConsole(ctx context.Context, width, height uint16) return err } - resultJSON, err := vmcompute.HcsModifyProcess(ctx, process.handle, string(modifyRequestb)) + resultJSON, err := process.system.driver.ModifyProcess(ctx, process.handle, string(modifyRequestb)) events := processHcsResult(ctx, resultJSON) if err != nil { return makeProcessError(process, operation, err, events) @@ -352,7 +352,7 @@ func (process *Process) StdioLegacy() (_ io.WriteCloser, _ io.ReadCloser, _ io.R return stdin, stdout, stderr, nil } - processInfo, resultJSON, err := vmcompute.HcsGetProcessInfo(ctx, process.handle) + processInfo, resultJSON, err := process.system.driver.GetProcessInfo(ctx, process.handle) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, nil, nil, makeProcessError(process, operation, err, events) @@ -406,7 +406,7 @@ func (process *Process) CloseStdin(ctx context.Context) (err error) { return err } - resultJSON, err := vmcompute.HcsModifyProcess(ctx, process.handle, string(modifyRequestb)) + resultJSON, err := process.system.driver.ModifyProcess(ctx, process.handle, string(modifyRequestb)) events := processHcsResult(ctx, resultJSON) if err != nil { return makeProcessError(process, operation, err, events) @@ -509,7 +509,7 @@ func (process *Process) Close() (err error) { return makeProcessError(process, operation, err, nil) } - if err = vmcompute.HcsCloseProcess(ctx, process.handle); err != nil { + if err = process.system.driver.CloseProcess(ctx, process.handle); err != nil { return makeProcessError(process, operation, err, nil) } @@ -536,7 +536,7 @@ func (process *Process) registerCallback(ctx context.Context) error { callbackMap[callbackNumber] = callbackContext callbackMapLock.Unlock() - callbackHandle, err := vmcompute.HcsRegisterProcessCallback(ctx, process.handle, notificationWatcherCallback, callbackNumber) + callbackHandle, err := process.system.driver.RegisterProcessCallback(ctx, process.handle, notificationWatcherCallback, callbackNumber) if err != nil { return err } @@ -563,9 +563,10 @@ func (process *Process) unregisterCallback(ctx context.Context) error { return nil } - // vmcompute.HcsUnregisterProcessCallback has its own synchronization to - // wait for all callbacks to complete. We must NOT hold the callbackMapLock. - err := vmcompute.HcsUnregisterProcessCallback(ctx, handle) + // The underlying HCS API (HcsUnregisterProcessCallback) has its own + // synchronization to wait for all in-flight callbacks to complete. + // We must NOT hold the callbackMapLock during this call. + err := process.system.driver.UnregisterProcessCallback(ctx, handle) if err != nil { return err } diff --git a/internal/hcs/system.go b/internal/hcs/system.go index 823e27b0b7..27b8ffc198 100644 --- a/internal/hcs/system.go +++ b/internal/hcs/system.go @@ -25,7 +25,12 @@ import ( "go.opencensus.io/trace" ) +// defaultDriver is the production hcsDriver. All compute systems use this +// unless overridden (e.g., in tests via SetDriverForTest). +var defaultDriver hcsDriver = &vmcomputeDriver{} + type System struct { + driver hcsDriver handleLock sync.RWMutex handle vmcompute.HcsSystem id string @@ -46,6 +51,7 @@ var _ cow.ProcessHost = &System{} func newSystem(id string) *System { return &System{ id: id, + driver: defaultDriver, waitBlock: make(chan struct{}), } } @@ -80,7 +86,7 @@ func CreateComputeSystem(ctx context.Context, id string, hcsDocumentInterface in resultJSON string createError error ) - computeSystem.handle, resultJSON, createError = vmcompute.HcsCreateComputeSystem(ctx, id, hcsDocument, identity) + computeSystem.handle, resultJSON, createError = computeSystem.driver.CreateComputeSystem(ctx, id, hcsDocument, identity) if createError == nil || IsPending(createError) { defer func() { if err != nil { @@ -117,7 +123,7 @@ func OpenComputeSystem(ctx context.Context, id string) (*System, error) { operation := "hcs::OpenComputeSystem" computeSystem := newSystem(id) - handle, resultJSON, err := vmcompute.HcsOpenComputeSystem(ctx, id) + handle, resultJSON, err := computeSystem.driver.OpenComputeSystem(ctx, id) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -211,7 +217,7 @@ func (computeSystem *System) Start(ctx context.Context) (err error) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - resultJSON, err := vmcompute.HcsStartComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := computeSystem.driver.StartComputeSystem(ctx, computeSystem.handle, "") events, err := processAsyncHcsResult(ctx, err, resultJSON, computeSystem.callbackNumber, hcsNotificationSystemStartCompleted, &timeout.SystemStart) if err != nil { @@ -237,7 +243,7 @@ func (computeSystem *System) Shutdown(ctx context.Context) error { return nil } - resultJSON, err := vmcompute.HcsShutdownComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := computeSystem.driver.ShutdownComputeSystem(ctx, computeSystem.handle, "") events := processHcsResult(ctx, resultJSON) if err != nil && !errors.Is(err, ErrVmcomputeAlreadyStopped) && @@ -259,7 +265,7 @@ func (computeSystem *System) Terminate(ctx context.Context) error { return nil } - resultJSON, err := vmcompute.HcsTerminateComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := computeSystem.driver.TerminateComputeSystem(ctx, computeSystem.handle, "") events := processHcsResult(ctx, resultJSON) if err != nil && !errors.Is(err, ErrVmcomputeAlreadyStopped) && @@ -362,7 +368,7 @@ func (computeSystem *System) Properties(ctx context.Context, types ...schema1.Pr return nil, makeSystemError(computeSystem, operation, err, nil) } - propertiesJSON, resultJSON, err := vmcompute.HcsGetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) + propertiesJSON, resultJSON, err := computeSystem.driver.GetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -503,7 +509,7 @@ func (computeSystem *System) hcsPropertiesV2Query(ctx context.Context, types []h return nil, makeSystemError(computeSystem, operation, err, nil) } - propertiesJSON, resultJSON, err := vmcompute.HcsGetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) + propertiesJSON, resultJSON, err := computeSystem.driver.GetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -592,7 +598,7 @@ func (computeSystem *System) Pause(ctx context.Context) (err error) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - resultJSON, err := vmcompute.HcsPauseComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := computeSystem.driver.PauseComputeSystem(ctx, computeSystem.handle, "") events, err := processAsyncHcsResult(ctx, err, resultJSON, computeSystem.callbackNumber, hcsNotificationSystemPauseCompleted, &timeout.SystemPause) if err != nil { @@ -620,7 +626,7 @@ func (computeSystem *System) Resume(ctx context.Context) (err error) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - resultJSON, err := vmcompute.HcsResumeComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := computeSystem.driver.ResumeComputeSystem(ctx, computeSystem.handle, "") events, err := processAsyncHcsResult(ctx, err, resultJSON, computeSystem.callbackNumber, hcsNotificationSystemResumeCompleted, &timeout.SystemResume) if err != nil { @@ -653,7 +659,7 @@ func (computeSystem *System) Save(ctx context.Context, options interface{}) (err return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - result, err := vmcompute.HcsSaveComputeSystem(ctx, computeSystem.handle, string(saveOptions)) + result, err := computeSystem.driver.SaveComputeSystem(ctx, computeSystem.handle, string(saveOptions)) events, err := processAsyncHcsResult(ctx, err, result, computeSystem.callbackNumber, hcsNotificationSystemSaveCompleted, &timeout.SystemSave) if err != nil { @@ -677,7 +683,7 @@ func (computeSystem *System) createProcess(ctx context.Context, operation string } configuration := string(configurationb) - processInfo, processHandle, resultJSON, err := vmcompute.HcsCreateProcess(ctx, computeSystem.handle, configuration) + processInfo, processHandle, resultJSON, err := computeSystem.driver.CreateProcess(ctx, computeSystem.handle, configuration) events := processHcsResult(ctx, resultJSON) if err != nil { if v2, ok := c.(*hcsschema.ProcessParameters); ok { @@ -733,7 +739,7 @@ func (computeSystem *System) OpenProcess(ctx context.Context, pid int) (*Process return nil, makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - processHandle, resultJSON, err := vmcompute.HcsOpenProcess(ctx, computeSystem.handle, uint32(pid)) + processHandle, resultJSON, err := computeSystem.driver.OpenProcess(ctx, computeSystem.handle, uint32(pid)) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -776,7 +782,7 @@ func (computeSystem *System) CloseCtx(ctx context.Context) (err error) { return makeSystemError(computeSystem, operation, err, nil) } - err = vmcompute.HcsCloseComputeSystem(ctx, computeSystem.handle) + err = computeSystem.driver.CloseComputeSystem(ctx, computeSystem.handle) if err != nil { return makeSystemError(computeSystem, operation, err, nil) } @@ -802,7 +808,7 @@ func (computeSystem *System) registerCallback(ctx context.Context) error { callbackMap[callbackNumber] = callbackContext callbackMapLock.Unlock() - callbackHandle, err := vmcompute.HcsRegisterComputeSystemCallback(ctx, computeSystem.handle, + callbackHandle, err := computeSystem.driver.RegisterComputeSystemCallback(ctx, computeSystem.handle, notificationWatcherCallback, callbackNumber) if err != nil { return err @@ -830,9 +836,10 @@ func (computeSystem *System) unregisterCallback(ctx context.Context) error { return nil } - // hcsUnregisterComputeSystemCallback has its own synchronization - // to wait for all callbacks to complete. We must NOT hold the callbackMapLock. - err := vmcompute.HcsUnregisterComputeSystemCallback(ctx, handle) + // The underlying HCS API (HcsUnregisterComputeSystemCallback) has its own + // synchronization to wait for all in-flight callbacks to complete. + // We must NOT hold the callbackMapLock during this call. + err := computeSystem.driver.UnregisterComputeSystemCallback(ctx, handle) if err != nil { return err } @@ -865,7 +872,7 @@ func (computeSystem *System) Modify(ctx context.Context, config interface{}) err } requestJSON := string(requestBytes) - resultJSON, err := vmcompute.HcsModifyComputeSystem(ctx, computeSystem.handle, requestJSON) + resultJSON, err := computeSystem.driver.ModifyComputeSystem(ctx, computeSystem.handle, requestJSON) events := processHcsResult(ctx, resultJSON) if err != nil { return makeSystemError(computeSystem, operation, err, events) diff --git a/internal/hcs/system_driver.go b/internal/hcs/system_driver.go new file mode 100644 index 0000000000..ba67868798 --- /dev/null +++ b/internal/hcs/system_driver.go @@ -0,0 +1,98 @@ +//go:build windows + +package hcs + +import ( + "context" + "syscall" + + "github.com/Microsoft/hcsshim/internal/vmcompute" +) + +// hcsDriver abstracts the HCS compute system and process API calls for +// testability and future migration from vmcompute.dll (V1) to computecore.dll (V2). +// +// The V1 implementation (vmcomputeDriver) delegates to the vmcompute package. +// Tests use a mock implementation generated by mockgen. +type hcsDriver interface { + // --- Compute System Lifecycle --- + + // CreateComputeSystem creates a new compute system with the given configuration. + CreateComputeSystem(ctx context.Context, id string, configuration string, identity syscall.Handle) (vmcompute.HcsSystem, string, error) + + // OpenComputeSystem opens an existing compute system by ID. + OpenComputeSystem(ctx context.Context, id string) (vmcompute.HcsSystem, string, error) + + // CloseComputeSystem releases the compute system handle. + CloseComputeSystem(ctx context.Context, system vmcompute.HcsSystem) error + + // StartComputeSystem starts a compute system. + StartComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) + + // ShutdownComputeSystem requests a graceful shutdown of the compute system. + ShutdownComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) + + // TerminateComputeSystem requests a forceful termination of the compute system. + TerminateComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) + + // --- Compute System Operations --- + + // PauseComputeSystem pauses execution of the compute system. + PauseComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) + + // ResumeComputeSystem resumes a paused compute system. + ResumeComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) + + // SaveComputeSystem saves the compute system state. + SaveComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) + + // GetComputeSystemProperties queries compute system properties. + GetComputeSystemProperties(ctx context.Context, system vmcompute.HcsSystem, propertyQuery string) (properties string, result string, err error) + + // ModifyComputeSystem sends a modification request to the compute system. + ModifyComputeSystem(ctx context.Context, system vmcompute.HcsSystem, configuration string) (string, error) + + // --- Compute System Callbacks --- + + // RegisterComputeSystemCallback registers a callback for compute system notifications. + RegisterComputeSystemCallback(ctx context.Context, system vmcompute.HcsSystem, callback uintptr, callbackContext uintptr) (vmcompute.HcsCallback, error) + + // UnregisterComputeSystemCallback unregisters a previously registered callback. + UnregisterComputeSystemCallback(ctx context.Context, callbackHandle vmcompute.HcsCallback) error + + // --- Process Lifecycle --- + + // CreateProcess creates a new process within a compute system. + CreateProcess(ctx context.Context, system vmcompute.HcsSystem, processParameters string) (vmcompute.HcsProcessInformation, vmcompute.HcsProcess, string, error) + + // OpenProcess opens an existing process by PID within a compute system. + OpenProcess(ctx context.Context, system vmcompute.HcsSystem, pid uint32) (vmcompute.HcsProcess, string, error) + + // CloseProcess releases the process handle. + CloseProcess(ctx context.Context, process vmcompute.HcsProcess) error + + // TerminateProcess requests termination of the process. + TerminateProcess(ctx context.Context, process vmcompute.HcsProcess) (string, error) + + // --- Process Operations --- + + // SignalProcess sends a signal to the process. + SignalProcess(ctx context.Context, process vmcompute.HcsProcess, options string) (string, error) + + // GetProcessInfo retrieves information about a process. + GetProcessInfo(ctx context.Context, process vmcompute.HcsProcess) (vmcompute.HcsProcessInformation, string, error) + + // GetProcessProperties retrieves process properties. + GetProcessProperties(ctx context.Context, process vmcompute.HcsProcess) (string, string, error) + + // ModifyProcess modifies a process. + ModifyProcess(ctx context.Context, process vmcompute.HcsProcess, settings string) (string, error) + + // --- Process Callbacks --- + + // RegisterProcessCallback registers a callback for process notifications. + RegisterProcessCallback(ctx context.Context, process vmcompute.HcsProcess, callback uintptr, callbackContext uintptr) (vmcompute.HcsCallback, error) + + // UnregisterProcessCallback unregisters a previously registered process callback. + UnregisterProcessCallback(ctx context.Context, callbackHandle vmcompute.HcsCallback) error +} diff --git a/internal/hcs/system_driver_vmcompute.go b/internal/hcs/system_driver_vmcompute.go new file mode 100644 index 0000000000..b8ca551678 --- /dev/null +++ b/internal/hcs/system_driver_vmcompute.go @@ -0,0 +1,110 @@ +//go:build windows + +package hcs + +import ( + "context" + "syscall" + + "github.com/Microsoft/hcsshim/internal/vmcompute" +) + +// vmcomputeDriver is the V1 implementation of hcsDriver that delegates +// to the vmcompute package (vmcompute.dll). +type vmcomputeDriver struct{} + +var _ hcsDriver = &vmcomputeDriver{} + +func (d *vmcomputeDriver) CreateComputeSystem(ctx context.Context, id string, configuration string, identity syscall.Handle) (vmcompute.HcsSystem, string, error) { + return vmcompute.HcsCreateComputeSystem(ctx, id, configuration, identity) +} + +func (d *vmcomputeDriver) OpenComputeSystem(ctx context.Context, id string) (vmcompute.HcsSystem, string, error) { + return vmcompute.HcsOpenComputeSystem(ctx, id) +} + +func (d *vmcomputeDriver) CloseComputeSystem(ctx context.Context, system vmcompute.HcsSystem) error { + return vmcompute.HcsCloseComputeSystem(ctx, system) +} + +func (d *vmcomputeDriver) StartComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + return vmcompute.HcsStartComputeSystem(ctx, system, options) +} + +func (d *vmcomputeDriver) ShutdownComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + return vmcompute.HcsShutdownComputeSystem(ctx, system, options) +} + +func (d *vmcomputeDriver) TerminateComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + return vmcompute.HcsTerminateComputeSystem(ctx, system, options) +} + +func (d *vmcomputeDriver) PauseComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + return vmcompute.HcsPauseComputeSystem(ctx, system, options) +} + +func (d *vmcomputeDriver) ResumeComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + return vmcompute.HcsResumeComputeSystem(ctx, system, options) +} + +func (d *vmcomputeDriver) SaveComputeSystem(ctx context.Context, system vmcompute.HcsSystem, options string) (string, error) { + return vmcompute.HcsSaveComputeSystem(ctx, system, options) +} + +func (d *vmcomputeDriver) GetComputeSystemProperties(ctx context.Context, system vmcompute.HcsSystem, propertyQuery string) (string, string, error) { + return vmcompute.HcsGetComputeSystemProperties(ctx, system, propertyQuery) +} + +func (d *vmcomputeDriver) ModifyComputeSystem(ctx context.Context, system vmcompute.HcsSystem, configuration string) (string, error) { + return vmcompute.HcsModifyComputeSystem(ctx, system, configuration) +} + +func (d *vmcomputeDriver) RegisterComputeSystemCallback(ctx context.Context, system vmcompute.HcsSystem, callback uintptr, callbackContext uintptr) (vmcompute.HcsCallback, error) { + return vmcompute.HcsRegisterComputeSystemCallback(ctx, system, callback, callbackContext) +} + +func (d *vmcomputeDriver) UnregisterComputeSystemCallback(ctx context.Context, callbackHandle vmcompute.HcsCallback) error { + return vmcompute.HcsUnregisterComputeSystemCallback(ctx, callbackHandle) +} + +// --- Process operations --- + +func (d *vmcomputeDriver) CreateProcess(ctx context.Context, system vmcompute.HcsSystem, processParameters string) (vmcompute.HcsProcessInformation, vmcompute.HcsProcess, string, error) { + return vmcompute.HcsCreateProcess(ctx, system, processParameters) +} + +func (d *vmcomputeDriver) OpenProcess(ctx context.Context, system vmcompute.HcsSystem, pid uint32) (vmcompute.HcsProcess, string, error) { + return vmcompute.HcsOpenProcess(ctx, system, pid) +} + +func (d *vmcomputeDriver) CloseProcess(ctx context.Context, process vmcompute.HcsProcess) error { + return vmcompute.HcsCloseProcess(ctx, process) +} + +func (d *vmcomputeDriver) TerminateProcess(ctx context.Context, process vmcompute.HcsProcess) (string, error) { + return vmcompute.HcsTerminateProcess(ctx, process) +} + +func (d *vmcomputeDriver) SignalProcess(ctx context.Context, process vmcompute.HcsProcess, options string) (string, error) { + return vmcompute.HcsSignalProcess(ctx, process, options) +} + +func (d *vmcomputeDriver) GetProcessInfo(ctx context.Context, process vmcompute.HcsProcess) (vmcompute.HcsProcessInformation, string, error) { + return vmcompute.HcsGetProcessInfo(ctx, process) +} + +func (d *vmcomputeDriver) GetProcessProperties(ctx context.Context, process vmcompute.HcsProcess) (string, string, error) { + return vmcompute.HcsGetProcessProperties(ctx, process) +} + +func (d *vmcomputeDriver) ModifyProcess(ctx context.Context, process vmcompute.HcsProcess, settings string) (string, error) { + return vmcompute.HcsModifyProcess(ctx, process, settings) +} + +func (d *vmcomputeDriver) RegisterProcessCallback(ctx context.Context, process vmcompute.HcsProcess, callback uintptr, callbackContext uintptr) (vmcompute.HcsCallback, error) { + return vmcompute.HcsRegisterProcessCallback(ctx, process, callback, callbackContext) +} + +func (d *vmcomputeDriver) UnregisterProcessCallback(ctx context.Context, callbackHandle vmcompute.HcsCallback) error { + return vmcompute.HcsUnregisterProcessCallback(ctx, callbackHandle) +} diff --git a/internal/hcs/system_test.go b/internal/hcs/system_test.go new file mode 100644 index 0000000000..5a8a0d71b0 --- /dev/null +++ b/internal/hcs/system_test.go @@ -0,0 +1,505 @@ +//go:build windows + +package hcs_test + +import ( + "context" + "errors" + "syscall" + "testing" + "time" + + "github.com/Microsoft/hcsshim/internal/hcs" + "github.com/Microsoft/hcsshim/internal/hcs/mock" + "github.com/Microsoft/hcsshim/internal/vmcompute" + "go.uber.org/mock/gomock" +) + +// TestStartOnClosedHandle verifies Start returns ErrAlreadyClosed when handle is 0. +func TestStartOnClosedHandle(t *testing.T) { + sys := hcs.NewTestSystem("closed-test") + err := sys.Start(context.Background()) + if !errors.Is(err, hcs.ErrAlreadyClosed) { + t.Fatalf("expected ErrAlreadyClosed, got: %v", err) + } +} + +// TestPauseOnClosedHandle verifies Pause returns ErrAlreadyClosed when handle is 0. +func TestPauseOnClosedHandle(t *testing.T) { + sys := hcs.NewTestSystem("closed-test") + err := sys.Pause(context.Background()) + if !errors.Is(err, hcs.ErrAlreadyClosed) { + t.Fatalf("expected ErrAlreadyClosed, got: %v", err) + } +} + +// TestShutdownOnClosedHandle verifies Shutdown returns nil when handle is 0 (graceful no-op). +func TestShutdownOnClosedHandle(t *testing.T) { + sys := hcs.NewTestSystem("closed-test") + err := sys.Shutdown(context.Background()) + if err != nil { + t.Fatalf("expected nil (graceful no-op on closed handle), got: %v", err) + } +} + +// TestTerminateOnClosedHandle verifies Terminate returns nil when handle is 0. +func TestTerminateOnClosedHandle(t *testing.T) { + sys := hcs.NewTestSystem("closed-test") + err := sys.Terminate(context.Background()) + if err != nil { + t.Fatalf("expected nil, got: %v", err) + } +} + +// TestStartWithMockDriver_SyncSuccess verifies Start succeeds when the driver +// returns immediate success (no pending). +func TestStartWithMockDriver_SyncSuccess(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("start-sync", driver, 42) + + driver.EXPECT(). + StartComputeSystem(gomock.Any(), vmcompute.HcsSystem(42), ""). + Return("", nil) + + err := sys.Start(context.Background()) + if err != nil { + t.Fatalf("Start failed: %v", err) + } +} + +// TestStartWithMockDriver_AsyncPending verifies Start waits for the notification +// when the driver returns ErrVmcomputeOperationPending, then completes when the +// notification arrives. +func TestStartWithMockDriver_AsyncPending(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("start-async", driver, 42) + + // Register callback so notification channels exist + fakeCallback := vmcompute.HcsCallback(99) + driver.EXPECT(). + RegisterComputeSystemCallback(gomock.Any(), vmcompute.HcsSystem(42), gomock.Any(), gomock.Any()). + Return(fakeCallback, nil) + + if err := hcs.RegisterCallbackForTest(sys); err != nil { + t.Fatalf("RegisterCallbackForTest failed: %v", err) + } + cbNum := hcs.GetCallbackNumberForTest(sys) + + // Start returns pending + driver.EXPECT(). + StartComputeSystem(gomock.Any(), vmcompute.HcsSystem(42), ""). + Return("", syscall.Errno(0xC0370103)) // ErrVmcomputeOperationPending + + // Fire the StartCompleted notification after a short delay + go func() { + time.Sleep(20 * time.Millisecond) + hcs.FireNotificationForTest(cbNum, hcs.HcsNotificationSystemStartCompleted, nil) + }() + + err := sys.Start(context.Background()) + if err != nil { + t.Fatalf("Start (async pending) failed: %v", err) + } +} + +// TestShutdownSwallowsAlreadyStopped verifies Shutdown swallows ErrVmcomputeAlreadyStopped. +func TestShutdownSwallowsAlreadyStopped(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("shutdown-stopped", driver, 42) + + driver.EXPECT(). + ShutdownComputeSystem(gomock.Any(), vmcompute.HcsSystem(42), ""). + Return("", syscall.Errno(0xc0370110)) // ErrVmcomputeAlreadyStopped + + err := sys.Shutdown(context.Background()) + if err != nil { + t.Fatalf("expected nil (swallowed), got: %v", err) + } +} + +// TestShutdownSwallowsDoesNotExist verifies Shutdown swallows ErrComputeSystemDoesNotExist. +func TestShutdownSwallowsDoesNotExist(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("shutdown-gone", driver, 42) + + driver.EXPECT(). + ShutdownComputeSystem(gomock.Any(), vmcompute.HcsSystem(42), ""). + Return("", syscall.Errno(0xc037010e)) // ErrComputeSystemDoesNotExist + + err := sys.Shutdown(context.Background()) + if err != nil { + t.Fatalf("expected nil (swallowed), got: %v", err) + } +} + +// TestTerminateSwallowsPending verifies Terminate swallows ErrVmcomputeOperationPending. +func TestTerminateSwallowsPending(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("terminate-pending", driver, 42) + + driver.EXPECT(). + TerminateComputeSystem(gomock.Any(), vmcompute.HcsSystem(42), ""). + Return("", syscall.Errno(0xC0370103)) // ErrVmcomputeOperationPending + + err := sys.Terminate(context.Background()) + if err != nil { + t.Fatalf("expected nil (swallowed), got: %v", err) + } +} + +// TestWaitOnClosedSystem verifies Wait returns immediately on a closed system. +func TestWaitOnClosedSystem(t *testing.T) { + sys := hcs.NewTestSystem("wait-closed") + // Close the waitBlock to simulate a system that's already exited + // (newSystem creates an open waitBlock, so Wait would block forever without closing it) + // Instead, test WaitCtx with a cancelled context + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + err := sys.WaitCtx(ctx) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected DeadlineExceeded, got: %v", err) + } +} + +// TestModifyWithMockDriver verifies Modify delegates to the driver. +func TestModifyWithMockDriver(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("modify-test", driver, 42) + + driver.EXPECT(). + ModifyComputeSystem(gomock.Any(), vmcompute.HcsSystem(42), `{"test":"config"}`). + Return("", nil) + + err := sys.Modify(context.Background(), map[string]string{"test": "config"}) + if err != nil { + t.Fatalf("Modify failed: %v", err) + } +} + +// --- A1: SystemExited during Start --- + +// TestStart_SystemExitedDuringPending verifies that if the system exits while +// Start is waiting for StartCompleted, the caller gets ErrUnexpectedContainerExit. +func TestStart_SystemExitedDuringPending(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("start-exit", driver, 42) + + fakeCallback := vmcompute.HcsCallback(99) + driver.EXPECT(). + RegisterComputeSystemCallback(gomock.Any(), vmcompute.HcsSystem(42), gomock.Any(), gomock.Any()). + Return(fakeCallback, nil) + + if err := hcs.RegisterCallbackForTest(sys); err != nil { + t.Fatalf("RegisterCallbackForTest failed: %v", err) + } + cbNum := hcs.GetCallbackNumberForTest(sys) + + // Start returns pending + driver.EXPECT(). + StartComputeSystem(gomock.Any(), vmcompute.HcsSystem(42), ""). + Return("", syscall.Errno(0xC0370103)) + + // Fire SystemExited instead of StartCompleted — simulates VM crash during boot + go func() { + time.Sleep(20 * time.Millisecond) + hcs.FireNotificationForTest(cbNum, hcs.HcsNotificationSystemExited, nil) + }() + + err := sys.Start(context.Background()) + if !errors.Is(err, hcs.ErrUnexpectedContainerExit) { + t.Fatalf("expected ErrUnexpectedContainerExit, got: %v", err) + } +} + +// --- A2: ServiceDisconnect during Start --- + +// TestStart_ServiceDisconnectDuringPending verifies that if the HCS service +// disconnects while Start is waiting, the caller gets ErrUnexpectedProcessAbort. +func TestStart_ServiceDisconnectDuringPending(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("start-disconnect", driver, 42) + + fakeCallback := vmcompute.HcsCallback(99) + driver.EXPECT(). + RegisterComputeSystemCallback(gomock.Any(), vmcompute.HcsSystem(42), gomock.Any(), gomock.Any()). + Return(fakeCallback, nil) + + if err := hcs.RegisterCallbackForTest(sys); err != nil { + t.Fatalf("RegisterCallbackForTest failed: %v", err) + } + cbNum := hcs.GetCallbackNumberForTest(sys) + + // Start returns pending + driver.EXPECT(). + StartComputeSystem(gomock.Any(), vmcompute.HcsSystem(42), ""). + Return("", syscall.Errno(0xC0370103)) + + // Fire ServiceDisconnect — simulates HCS service crash + go func() { + time.Sleep(20 * time.Millisecond) + hcs.FireNotificationForTest(cbNum, hcs.HcsNotificationServiceDisconnect, nil) + }() + + err := sys.Start(context.Background()) + if !errors.Is(err, hcs.ErrUnexpectedProcessAbort) { + t.Fatalf("expected ErrUnexpectedProcessAbort, got: %v", err) + } +} + +// --- A3: Timeout during Pause --- + +// TestPause_Timeout verifies that Pause returns ErrTimeout if the notification +// never arrives within the deadline. +func TestPause_Timeout(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("pause-timeout", driver, 42) + + fakeCallback := vmcompute.HcsCallback(99) + driver.EXPECT(). + RegisterComputeSystemCallback(gomock.Any(), vmcompute.HcsSystem(42), gomock.Any(), gomock.Any()). + Return(fakeCallback, nil) + + if err := hcs.RegisterCallbackForTest(sys); err != nil { + t.Fatalf("RegisterCallbackForTest failed: %v", err) + } + + // Pause returns pending — we never fire the notification + driver.EXPECT(). + PauseComputeSystem(gomock.Any(), vmcompute.HcsSystem(42), ""). + Return("", syscall.Errno(0xC0370103)) + + // Use a short timeout via environment override. + // The code reads timeout.SystemPause which defaults to 4 min. + // We can't easily override it, but we can use context deadline instead. + // waitForNotification selects on: expectedChannel, exitChannel, disconnectChannel, timeout timer. + // The timeout comes from &timeout.SystemPause which we can't control per-test. + // BUT we can fire SystemExited to unblock it with a known error. + // Actually — let's just test with a context deadline. + // Pause() doesn't accept context for its async wait unfortunately — it uses + // the timeout.SystemPause value. Let's fire SystemExited after a delay to + // test the "unexpected exit during pause" path instead, which is more realistic. + + // Fire SystemExited to unblock the wait — tests the "system died during pause" path + go func() { + time.Sleep(50 * time.Millisecond) + hcs.FireNotificationForTest( + hcs.GetCallbackNumberForTest(sys), + hcs.HcsNotificationSystemExited, + nil, + ) + }() + + err := sys.Pause(context.Background()) + if !errors.Is(err, hcs.ErrUnexpectedContainerExit) { + t.Fatalf("expected ErrUnexpectedContainerExit (system died during pause), got: %v", err) + } +} + +// --- B3: waitBackground unexpected exit classification --- + +// TestWaitBackground_NormalExit verifies that when SystemExited fires with nil error, +// Wait() returns nil and exitError is nil. +func TestWaitBackground_NormalExit(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("wait-normal", driver, 42) + + fakeCallback := vmcompute.HcsCallback(99) + driver.EXPECT(). + RegisterComputeSystemCallback(gomock.Any(), vmcompute.HcsSystem(42), gomock.Any(), gomock.Any()). + Return(fakeCallback, nil) + + if err := hcs.RegisterCallbackForTest(sys); err != nil { + t.Fatalf("RegisterCallbackForTest failed: %v", err) + } + + // Launch waitBackground + hcs.StartWaitBackgroundForTest(sys) + + // Fire normal SystemExited + time.Sleep(20 * time.Millisecond) + hcs.FireNotificationForTest( + hcs.GetCallbackNumberForTest(sys), + hcs.HcsNotificationSystemExited, + nil, + ) + + // Wait should return nil + err := sys.Wait() + if err != nil { + t.Fatalf("expected nil from Wait() after normal exit, got: %v", err) + } + if hcs.ExitErrorForTest(sys) != nil { + t.Fatalf("expected nil exitError, got: %v", hcs.ExitErrorForTest(sys)) + } +} + +// TestWaitBackground_UnexpectedExit verifies that when SystemExited fires with +// ErrVmcomputeUnexpectedExit, exitError is set but waitError is nil. +func TestWaitBackground_UnexpectedExit(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("wait-unexpected", driver, 42) + + fakeCallback := vmcompute.HcsCallback(99) + driver.EXPECT(). + RegisterComputeSystemCallback(gomock.Any(), vmcompute.HcsSystem(42), gomock.Any(), gomock.Any()). + Return(fakeCallback, nil) + + if err := hcs.RegisterCallbackForTest(sys); err != nil { + t.Fatalf("RegisterCallbackForTest failed: %v", err) + } + + hcs.StartWaitBackgroundForTest(sys) + + // Fire SystemExited with unexpected exit error + time.Sleep(20 * time.Millisecond) + hcs.FireNotificationForTest( + hcs.GetCallbackNumberForTest(sys), + hcs.HcsNotificationSystemExited, + syscall.Errno(0xC0370106), // ErrVmcomputeUnexpectedExit + ) + + // Wait should return nil (unexpected exit goes to exitError, not waitError) + err := sys.Wait() + if err != nil { + t.Fatalf("expected nil from Wait() after unexpected exit, got: %v", err) + } + exitErr := hcs.ExitErrorForTest(sys) + if exitErr == nil { + t.Fatal("expected non-nil exitError after unexpected exit") + } + if !errors.Is(exitErr, syscall.Errno(0xC0370106)) { + t.Fatalf("expected exitError to contain ErrVmcomputeUnexpectedExit, got: %v", exitErr) + } +} + +// --- B2: Multiple Wait() fan-out --- + +// TestWait_MultipleGoroutines verifies that multiple goroutines waiting on +// the same system all unblock when the system exits. +func TestWait_MultipleGoroutines(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("wait-fanout", driver, 42) + + fakeCallback := vmcompute.HcsCallback(99) + driver.EXPECT(). + RegisterComputeSystemCallback(gomock.Any(), vmcompute.HcsSystem(42), gomock.Any(), gomock.Any()). + Return(fakeCallback, nil) + + if err := hcs.RegisterCallbackForTest(sys); err != nil { + t.Fatalf("RegisterCallbackForTest failed: %v", err) + } + + hcs.StartWaitBackgroundForTest(sys) + + // Launch 5 goroutines all waiting + const numWaiters = 5 + results := make(chan error, numWaiters) + for i := 0; i < numWaiters; i++ { + go func() { + results <- sys.Wait() + }() + } + + // Fire exit after all goroutines are waiting + time.Sleep(50 * time.Millisecond) + hcs.FireNotificationForTest( + hcs.GetCallbackNumberForTest(sys), + hcs.HcsNotificationSystemExited, + nil, + ) + + // All 5 should unblock with nil error + for i := 0; i < numWaiters; i++ { + select { + case err := <-results: + if err != nil { + t.Errorf("waiter %d: expected nil, got: %v", i, err) + } + case <-time.After(2 * time.Second): + t.Fatalf("waiter %d: timed out waiting for Wait() to return", i) + } + } +} + +// --- C1: Late notification after unregisterCallback --- + +// TestCallback_LateNotificationAfterUnregister verifies that firing a +// notification after unregisterCallback has completed does not panic. +// After unregister, the callbackMap entry is deleted and channels are closed. +// A late FireNotificationForTest should be a no-op (context is nil). +func TestCallback_LateNotificationAfterUnregister(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + driver := mock.NewMockhcsDriver(ctrl) + sys := hcs.NewTestSystemWithDriver("late-callback", driver, 42) + + fakeCallback := vmcompute.HcsCallback(99) + driver.EXPECT(). + RegisterComputeSystemCallback(gomock.Any(), vmcompute.HcsSystem(42), gomock.Any(), gomock.Any()). + Return(fakeCallback, nil) + + if err := hcs.RegisterCallbackForTest(sys); err != nil { + t.Fatalf("RegisterCallbackForTest failed: %v", err) + } + cbNum := hcs.GetCallbackNumberForTest(sys) + + // Verify callback exists + if !hcs.CallbackExistsForTest(cbNum) { + t.Fatal("callback should exist after registration") + } + + // Unregister the callback + driver.EXPECT(). + UnregisterComputeSystemCallback(gomock.Any(), fakeCallback). + Return(nil) + + if err := hcs.UnregisterCallbackForTest(sys); err != nil { + t.Fatalf("UnregisterCallbackForTest failed: %v", err) + } + + // Verify callback is gone + if hcs.CallbackExistsForTest(cbNum) { + t.Fatal("callback should not exist after unregistration") + } + + // Fire a late notification — should not panic (callbackMap lookup returns nil) + hcs.FireNotificationForTest(cbNum, hcs.HcsNotificationSystemExited, nil) + hcs.FireNotificationForTest(cbNum, hcs.HcsNotificationSystemStartCompleted, nil) +}