diff --git a/cmd/audit2rbac/audit2rbac.go b/cmd/audit2rbac/audit2rbac.go index f7250d6..f60137a 100644 --- a/cmd/audit2rbac/audit2rbac.go +++ b/cmd/audit2rbac/audit2rbac.go @@ -59,16 +59,33 @@ func NewAudit2RBACCommand(stdout, stderr io.Writer) *cobra.Command { showVersion := false + outputFilename := "" + cmd := &cobra.Command{ Use: "audit2rbac --filename=audit.log [ --user=bob | --serviceaccount=my-namespace:my-sa ]", Short: "", Long: "", Run: func(cmd *cobra.Command, args []string) { + var tmpFile *os.File + var err error + if showVersion { fmt.Fprintln(stdout, "audit2rbac version "+pkg.Version) return } + if outputFilename != "" { + + tmpFile, err = os.CreateTemp("", "audit2rbac") + if err != nil { + fmt.Fprintln(stderr, err) + fmt.Fprintln(stderr) + cmd.Help() + os.Exit(1) + } + options.Stdout = tmpFile + } + checkErr(stderr, options.Complete(serviceAccount, args, name, annotations, labels)) if err := options.Validate(); err != nil { @@ -79,10 +96,12 @@ func NewAudit2RBACCommand(stdout, stderr io.Writer) *cobra.Command { } checkErr(stderr, options.Run()) + checkErr(stderr, os.Rename(tmpFile.Name(), outputFilename)) }, } cmd.Flags().StringArrayVarP(&options.AuditSources, "filename", "f", options.AuditSources, "File, URL, or - for STDIN to read audit events from") + cmd.Flags().StringVar(&outputFilename, "output-filename", name, "File to write output manifests to") cmd.Flags().StringVar(&options.User, "user", options.User, "User to filter audit events to and generate role bindings for") cmd.Flags().StringVar(&serviceAccount, "serviceaccount", serviceAccount, "Service account to filter audit events to and generate role bindings for, in format :") @@ -292,23 +311,23 @@ func (a *Audit2RBACOptions) Run() error { firstSeparator = false return } - fmt.Fprintln(os.Stdout, "---") + fmt.Fprintln(a.Stdout, "---") } for _, obj := range generated.Roles { printSeparator() - pkg.Output(os.Stdout, obj, "yaml") + pkg.Output(a.Stdout, obj, "yaml") } for _, obj := range generated.ClusterRoles { printSeparator() - pkg.Output(os.Stdout, obj, "yaml") + pkg.Output(a.Stdout, obj, "yaml") } for _, obj := range generated.RoleBindings { printSeparator() - pkg.Output(os.Stdout, obj, "yaml") + pkg.Output(a.Stdout, obj, "yaml") } for _, obj := range generated.ClusterRoleBindings { printSeparator() - pkg.Output(os.Stdout, obj, "yaml") + pkg.Output(a.Stdout, obj, "yaml") } fmt.Fprintln(a.Stderr, "Complete!")