Skip to content

Commit 4bfa492

Browse files
authored
Merge pull request #4 from OpenSystemsLab/feature/auto-reconnect
feat: auto reconnect to remote mcp server
2 parents c06bb13 + 3b33b6c commit 4bfa492

File tree

1 file changed

+70
-58
lines changed

1 file changed

+70
-58
lines changed

main.go

Lines changed: 70 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"sort"
1212
"strconv"
1313
"strings"
14+
"time"
1415

1516
"github.com/charmbracelet/bubbles/key"
1617
"github.com/charmbracelet/bubbles/list"
@@ -83,37 +84,26 @@ var sseCmd = &cobra.Command{
8384
Args: cobra.ExactArgs(1),
8485
Run: func(cmd *cobra.Command, args []string) {
8586
url := args[0]
86-
if verbose {
87-
log.Printf("URL: %s", url)
88-
}
89-
9087
headerStrings, _ := cmd.Flags().GetStringSlice("header")
91-
var httpClient *http.Client
92-
if len(headerStrings) > 0 {
93-
headers := parseHeaders(headerStrings)
94-
httpClient = &http.Client{
95-
Transport: &headerTransport{
96-
base: http.DefaultTransport,
97-
headers: headers,
98-
},
99-
}
100-
}
101-
10288
ctx := context.Background()
103-
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil)
10489

105-
transport := &mcp.SSEClientTransport{Endpoint: url, HTTPClient: httpClient}
106-
session, err := client.Connect(ctx, transport, nil)
107-
if err != nil {
108-
log.Fatalf("Failed to connect to SSE server: %v", err)
109-
}
110-
defer session.Close()
111-
112-
if verbose {
113-
log.Println("Connected to SSE server")
90+
connect := func() (*mcp.ClientSession, error) {
91+
var httpClient *http.Client
92+
if len(headerStrings) > 0 {
93+
headers := parseHeaders(headerStrings)
94+
httpClient = &http.Client{
95+
Transport: &headerTransport{
96+
base: http.DefaultTransport,
97+
headers: headers,
98+
},
99+
}
100+
}
101+
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil)
102+
transport := &mcp.SSEClientTransport{Endpoint: url, HTTPClient: httpClient}
103+
return client.Connect(ctx, transport, nil)
114104
}
115105

116-
handleSession(ctx, session)
106+
runSessionWithReconnect(ctx, connect)
117107
},
118108
}
119109

@@ -123,37 +113,26 @@ var httpCmd = &cobra.Command{
123113
Args: cobra.ExactArgs(1),
124114
Run: func(cmd *cobra.Command, args []string) {
125115
url := args[0]
126-
if verbose {
127-
log.Printf("URL: %s", url)
128-
}
129-
130116
headerStrings, _ := cmd.Flags().GetStringSlice("header")
131-
var httpClient *http.Client
132-
if len(headerStrings) > 0 {
133-
headers := parseHeaders(headerStrings)
134-
httpClient = &http.Client{
135-
Transport: &headerTransport{
136-
base: http.DefaultTransport,
137-
headers: headers,
138-
},
139-
}
140-
}
141-
142117
ctx := context.Background()
143-
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil)
144118

145-
transport := &mcp.StreamableClientTransport{Endpoint: url, HTTPClient: httpClient}
146-
session, err := client.Connect(ctx, transport, nil)
147-
if err != nil {
148-
log.Fatalf("Failed to connect to streamable HTTP server: %v", err)
149-
}
150-
defer session.Close()
151-
152-
if verbose {
153-
log.Println("Connected to streamable HTTP server")
119+
connect := func() (*mcp.ClientSession, error) {
120+
var httpClient *http.Client
121+
if len(headerStrings) > 0 {
122+
headers := parseHeaders(headerStrings)
123+
httpClient = &http.Client{
124+
Transport: &headerTransport{
125+
base: http.DefaultTransport,
126+
headers: headers,
127+
},
128+
}
129+
}
130+
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil)
131+
transport := &mcp.StreamableClientTransport{Endpoint: url, HTTPClient: httpClient}
132+
return client.Connect(ctx, transport, nil)
154133
}
155134

156-
handleSession(ctx, session)
135+
runSessionWithReconnect(ctx, connect)
157136
},
158137
}
159138

@@ -184,6 +163,31 @@ func parseHeaders(headerStrings []string) http.Header {
184163
return headers
185164
}
186165

166+
type connectFn func() (*mcp.ClientSession, error)
167+
168+
func runSessionWithReconnect(ctx context.Context, connect connectFn) {
169+
for {
170+
log.Println("Attempting to connect to server...")
171+
session, err := connect()
172+
if err != nil {
173+
log.Printf("Failed to connect: %v. Retrying in 5 seconds...", err)
174+
time.Sleep(5 * time.Second)
175+
continue
176+
}
177+
178+
log.Println("Connected to server.")
179+
err = handleSession(ctx, session)
180+
session.Close()
181+
182+
if err != nil {
183+
log.Printf("Session ended with error: %v. Reconnecting...", err)
184+
} else {
185+
log.Println("Session closed cleanly. Exiting.")
186+
break
187+
}
188+
}
189+
}
190+
187191
// -- Bubble Tea TUI -----------------------------------------------------------
188192

189193
type viewState int
@@ -384,7 +388,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
384388
case toolResult:
385389
if msg.err != nil {
386390
m.err = msg.err
387-
return m, nil
391+
return m, tea.Quit
388392
}
389393
if verbose {
390394
m.logf("Tool result received")
@@ -396,7 +400,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
396400
case resourceResult:
397401
if msg.err != nil {
398402
m.err = msg.err
399-
return m, nil
403+
return m, tea.Quit
400404
}
401405
if verbose {
402406
m.logf("Resource result received")
@@ -778,7 +782,7 @@ func (m *AppModel) readResourceCmd() tea.Cmd {
778782
}
779783
}
780784

781-
func handleSession(ctx context.Context, session *mcp.ClientSession) {
785+
func handleSession(ctx context.Context, session *mcp.ClientSession) error {
782786
if verbose {
783787
f, err := tea.LogToFile("debug.log", "debug")
784788
if err != nil {
@@ -787,10 +791,18 @@ func handleSession(ctx context.Context, session *mcp.ClientSession) {
787791
}
788792
defer f.Close()
789793
}
790-
p := tea.NewProgram(initialModel(ctx, session), tea.WithAltScreen(), tea.WithMouseCellMotion())
791-
if _, err := p.Run(); err != nil {
792-
log.Fatalf("Error running program: %v", err)
794+
model := initialModel(ctx, session)
795+
p := tea.NewProgram(model, tea.WithAltScreen(), tea.WithMouseCellMotion())
796+
finalModel, err := p.Run()
797+
if err != nil {
798+
return fmt.Errorf("error running program: %w", err)
793799
}
800+
801+
appModel, ok := finalModel.(*AppModel)
802+
if !ok {
803+
return fmt.Errorf("unexpected model type: %T", finalModel)
804+
805+
return appModel.err
794806
}
795807

796808
func main() {

0 commit comments

Comments
 (0)