From f9242e4f6abe817e62d39d8c6a85ed4a25c40df6 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Fri, 23 Jan 2026 23:04:20 +0100 Subject: [PATCH 1/3] fix(server): fire disconnect callbacks on EOF/errors Change-Id: I53259ea2f51407e4bde33258ba0639cf00b2c29a Signed-off-by: Thomas Kosiewski --- lua/claudecode/server/init.lua | 2 +- lua/claudecode/server/tcp.lua | 52 +++++++++++-- tests/mocks/vim.lua | 1 + tests/unit/server/tcp_spec.lua | 131 +++++++++++++++++++++++++++++++++ 4 files changed, 177 insertions(+), 9 deletions(-) create mode 100644 tests/unit/server/tcp_spec.lua diff --git a/lua/claudecode/server/init.lua b/lua/claudecode/server/init.lua index 288c4914..81fe9a64 100644 --- a/lua/claudecode/server/init.lua +++ b/lua/claudecode/server/init.lua @@ -12,7 +12,7 @@ local M = {} ---@field server table|nil The TCP server instance ---@field port number|nil The port server is running on ---@field auth_token string|nil The authentication token for validating connections ----@field clients table A list of connected clients +---@field clients table Mirrored view of connected clients (updated via tcp callbacks) ---@field handlers table Message handlers by method name ---@field ping_timer table|nil Timer for sending pings M.state = { diff --git a/lua/claudecode/server/tcp.lua b/lua/claudecode/server/tcp.lua index 297aa209..e4d6ae4d 100644 --- a/lua/claudecode/server/tcp.lua +++ b/lua/claudecode/server/tcp.lua @@ -8,7 +8,7 @@ local M = {} ---@field server table The vim.loop TCP server handle ---@field port number The port the server is listening on ---@field auth_token string|nil The authentication token for validating connections ----@field clients table Table of connected clients +---@field clients table Transport-level registry of connected clients (canonical) ---@field on_message function Callback for WebSocket messages ---@field on_connect function Callback for new connections ---@field on_disconnect function Callback for client disconnections @@ -124,14 +124,15 @@ function M._handle_new_connection(server) -- Set up data handler client_tcp:read_start(function(err, data) if err then - server.on_error("Client read error: " .. err) - M._remove_client(server, client) + local error_msg = "Client read error: " .. err + server.on_error(error_msg) + M._disconnect_client(server, client, 1006, error_msg) return end if not data then -- EOF - client disconnected - M._remove_client(server, client) + M._disconnect_client(server, client, 1006, "EOF") return end @@ -139,11 +140,10 @@ function M._handle_new_connection(server) client_manager.process_data(client, data, function(cl, message) server.on_message(cl, message) end, function(cl, code, reason) - server.on_disconnect(cl, code, reason) - M._remove_client(server, cl) + M._disconnect_client(server, cl, code, reason) end, function(cl, error_msg) server.on_error("Client " .. cl.id .. " error: " .. error_msg) - M._remove_client(server, cl) + M._disconnect_client(server, cl, 1006, "Client error: " .. error_msg) end, server.auth_token) end) @@ -151,6 +151,42 @@ function M._handle_new_connection(server) server.on_connect(client) end +---Disconnect a client and remove it from the server. +---This ensures `server.on_disconnect` is invoked for every disconnect path +---(EOF, read errors, protocol errors, timeouts), keeping higher-level client +---state in sync. +---@param server TCPServer The server object +---@param client WebSocketClient The client to disconnect +---@param code number|nil WebSocket close code +---@param reason string|nil WebSocket close reason +function M._disconnect_client(server, client, code, reason) + assert(type(server) == "table", "Expected server to be a table") + local on_disconnect_type = type(server.on_disconnect) + local on_disconnect_mt = on_disconnect_type == "table" and getmetatable(server.on_disconnect) or nil + assert( + on_disconnect_type == "function" or (on_disconnect_mt ~= nil and type(on_disconnect_mt.__call) == "function"), + "Expected server.on_disconnect to be callable" + ) + assert(type(server.clients) == "table", "Expected server.clients to be a table") + assert(type(client) == "table", "Expected client to be a table") + assert(type(client.id) == "string", "Expected client.id to be a string") + if code ~= nil then + assert(type(code) == "number", "Expected code to be a number") + end + if reason ~= nil then + assert(type(reason) == "string", "Expected reason to be a string") + end + + -- Idempotency: a client can hit multiple disconnect paths (e.g. CLOSE frame + -- followed by a TCP EOF). Only notify/remove once. + if not server.clients[client.id] then + return + end + + server.on_disconnect(client, code, reason) + M._remove_client(server, client) +end + ---Remove a client from the server ---@param server TCPServer The server object ---@param client WebSocketClient The client to remove @@ -293,7 +329,7 @@ function M.start_ping_timer(server, interval) string.format("Client %s keepalive timeout (%ds idle), closing connection", client.id, time_since_pong) ) client_manager.close_client(client, 1006, "Connection timeout") - M._remove_client(server, client) + M._disconnect_client(server, client, 1006, "Connection timeout") end end end diff --git a/tests/mocks/vim.lua b/tests/mocks/vim.lua index 77a33021..2c1e5b9e 100644 --- a/tests/mocks/vim.lua +++ b/tests/mocks/vim.lua @@ -881,6 +881,7 @@ local vim = { return true end, read_start = function(self, callback) + self._read_cb = callback return true end, write = function(self, data, callback) diff --git a/tests/unit/server/tcp_spec.lua b/tests/unit/server/tcp_spec.lua new file mode 100644 index 00000000..83a96e3d --- /dev/null +++ b/tests/unit/server/tcp_spec.lua @@ -0,0 +1,131 @@ +require("tests.busted_setup") + +local client_manager = require("claudecode.server.client") + +describe("TCP server disconnect handling", function() + local tcp + local original_process_data + + before_each(function() + package.loaded["claudecode.server.tcp"] = nil + tcp = require("claudecode.server.tcp") + original_process_data = client_manager.process_data + end) + + after_each(function() + client_manager.process_data = original_process_data + end) + + it("should call on_disconnect and remove client on EOF", function() + local callbacks = { + on_message = spy.new(function() end), + on_connect = spy.new(function() end), + on_disconnect = spy.new(function() end), + on_error = spy.new(function() end), + } + + local config = { port_range = { min = 10000, max = 10000 } } + local server, err = tcp.create_server(config, callbacks, nil) + assert.is_nil(err) + assert.is_table(server) + + tcp._handle_new_connection(server) + + assert.spy(callbacks.on_connect).was_called(1) + local client = callbacks.on_connect.calls[1].vals[1] + assert.is_table(client) + assert.is_table(client.tcp_handle) + assert.is_function(client.tcp_handle._read_cb) + + -- Simulate client abruptly disconnecting (e.g. CLI terminated via Ctrl-C) + client.tcp_handle._read_cb(nil, nil) + + assert.spy(callbacks.on_disconnect).was_called(1) + assert.spy(callbacks.on_disconnect).was_called_with(client, 1006, "EOF") + expect(server.clients[client.id]).to_be_nil() + end) + + it("should call on_disconnect and remove client on TCP read error", function() + local callbacks = { + on_message = spy.new(function() end), + on_connect = spy.new(function() end), + on_disconnect = spy.new(function() end), + on_error = spy.new(function() end), + } + + local config = { port_range = { min = 10000, max = 10000 } } + local server, err = tcp.create_server(config, callbacks, nil) + assert.is_nil(err) + assert.is_table(server) + + tcp._handle_new_connection(server) + + local client = callbacks.on_connect.calls[1].vals[1] + client.tcp_handle._read_cb("boom", nil) + + assert.spy(callbacks.on_disconnect).was_called(1) + assert.spy(callbacks.on_disconnect).was_called_with(client, 1006, "Client read error: boom") + expect(server.clients[client.id]).to_be_nil() + + assert.spy(callbacks.on_error).was_called(1) + assert.spy(callbacks.on_error).was_called_with("Client read error: boom") + end) + + it("should call on_disconnect when client manager reports an error", function() + client_manager.process_data = function(cl, data, on_message, on_close, on_error, auth_token) + on_error(cl, "Protocol error") + end + + local callbacks = { + on_message = spy.new(function() end), + on_connect = spy.new(function() end), + on_disconnect = spy.new(function() end), + on_error = spy.new(function() end), + } + + local config = { port_range = { min = 10000, max = 10000 } } + local server, err = tcp.create_server(config, callbacks, nil) + assert.is_nil(err) + assert.is_table(server) + + tcp._handle_new_connection(server) + + local client = callbacks.on_connect.calls[1].vals[1] + client.tcp_handle._read_cb(nil, "some data") + + assert.spy(callbacks.on_disconnect).was_called(1) + assert.spy(callbacks.on_disconnect).was_called_with(client, 1006, "Client error: Protocol error") + expect(server.clients[client.id]).to_be_nil() + end) + + it("should only call on_disconnect once if multiple disconnect paths fire", function() + client_manager.process_data = function(cl, data, on_message, on_close, on_error, auth_token) + on_close(cl, 1000, "bye") + end + + local callbacks = { + on_message = spy.new(function() end), + on_connect = spy.new(function() end), + on_disconnect = spy.new(function() end), + on_error = spy.new(function() end), + } + + local config = { port_range = { min = 10000, max = 10000 } } + local server, err = tcp.create_server(config, callbacks, nil) + assert.is_nil(err) + assert.is_table(server) + + tcp._handle_new_connection(server) + + local client = callbacks.on_connect.calls[1].vals[1] + client.tcp_handle._read_cb(nil, "data") + + assert.spy(callbacks.on_disconnect).was_called(1) + assert.spy(callbacks.on_disconnect).was_called_with(client, 1000, "bye") + expect(server.clients[client.id]).to_be_nil() + + -- Simulate a later EOF after the CLOSE path already removed the client. + client.tcp_handle._read_cb(nil, nil) + assert.spy(callbacks.on_disconnect).was_called(1) + end) +end) From b7dc62368c4c06b78f8245625cec61802af76600 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Mon, 26 Jan 2026 14:25:54 +0100 Subject: [PATCH 2/3] refactor(server): remove client mirror Change-Id: Ieee0d2842aa3ca84645dae4cd2552c7ee48dde6d Signed-off-by: Thomas Kosiewski --- lua/claudecode/server/init.lua | 7 ------ lua/claudecode/server/mock.lua | 44 ++++++++++++++++++++++++++-------- tests/server_test.lua | 8 ++++--- tests/unit/server_spec.lua | 6 +++-- 4 files changed, 43 insertions(+), 22 deletions(-) diff --git a/lua/claudecode/server/init.lua b/lua/claudecode/server/init.lua index 81fe9a64..ad5a8b9d 100644 --- a/lua/claudecode/server/init.lua +++ b/lua/claudecode/server/init.lua @@ -12,14 +12,12 @@ local M = {} ---@field server table|nil The TCP server instance ---@field port number|nil The port server is running on ---@field auth_token string|nil The authentication token for validating connections ----@field clients table Mirrored view of connected clients (updated via tcp callbacks) ---@field handlers table Message handlers by method name ---@field ping_timer table|nil Timer for sending pings M.state = { server = nil, port = nil, auth_token = nil, - clients = {}, handlers = {}, ping_timer = nil, } @@ -53,8 +51,6 @@ function M.start(config, auth_token) M._handle_message(client, message) end, on_connect = function(client) - M.state.clients[client.id] = client - -- Log connection with auth status if M.state.auth_token then logger.debug("server", "Authenticated WebSocket client connected:", client.id) @@ -71,7 +67,6 @@ function M.start(config, auth_token) end end, on_disconnect = function(client, code, reason) - M.state.clients[client.id] = nil logger.debug( "server", "WebSocket client disconnected:", @@ -124,8 +119,6 @@ function M.stop() M.state.server = nil M.state.port = nil M.state.auth_token = nil - M.state.clients = {} - return true end diff --git a/lua/claudecode/server/mock.lua b/lua/claudecode/server/mock.lua index 11b5ba18..e3cdfe02 100644 --- a/lua/claudecode/server/mock.lua +++ b/lua/claudecode/server/mock.lua @@ -12,7 +12,6 @@ local tools = require("claudecode.tools.init") M.state = { server = nil, port = nil, - clients = {}, handlers = {}, messages = {}, -- Store messages for testing } @@ -74,7 +73,6 @@ function M.stop() -- Reset state M.state.server = nil M.state.port = nil - M.state.clients = {} M.state.messages = {} return true @@ -101,9 +99,11 @@ end ---@param client_id string A unique client identifier ---@return table client The client object function M.add_client(client_id) + assert(type(client_id) == "string", "Expected client_id to be a string") if not M.state.server then error("Server not running") end + assert(type(M.state.server.clients) == "table", "Expected mock server.clients to be a table") local client = { id = client_id, @@ -111,7 +111,7 @@ function M.add_client(client_id) messages = {}, } - M.state.clients[client_id] = client + M.state.server.clients[client_id] = client return client end @@ -119,11 +119,16 @@ end ---@param client_id string The client identifier ---@return boolean success Whether removal was successful function M.remove_client(client_id) - if not M.state.server or not M.state.clients[client_id] then + assert(type(client_id) == "string", "Expected client_id to be a string") + if not M.state.server or type(M.state.server.clients) ~= "table" then return false end - M.state.clients[client_id] = nil + if not M.state.server.clients[client_id] then + return false + end + + M.state.server.clients[client_id] = nil return true end @@ -136,7 +141,10 @@ function M.send(client, method, params) local client_obj if type(client) == "string" then - client_obj = M.state.clients[client] + if not M.state.server or type(M.state.server.clients) ~= "table" then + return false + end + client_obj = M.state.server.clients[client] else client_obj = client end @@ -172,7 +180,10 @@ function M.send_response(client, id, result, error) local client_obj if type(client) == "string" then - client_obj = M.state.clients[client] + if not M.state.server or type(M.state.server.clients) ~= "table" then + return false + end + client_obj = M.state.server.clients[client] else client_obj = client end @@ -208,9 +219,13 @@ end ---@param params table The parameters to send ---@return boolean success Whether broadcasting was successful function M.broadcast(method, params) + if not M.state.server or type(M.state.server.clients) ~= "table" then + return false + end + local success = true - for client_id, _ in pairs(M.state.clients) do + for client_id, _ in pairs(M.state.server.clients) do local send_success = M.send(client_id, method, params) success = success and send_success end @@ -223,7 +238,12 @@ end ---@param message table The message to process ---@return table|nil response The response if any function M.simulate_message(client_id, message) - local client = M.state.clients[client_id] + assert(type(client_id) == "string", "Expected client_id to be a string") + if not M.state.server or type(M.state.server.clients) ~= "table" then + return nil + end + + local client = M.state.server.clients[client_id] if not client then return nil @@ -255,7 +275,11 @@ end function M.clear_messages() M.state.messages = {} - for _, client in pairs(M.state.clients) do + if not M.state.server or type(M.state.server.clients) ~= "table" then + return + end + + for _, client in pairs(M.state.server.clients) do client.messages = {} end end diff --git a/tests/server_test.lua b/tests/server_test.lua index 9bb27e7c..52ccdc2d 100644 --- a/tests/server_test.lua +++ b/tests/server_test.lua @@ -226,7 +226,6 @@ describe("Server module", function() assert(type(server.state) == "table") assert(server.state.server == nil) assert(server.state.port == nil) - assert(type(server.state.clients) == "table") assert(type(server.state.handlers) == "table") end) @@ -259,8 +258,11 @@ describe("Server module", function() assert(stop_success == true) assert(server.state.server == nil) assert(server.state.port == nil) - assert(type(server.state.clients) == "table") - assert(0 == #server.state.clients) + + local status = server.get_status() + assert(status.running == false) + assert(status.port == nil) + assert(status.client_count == 0) end) it("should not stop the server if not running", function() diff --git a/tests/unit/server_spec.lua b/tests/unit/server_spec.lua index d2df77a6..14148f34 100644 --- a/tests/unit/server_spec.lua +++ b/tests/unit/server_spec.lua @@ -91,8 +91,10 @@ describe("WebSocket Server", function() expect(success).to_be_true() expect(server.state.server).to_be_nil() expect(server.state.port).to_be_nil() - expect(server.state.clients).to_be_table() - expect(#server.state.clients).to_be(0) + + local status = server.get_status() + expect(status.running).to_be_false() + expect(status.client_count).to_be(0) end) it("should not stop server if not running", function() From a00cd93b279287b3f2f5166d1b27508aaaf4e27a Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Mon, 26 Jan 2026 15:12:36 +0100 Subject: [PATCH 3/3] docs(server): clean up client registry comments Change-Id: If645c9f567afe7dbbf64c1dcd6f1c252ea18d266 Signed-off-by: Thomas Kosiewski --- lua/claudecode/server/init.lua | 2 -- lua/claudecode/server/tcp.lua | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/lua/claudecode/server/init.lua b/lua/claudecode/server/init.lua index ad5a8b9d..9da47080 100644 --- a/lua/claudecode/server/init.lua +++ b/lua/claudecode/server/init.lua @@ -206,8 +206,6 @@ end local module_instance_id = math.random(10000, 99999) logger.debug("server", "Server module loaded with instance ID:", module_instance_id) --- Note: debug_deferred_table function removed as deferred_responses table is no longer used - function M._setup_deferred_response(deferred_info) local co = deferred_info.coroutine diff --git a/lua/claudecode/server/tcp.lua b/lua/claudecode/server/tcp.lua index e4d6ae4d..4aac69e6 100644 --- a/lua/claudecode/server/tcp.lua +++ b/lua/claudecode/server/tcp.lua @@ -8,7 +8,7 @@ local M = {} ---@field server table The vim.loop TCP server handle ---@field port number The port the server is listening on ---@field auth_token string|nil The authentication token for validating connections ----@field clients table Transport-level registry of connected clients (canonical) +---@field clients table Table of connected clients ---@field on_message function Callback for WebSocket messages ---@field on_connect function Callback for new connections ---@field on_disconnect function Callback for client disconnections @@ -153,8 +153,7 @@ end ---Disconnect a client and remove it from the server. ---This ensures `server.on_disconnect` is invoked for every disconnect path ----(EOF, read errors, protocol errors, timeouts), keeping higher-level client ----state in sync. +---(EOF, read errors, protocol errors, timeouts), and only once per client. ---@param server TCPServer The server object ---@param client WebSocketClient The client to disconnect ---@param code number|nil WebSocket close code