Skip to content

Commit 9298252

Browse files
committed
feat(bridge) get default value from flag and add to schema
1 parent ac23673 commit 9298252

File tree

2 files changed

+112
-2
lines changed

2 files changed

+112
-2
lines changed

internal/bridge/tool.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package bridge
22

33
import (
4+
"encoding/json"
45
"fmt"
56
"log/slog"
7+
"strconv"
68
"strings"
79

810
"github.com/google/jsonschema-go/jsonschema"
@@ -204,9 +206,100 @@ func addFlagToSchema(schema *jsonschema.Schema, flag *pflag.Flag) {
204206
slog.Debug("unknown flag type, defaulting to string", "flag", flag.Name, "type", t)
205207
}
206208

209+
setDefaultValue(flagSchema, flag.DefValue)
207210
schema.Properties[flag.Name] = flagSchema
208211
}
209212

213+
// setDefaultValue sets the default value for a flag schema if it's not a zero value.
214+
func setDefaultValue(flagSchema *jsonschema.Schema, defValue string) {
215+
if defValue == "" {
216+
return
217+
}
218+
219+
setDefault := func(val any) {
220+
if raw, err := json.Marshal(val); err == nil {
221+
flagSchema.Default = json.RawMessage(raw)
222+
}
223+
}
224+
225+
// Parse the default value based on the schema type
226+
switch flagSchema.Type {
227+
case "boolean":
228+
if val, err := strconv.ParseBool(defValue); err == nil {
229+
setDefault(val)
230+
}
231+
case "integer":
232+
if val, err := strconv.ParseInt(defValue, 10, 64); err == nil {
233+
setDefault(val)
234+
}
235+
case "number":
236+
if val, err := strconv.ParseFloat(defValue, 64); err == nil {
237+
setDefault(val)
238+
}
239+
case "string":
240+
setDefault(defValue)
241+
case "array":
242+
// Handle array types (slices)
243+
// pflag represents empty slices as "[]"
244+
if defValue == "[]" {
245+
return
246+
}
247+
// pflag represents arrays as "[item1,item2,item3]"
248+
// We need to manually parse this into an actual JSON array
249+
// --- Ewwww Gross ---
250+
if strings.HasPrefix(defValue, "[") && strings.HasSuffix(defValue, "]") {
251+
// Remove the brackets
252+
inner := defValue[1 : len(defValue)-1]
253+
if inner == "" {
254+
return // Empty array
255+
}
256+
// Split by comma
257+
parts := strings.Split(inner, ",")
258+
// Determine the array item type from the schema
259+
if flagSchema.Items != nil {
260+
switch flagSchema.Items.Type {
261+
case "integer":
262+
// Parse as integer array
263+
intArr := make([]int64, 0, len(parts))
264+
for _, p := range parts {
265+
if val, err := strconv.ParseInt(strings.TrimSpace(p), 10, 64); err == nil {
266+
intArr = append(intArr, val)
267+
}
268+
}
269+
setDefault(intArr)
270+
case "number":
271+
// Parse as float array
272+
floatArr := make([]float64, 0, len(parts))
273+
for _, p := range parts {
274+
if val, err := strconv.ParseFloat(strings.TrimSpace(p), 64); err == nil {
275+
floatArr = append(floatArr, val)
276+
}
277+
}
278+
setDefault(floatArr)
279+
case "boolean":
280+
// Parse as boolean array
281+
boolArr := make([]bool, 0, len(parts))
282+
for _, p := range parts {
283+
if val, err := strconv.ParseBool(strings.TrimSpace(p)); err == nil {
284+
boolArr = append(boolArr, val)
285+
}
286+
}
287+
setDefault(boolArr)
288+
case "string":
289+
// String array - trim whitespace from each element
290+
strArr := make([]string, 0, len(parts))
291+
for _, p := range parts {
292+
strArr = append(strArr, strings.TrimSpace(p))
293+
}
294+
setDefault(strArr)
295+
}
296+
297+
// there are no array of objects in pflag, so we don't handle that case
298+
}
299+
}
300+
}
301+
}
302+
210303
// enhanceArgsSchema adds detailed argument information to the args property.
211304
func enhanceArgsSchema(schema *jsonschema.Schema, cmd *cobra.Command) {
212305
description := "Positional command line arguments"

internal/bridge/tool_test.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ func TestCreateToolFromCmd(t *testing.T) {
2323
// Add some flags
2424
cmd.Flags().String("output", "", "Output file")
2525
cmd.Flags().Bool("verbose", false, "Verbose output")
26-
cmd.Flags().StringSlice("include", []string{}, "Include patterns")
26+
cmd.Flags().IntSlice("include", []int{}, "Include patterns")
27+
cmd.Flags().StringSlice("greeting", []string{"hello", "world"}, "Include patterns")
2728
cmd.Flags().Int("count", 10, "Number of items")
2829

2930
// Add a hidden flag
@@ -73,6 +74,7 @@ func TestCreateToolFromCmd(t *testing.T) {
7374
assert.Contains(t, flagsSchema.Properties, "verbose")
7475
assert.Contains(t, flagsSchema.Properties, "include")
7576
assert.Contains(t, flagsSchema.Properties, "count")
77+
assert.Contains(t, flagsSchema.Properties, "greeting")
7678

7779
// Verify excluded flags
7880
assert.NotContains(t, flagsSchema.Properties, "hidden", "Should not include hidden flag")
@@ -83,15 +85,30 @@ func TestCreateToolFromCmd(t *testing.T) {
8385
assert.Equal(t, "boolean", flagsSchema.Properties["verbose"].Type)
8486
assert.Equal(t, "array", flagsSchema.Properties["include"].Type)
8587
assert.Equal(t, "integer", flagsSchema.Properties["count"].Type)
88+
assert.Equal(t, "array", flagsSchema.Properties["greeting"].Type)
8689

8790
// Verify required flags
8891
require.Len(t, flagsSchema.Required, 1, "Should have 1 required flag")
8992
assert.Contains(t, flagsSchema.Required, "count", "count flag should be marked as required")
9093

94+
// Verify default values
95+
assert.NotNil(t, flagsSchema.Properties["verbose"].Default)
96+
assert.JSONEq(t, "false", string(flagsSchema.Properties["verbose"].Default))
97+
assert.NotNil(t, flagsSchema.Properties["count"].Default)
98+
assert.JSONEq(t, "10", string(flagsSchema.Properties["count"].Default))
99+
assert.NotNil(t, flagsSchema.Properties["greeting"].Default)
100+
assert.JSONEq(t, `["hello","world"]`, string(flagsSchema.Properties["greeting"].Default))
101+
// Empty string and empty array should not have defaults set
102+
assert.Nil(t, flagsSchema.Properties["output"].Default)
103+
assert.Nil(t, flagsSchema.Properties["include"].Default)
104+
91105
// Verify array items schema
92106
includeSchema := flagsSchema.Properties["include"]
93107
assert.NotNil(t, includeSchema.Items)
94-
assert.Equal(t, "string", includeSchema.Items.Type)
108+
assert.Equal(t, "integer", includeSchema.Items.Type)
109+
greetingSchema := flagsSchema.Properties["greeting"]
110+
assert.NotNil(t, greetingSchema.Items)
111+
assert.Equal(t, "string", greetingSchema.Items.Type)
95112

96113
// Verify persistent flag from parent command
97114
assert.Contains(t, flagsSchema.Properties, "config", "Should include persistent flag from parent command")

0 commit comments

Comments
 (0)