Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions lua/claudecode/server/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@ local M = {}
---@field server table|nil The TCP server instance
---@field port number|nil The port server is running on
---@field auth_token string|nil The authentication token for validating connections
---@field clients table<string, WebSocketClient> A list of connected clients
---@field handlers table Message handlers by method name
---@field ping_timer table|nil Timer for sending pings
M.state = {
server = nil,
port = nil,
auth_token = nil,
clients = {},
handlers = {},
ping_timer = nil,
}
Expand Down Expand Up @@ -53,8 +51,6 @@ function M.start(config, auth_token)
M._handle_message(client, message)
end,
on_connect = function(client)
M.state.clients[client.id] = client

-- Log connection with auth status
if M.state.auth_token then
logger.debug("server", "Authenticated WebSocket client connected:", client.id)
Expand All @@ -71,7 +67,6 @@ function M.start(config, auth_token)
end
end,
on_disconnect = function(client, code, reason)
M.state.clients[client.id] = nil
logger.debug(
"server",
"WebSocket client disconnected:",
Expand Down Expand Up @@ -124,8 +119,6 @@ function M.stop()
M.state.server = nil
M.state.port = nil
M.state.auth_token = nil
M.state.clients = {}

return true
end

Expand Down Expand Up @@ -213,8 +206,6 @@ end
local module_instance_id = math.random(10000, 99999)
logger.debug("server", "Server module loaded with instance ID:", module_instance_id)

-- Note: debug_deferred_table function removed as deferred_responses table is no longer used

function M._setup_deferred_response(deferred_info)
local co = deferred_info.coroutine

Expand Down
44 changes: 34 additions & 10 deletions lua/claudecode/server/mock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ local tools = require("claudecode.tools.init")
M.state = {
server = nil,
port = nil,
clients = {},
handlers = {},
messages = {}, -- Store messages for testing
}
Expand Down Expand Up @@ -74,7 +73,6 @@ function M.stop()
-- Reset state
M.state.server = nil
M.state.port = nil
M.state.clients = {}
M.state.messages = {}

return true
Expand All @@ -101,29 +99,36 @@ end
---@param client_id string A unique client identifier
---@return table client The client object
function M.add_client(client_id)
assert(type(client_id) == "string", "Expected client_id to be a string")
if not M.state.server then
error("Server not running")
end
assert(type(M.state.server.clients) == "table", "Expected mock server.clients to be a table")

local client = {
id = client_id,
connected = true,
messages = {},
}

M.state.clients[client_id] = client
M.state.server.clients[client_id] = client
return client
end

---Remove a client from the server
---@param client_id string The client identifier
---@return boolean success Whether removal was successful
function M.remove_client(client_id)
if not M.state.server or not M.state.clients[client_id] then
assert(type(client_id) == "string", "Expected client_id to be a string")
if not M.state.server or type(M.state.server.clients) ~= "table" then
return false
end

M.state.clients[client_id] = nil
if not M.state.server.clients[client_id] then
return false
end

M.state.server.clients[client_id] = nil
return true
end

Expand All @@ -136,7 +141,10 @@ function M.send(client, method, params)
local client_obj

if type(client) == "string" then
client_obj = M.state.clients[client]
if not M.state.server or type(M.state.server.clients) ~= "table" then
return false
end
client_obj = M.state.server.clients[client]
else
client_obj = client
end
Expand Down Expand Up @@ -172,7 +180,10 @@ function M.send_response(client, id, result, error)
local client_obj

if type(client) == "string" then
client_obj = M.state.clients[client]
if not M.state.server or type(M.state.server.clients) ~= "table" then
return false
end
client_obj = M.state.server.clients[client]
else
client_obj = client
end
Expand Down Expand Up @@ -208,9 +219,13 @@ end
---@param params table The parameters to send
---@return boolean success Whether broadcasting was successful
function M.broadcast(method, params)
if not M.state.server or type(M.state.server.clients) ~= "table" then
return false
end

local success = true

for client_id, _ in pairs(M.state.clients) do
for client_id, _ in pairs(M.state.server.clients) do
local send_success = M.send(client_id, method, params)
success = success and send_success
end
Expand All @@ -223,7 +238,12 @@ end
---@param message table The message to process
---@return table|nil response The response if any
function M.simulate_message(client_id, message)
local client = M.state.clients[client_id]
assert(type(client_id) == "string", "Expected client_id to be a string")
if not M.state.server or type(M.state.server.clients) ~= "table" then
return nil
end

local client = M.state.server.clients[client_id]

if not client then
return nil
Expand Down Expand Up @@ -255,7 +275,11 @@ end
function M.clear_messages()
M.state.messages = {}

for _, client in pairs(M.state.clients) do
if not M.state.server or type(M.state.server.clients) ~= "table" then
return
end

for _, client in pairs(M.state.server.clients) do
client.messages = {}
end
end
Expand Down
49 changes: 42 additions & 7 deletions lua/claudecode/server/tcp.lua
Original file line number Diff line number Diff line change
Expand Up @@ -124,33 +124,68 @@ function M._handle_new_connection(server)
-- Set up data handler
client_tcp:read_start(function(err, data)
if err then
server.on_error("Client read error: " .. err)
M._remove_client(server, client)
local error_msg = "Client read error: " .. err
server.on_error(error_msg)
M._disconnect_client(server, client, 1006, error_msg)
return
end

if not data then
-- EOF - client disconnected
M._remove_client(server, client)
M._disconnect_client(server, client, 1006, "EOF")
return
end

-- Process incoming data
client_manager.process_data(client, data, function(cl, message)
server.on_message(cl, message)
end, function(cl, code, reason)
server.on_disconnect(cl, code, reason)
M._remove_client(server, cl)
M._disconnect_client(server, cl, code, reason)
end, function(cl, error_msg)
server.on_error("Client " .. cl.id .. " error: " .. error_msg)
M._remove_client(server, cl)
M._disconnect_client(server, cl, 1006, "Client error: " .. error_msg)
end, server.auth_token)
end)

-- Notify about new connection
server.on_connect(client)
end

---Disconnect a client and remove it from the server.
---This ensures `server.on_disconnect` is invoked for every disconnect path
---(EOF, read errors, protocol errors, timeouts), and only once per client.
---@param server TCPServer The server object
---@param client WebSocketClient The client to disconnect
---@param code number|nil WebSocket close code
---@param reason string|nil WebSocket close reason
function M._disconnect_client(server, client, code, reason)
assert(type(server) == "table", "Expected server to be a table")
local on_disconnect_type = type(server.on_disconnect)
local on_disconnect_mt = on_disconnect_type == "table" and getmetatable(server.on_disconnect) or nil
assert(
on_disconnect_type == "function" or (on_disconnect_mt ~= nil and type(on_disconnect_mt.__call) == "function"),
"Expected server.on_disconnect to be callable"
)
assert(type(server.clients) == "table", "Expected server.clients to be a table")
assert(type(client) == "table", "Expected client to be a table")
assert(type(client.id) == "string", "Expected client.id to be a string")
if code ~= nil then
assert(type(code) == "number", "Expected code to be a number")
end
if reason ~= nil then
assert(type(reason) == "string", "Expected reason to be a string")
end

-- Idempotency: a client can hit multiple disconnect paths (e.g. CLOSE frame
-- followed by a TCP EOF). Only notify/remove once.
if not server.clients[client.id] then
return
end

server.on_disconnect(client, code, reason)
M._remove_client(server, client)
end

---Remove a client from the server
---@param server TCPServer The server object
---@param client WebSocketClient The client to remove
Expand Down Expand Up @@ -293,7 +328,7 @@ function M.start_ping_timer(server, interval)
string.format("Client %s keepalive timeout (%ds idle), closing connection", client.id, time_since_pong)
)
client_manager.close_client(client, 1006, "Connection timeout")
M._remove_client(server, client)
M._disconnect_client(server, client, 1006, "Connection timeout")
end
end
end
Expand Down
1 change: 1 addition & 0 deletions tests/mocks/vim.lua
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ local vim = {
return true
end,
read_start = function(self, callback)
self._read_cb = callback
return true
end,
write = function(self, data, callback)
Expand Down
8 changes: 5 additions & 3 deletions tests/server_test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ describe("Server module", function()
assert(type(server.state) == "table")
assert(server.state.server == nil)
assert(server.state.port == nil)
assert(type(server.state.clients) == "table")
assert(type(server.state.handlers) == "table")
end)

Expand Down Expand Up @@ -259,8 +258,11 @@ describe("Server module", function()
assert(stop_success == true)
assert(server.state.server == nil)
assert(server.state.port == nil)
assert(type(server.state.clients) == "table")
assert(0 == #server.state.clients)

local status = server.get_status()
assert(status.running == false)
assert(status.port == nil)
assert(status.client_count == 0)
end)

it("should not stop the server if not running", function()
Expand Down
Loading