Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"path/filepath"
"regexp"
"strings"
"sync"
"text/template"
"time"

Expand Down Expand Up @@ -60,6 +61,7 @@ var (
year = flag.String("y", fmt.Sprint(time.Now().Year()), "copyright year(s)")
verbose = flag.Bool("v", false, "verbose mode: print the name of the files that are modified or were skipped")
checkonly = flag.Bool("check", false, "check only mode: verify presence of license headers and exit with non-zero code if missing")
strict = flag.Bool("strict", false, "check license headers in strict mode (files with unknown extensions treated as errors)")
)

func init() {
Expand Down Expand Up @@ -114,6 +116,10 @@ func main() {
os.Exit(1)
}

if *strict && !(*checkonly) {
log.Fatal("-strict flag must only be used together with -check flag")
}

// convert -skip flags to -ignore equivalents
for _, s := range skipExtensionFlags {
ignorePatterns = append(ignorePatterns, fmt.Sprintf("**/*.%s", s))
Expand Down Expand Up @@ -148,6 +154,8 @@ func main() {
// process at most 1000 files in parallel
ch := make(chan *file, 1000)
done := make(chan struct{})
mu := sync.Mutex{} // Protect access to the nonTerminalErrors slice
nonTerminalErrors := []error{}
go func() {
var wg errgroup.Group
for f := range ch {
Expand All @@ -161,6 +169,13 @@ func main() {
return err
}
if lic == nil { // Unknown fileExtension
if *strict {
mu.Lock()
nonTerminalErrors = append(nonTerminalErrors, fmt.Errorf("unknown extension: %s", f.path))
mu.Unlock()
} else {
log.Printf("unknown extension: %s", f.path)
}
return nil
}
// Check if file has a license
Expand All @@ -187,10 +202,21 @@ func main() {
})
}
err := wg.Wait()
close(done)

foundNonTerminalErrors := len(nonTerminalErrors) > 0
if foundNonTerminalErrors {
for _, e := range nonTerminalErrors {
log.Printf("%v", e)
}
}
if err != nil {
log.Printf("%v", err)
}
if foundNonTerminalErrors || err != nil {
os.Exit(1)
}

close(done)
}()

for _, d := range flag.Args() {
Expand Down