@@ -3,18 +3,74 @@ package pq
33import (
44 "crypto/tls"
55 "crypto/x509"
6+ "fmt"
67 "io/ioutil"
78 "net"
89 "os"
910 "os/user"
1011 "path/filepath"
12+ "strings"
13+ "sync"
1114)
1215
16+ // Registry for custom tls.Configs
17+ var (
18+ tlsConfigLock sync.RWMutex
19+ tlsConfigRegistry map [string ]* tls.Config
20+ )
21+
22+ func RegisterTLSConfig (key string , config * tls.Config ) error {
23+ if _ , isBool := readBool (key ); isBool || strings .ToLower (key ) == "require" || strings .ToLower (key ) == "verify-ca" || strings .ToLower (key ) == "verify-full" || strings .ToLower (key ) == "disable" {
24+ return fmt .Errorf ("key '%s' is reserved" , key )
25+ }
26+
27+ tlsConfigLock .Lock ()
28+ if tlsConfigRegistry == nil {
29+ tlsConfigRegistry = make (map [string ]* tls.Config )
30+ }
31+
32+ tlsConfigRegistry [key ] = config
33+ tlsConfigLock .Unlock ()
34+ return nil
35+ }
36+
37+ // DeregisterTLSConfig removes the tls.Config associated with key.
38+ func DeregisterTLSConfig (key string ) {
39+ tlsConfigLock .Lock ()
40+ if tlsConfigRegistry != nil {
41+ delete (tlsConfigRegistry , key )
42+ }
43+ tlsConfigLock .Unlock ()
44+ }
45+
46+ func getTLSConfigClone (key string ) (config * tls.Config ) {
47+ tlsConfigLock .RLock ()
48+ if v , ok := tlsConfigRegistry [key ]; ok {
49+ config = v .Clone ()
50+ }
51+ tlsConfigLock .RUnlock ()
52+ return
53+ }
54+
55+ // Returns the bool value of the input.
56+ // The 2nd return value indicates if the input was a valid bool value
57+ func readBool (input string ) (value bool , valid bool ) {
58+ switch input {
59+ case "1" , "true" , "TRUE" , "True" :
60+ return true , true
61+ case "0" , "false" , "FALSE" , "False" :
62+ return false , true
63+ }
64+
65+ // Not a valid bool value
66+ return
67+ }
68+
1369// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
1470// related settings. The function is nil when no upgrade should take place.
1571func ssl (o values ) (func (net.Conn ) (net.Conn , error ), error ) {
1672 verifyCaOnly := false
17- tlsConf := tls.Config {}
73+ tlsConf := & tls.Config {}
1874 switch mode := o ["sslmode" ]; mode {
1975 // "require" is the default.
2076 case "" , "require" :
@@ -47,14 +103,19 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
47103 case "disable" :
48104 return nil , nil
49105 default :
50- return nil , fmterrorf (`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported` , mode )
106+ {
107+ tlsConf = getTLSConfigClone (mode )
108+ if tlsConf == nil {
109+ return nil , fmterrorf (`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported` , mode )
110+ }
111+ }
51112 }
52113
53- err := sslClientCertificates (& tlsConf , o )
114+ err := sslClientCertificates (tlsConf , o )
54115 if err != nil {
55116 return nil , err
56117 }
57- err = sslCertificateAuthority (& tlsConf , o )
118+ err = sslCertificateAuthority (tlsConf , o )
58119 if err != nil {
59120 return nil , err
60121 }
@@ -67,9 +128,9 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
67128 tlsConf .Renegotiation = tls .RenegotiateFreelyAsClient
68129
69130 return func (conn net.Conn ) (net.Conn , error ) {
70- client := tls .Client (conn , & tlsConf )
131+ client := tls .Client (conn , tlsConf )
71132 if verifyCaOnly {
72- err := sslVerifyCertificateAuthority (client , & tlsConf )
133+ err := sslVerifyCertificateAuthority (client , tlsConf )
73134 if err != nil {
74135 return nil , err
75136 }
0 commit comments