diff --git a/example/gen.go b/example/gen.go index d5f0a9d..ea7c82d 100644 --- a/example/gen.go +++ b/example/gen.go @@ -1,4 +1,4 @@ package main -//go:generate go run ../cmd/cfgx/main.go generate --in config/config.toml --out config/config.go --pkg config -//go:generate go run ../cmd/cfgx/main.go generate --in config/config.toml --out getter_config/config.go --pkg getter_config --mode getter +//go:generate go run ../cmd/cfgx generate --in config/config.toml --out config/config.go --pkg config +//go:generate go run ../cmd/cfgx generate --in config/config.toml --out getter_config/config.go --pkg getter_config --mode getter diff --git a/example/getter_config/config.go b/example/getter_config/config.go index 3e43b18..b5b70c1 100644 --- a/example/getter_config/config.go +++ b/example/getter_config/config.go @@ -427,13 +427,19 @@ func (serviceConfig) Weights() []float64 { return []float64{1, 2.5, 3.7} } +func Name() string { + if v := os.Getenv("CONFIG_NAME"); v != "" { + return v + } + return "cfgx" +} + var ( App appConfig Cache cacheConfig Database databaseConfig Endpoints []endpointsItem Features []featuresItem - Name string Server serverConfig Service serviceConfig ) diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index 577e11b..6844028 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -344,3 +344,52 @@ tls_key = "file:files/small.txt" require.Contains(t, outputStr, "func (serverConfig) TlsKey() []byte", "output missing TlsKey getter") require.Contains(t, outputStr, `os.Getenv("CONFIG_SERVER_TLS_KEY")`, "output missing file path env var check for key") } + +func TestGenerator_GetterMode_TopLevelVariables(t *testing.T) { + data := []byte(` +name = "myapp" +version = "1.0.0" +port = 8080 +debug = true +timeout = "30s" + +[server] +addr = ":8080" +`) + + gen := New(WithPackageName("config"), WithMode("getter")) + output, err := gen.Generate(data) + require.NoError(t, err, "Generate() should not error") + + outputStr := string(output) + + // Check top-level getters are generated (not vars) + require.Contains(t, outputStr, "func Name() string", "output missing Name getter") + require.Contains(t, outputStr, "func Version() string", "output missing Version getter") + require.Contains(t, outputStr, "func Port() int64", "output missing Port getter") + require.Contains(t, outputStr, "func Debug() bool", "output missing Debug getter") + require.Contains(t, outputStr, "func Timeout() time.Duration", "output missing Timeout getter") + + // Check env var overrides in top-level getters + require.Contains(t, outputStr, `os.Getenv("CONFIG_NAME")`, "output missing CONFIG_NAME env var check") + require.Contains(t, outputStr, `os.Getenv("CONFIG_VERSION")`, "output missing CONFIG_VERSION env var check") + require.Contains(t, outputStr, `os.Getenv("CONFIG_PORT")`, "output missing CONFIG_PORT env var check") + require.Contains(t, outputStr, `os.Getenv("CONFIG_DEBUG")`, "output missing CONFIG_DEBUG env var check") + require.Contains(t, outputStr, `os.Getenv("CONFIG_TIMEOUT")`, "output missing CONFIG_TIMEOUT env var check") + + // Check default values are returned + require.Contains(t, outputStr, `return "myapp"`, "output missing name default value") + require.Contains(t, outputStr, `return "1.0.0"`, "output missing version default value") + require.Contains(t, outputStr, "return 8080", "output missing port default value") + require.Contains(t, outputStr, "return true", "output missing debug default value") + require.Contains(t, outputStr, "return 30 * time.Second", "output missing timeout default value") + + // Verify top-level simple variables are NOT in var block + require.NotContains(t, outputStr, "Name string", "top-level Name should be a getter, not a var") + require.NotContains(t, outputStr, "Version string", "top-level Version should be a getter, not a var") + require.NotContains(t, outputStr, "Port int64", "top-level Port should be a getter, not a var") + + // But structs should still be in var block + require.Contains(t, outputStr, "var (", "output missing var block") + require.Contains(t, outputStr, "Server serverConfig", "output missing Server var declaration") +} diff --git a/internal/generator/struct_gen.go b/internal/generator/struct_gen.go index f2e9610..dde3ca0 100644 --- a/internal/generator/struct_gen.go +++ b/internal/generator/struct_gen.go @@ -412,13 +412,42 @@ func (g *Generator) generateStructsAndGetters(buf *bytes.Buffer, data map[string } } - // Generate var declarations + // Generate top-level getter functions for simple variables + for _, key := range keys { + value := data[key] + + // Only generate getters for non-struct, non-array-of-structs values + switch val := value.(type) { + case map[string]any, []map[string]any: + // Skip structs - they will be var declarations + continue + case []any: + // Check if it's an array of maps (structs) + if len(val) > 0 { + if _, ok := val[0].(map[string]any); ok { + // Skip array of structs + continue + } + } + // Generate getter for array of primitives + if err := g.generateTopLevelGetter(buf, key, value); err != nil { + return err + } + default: + // Generate getter for simple types + if err := g.generateTopLevelGetter(buf, key, value); err != nil { + return err + } + } + } + + // Generate var declarations (only for structs and arrays of structs) buf.WriteString("var (\n") for _, key := range keys { varName := sx.PascalCase(key) value := data[key] - switch value.(type) { + switch val := value.(type) { case map[string]any: structName := sx.CamelCase(key) + "Config" fmt.Fprintf(buf, "\t%s %s\n", varName, structName) @@ -426,11 +455,13 @@ func (g *Generator) generateStructsAndGetters(buf *bytes.Buffer, data map[string structName := sx.CamelCase(key) + "Item" fmt.Fprintf(buf, "\t%s []%s\n", varName, structName) case []any: - goType := g.toGoType(value) - fmt.Fprintf(buf, "\t%s %s\n", varName, goType) - default: - goType := g.toGoType(value) - fmt.Fprintf(buf, "\t%s %s\n", varName, goType) + // Check if it's an array of maps (structs) + if len(val) > 0 { + if _, ok := val[0].(map[string]any); ok { + structName := sx.CamelCase(key) + "Item" + fmt.Fprintf(buf, "\t%s []%s\n", varName, structName) + } + } } } buf.WriteString(")\n") @@ -546,7 +577,26 @@ func (g *Generator) generateGetterMethods(buf *bytes.Buffer, structName string, // generateGetterMethod generates a single getter method with env var override. func (g *Generator) generateGetterMethod(buf *bytes.Buffer, structName, fieldName, goType, envVarName string, defaultValue any) error { fmt.Fprintf(buf, "func (%s) %s() %s {\n", structName, fieldName, goType) + g.writeGetterBody(buf, goType, envVarName, defaultValue) + buf.WriteString("}\n\n") + return nil +} + +// generateTopLevelGetter generates a top-level getter function (not a method) for simple variables. +func (g *Generator) generateTopLevelGetter(buf *bytes.Buffer, varName string, defaultValue any) error { + funcName := sx.PascalCase(varName) + goType := g.toGoType(defaultValue) + envVarName := "CONFIG_" + strings.ToUpper(varName) + fmt.Fprintf(buf, "func %s() %s {\n", funcName, goType) + g.writeGetterBody(buf, goType, envVarName, defaultValue) + buf.WriteString("}\n\n") + return nil +} + +// writeGetterBody generates the common body logic for getter functions/methods. +// This handles env var checking, type conversion, and default value fallback. +func (g *Generator) writeGetterBody(buf *bytes.Buffer, goType, envVarName string, defaultValue any) { // Special handling for []byte (file references) - check for file path in env var if goType == "[]byte" { buf.WriteString("\t// Check for file path to load\n") @@ -559,8 +609,7 @@ func (g *Generator) generateGetterMethod(buf *bytes.Buffer, structName, fieldNam buf.WriteString("\treturn ") g.writeValue(buf, defaultValue) buf.WriteString("\n") - buf.WriteString("}\n\n") - return nil + return } // For other types, check env var with type conversion @@ -599,9 +648,6 @@ func (g *Generator) generateGetterMethod(buf *bytes.Buffer, structName, fieldNam buf.WriteString("\treturn ") g.writeValue(buf, defaultValue) buf.WriteString("\n") - - buf.WriteString("}\n\n") - return nil } // envVarName generates an environment variable name from a struct name and field name.