Skip to content

Commit d2409a9

Browse files
daqingshuarp242
andcommitted
Add support for custom tls.Config
Allow registering a custom TLS configuration, for example to use encrypted keys. Based on #1066 Fixes #766 Fixes #789 Fixes #811 Fixes #849 Co-authored-by: Martin Tournoij <martin@arp242.net>
1 parent 6ec2ad4 commit d2409a9

File tree

4 files changed

+94
-11
lines changed

4 files changed

+94
-11
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ newer. Previously PostgreSQL 8.4 and newer were supported.
2828
12 | );
2929
^
3030

31+
- Allow using a custom `tls.Config`, for example for encrypted keys ([#1228]).
32+
3133
- Add `PQGO_DEBUG=1` print the communication with PostgreSQL to stderr, to aid
3234
in debugging, testing, and bug reports ([#1223]).
3335

@@ -88,6 +90,7 @@ newer. Previously PostgreSQL 8.4 and newer were supported.
8890
[#1223]: https://github.com/lib/pq/pull/1223
8991
[#1224]: https://github.com/lib/pq/pull/1224
9092
[#1226]: https://github.com/lib/pq/pull/1226
93+
[#1228]: https://github.com/lib/pq/pull/1228
9194

9295

9396
v1.10.9 (2023-04-26)

doc.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ Valid values for sslmode are:
6666
- verify-full - Always SSL (verify that the certification presented by
6767
the server was signed by a trusted CA and the server host name
6868
matches the one in the certificate)
69+
- A custom TLS configuration registered with [RegisterTLSConfig]. These must
70+
be prefixed with "pqgo-".
6971
7072
See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
7173
for more information about connection string parameters.

example_test.go

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package pq_test
22

33
import (
4+
"crypto/tls"
5+
"crypto/x509"
46
"database/sql"
57
"fmt"
68
"log"
9+
"os"
710

811
"github.com/lib/pq"
912
)
@@ -23,6 +26,7 @@ func ExampleNewConnector() {
2326
log.Fatalf("could not start transaction: %v", err)
2427
}
2528
tx.Rollback()
29+
// Output:
2630
}
2731

2832
func ExampleConnectorWithNoticeHandler() {
@@ -44,7 +48,38 @@ func ExampleConnectorWithNoticeHandler() {
4448
if _, err := db.Exec(sql); err != nil {
4549
log.Fatal(err)
4650
}
47-
4851
// Output:
4952
// Notice sent: test notice
5053
}
54+
55+
func ExampleRegisterTLSConfig() {
56+
pem, err := os.ReadFile("testdata/init/root.crt")
57+
if err != nil {
58+
log.Fatal(err)
59+
}
60+
61+
root := x509.NewCertPool()
62+
root.AppendCertsFromPEM(pem)
63+
64+
certs, err := tls.LoadX509KeyPair("testdata/init/postgresql.crt", "testdata/init/postgresql.key")
65+
if err != nil {
66+
log.Fatal(err)
67+
}
68+
69+
pq.RegisterTLSConfig("mytls", &tls.Config{
70+
RootCAs: root,
71+
Certificates: []tls.Certificate{certs},
72+
ServerName: "postgres",
73+
})
74+
75+
db, err := sql.Open("postgres", "host=postgres dbname=pqgo sslmode=pqgo-mytls")
76+
if err != nil {
77+
log.Fatal(err)
78+
}
79+
80+
err = db.Ping()
81+
if err != nil {
82+
log.Fatal(err)
83+
}
84+
// Output:
85+
}

ssl.go

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,57 @@ import (
1010
"path/filepath"
1111
"runtime"
1212
"strings"
13+
"sync"
1314
"syscall"
1415

1516
"github.com/lib/pq/internal/pqutil"
1617
)
1718

19+
// Registry for custom tls.Configs
20+
var (
21+
tlsConfs = make(map[string]*tls.Config)
22+
tlsConfsMu sync.RWMutex
23+
)
24+
25+
// RegisterTLSConfig registers a custom [tls.Config]. They are used by using
26+
// sslmode=pqgo-«key» in the connection string.
27+
//
28+
// Set the config to nil to remove a configuration.
29+
func RegisterTLSConfig(key string, config *tls.Config) error {
30+
key = strings.TrimPrefix(key, "pqgo-")
31+
if config == nil {
32+
tlsConfsMu.Lock()
33+
delete(tlsConfs, key)
34+
tlsConfsMu.Unlock()
35+
return nil
36+
}
37+
38+
tlsConfsMu.Lock()
39+
tlsConfs[key] = config
40+
tlsConfsMu.Unlock()
41+
return nil
42+
}
43+
44+
func getTLSConfigClone(key string) *tls.Config {
45+
tlsConfsMu.RLock()
46+
if v, ok := tlsConfs[key]; ok {
47+
return v.Clone()
48+
}
49+
tlsConfsMu.RUnlock()
50+
return nil
51+
}
52+
1853
// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
1954
// related settings. The function is nil when no upgrade should take place.
2055
func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
21-
verifyCaOnly := false
22-
tlsConf := tls.Config{}
23-
switch mode := o["sslmode"]; mode {
56+
var (
57+
verifyCaOnly = false
58+
tlsConf = &tls.Config{}
59+
mode = o["sslmode"]
60+
)
61+
switch {
2462
// "require" is the default.
25-
case "", "require":
63+
case mode == "" || mode == "require":
2664
// We must skip TLS's own verification since it requires full
2765
// verification since Go 1.3.
2866
tlsConf.InsecureSkipVerify = true
@@ -42,15 +80,20 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
4280
delete(o, "sslrootcert")
4381
}
4482
}
45-
case "verify-ca":
83+
case mode == "verify-ca":
4684
// We must skip TLS's own verification since it requires full
4785
// verification since Go 1.3.
4886
tlsConf.InsecureSkipVerify = true
4987
verifyCaOnly = true
50-
case "verify-full":
88+
case mode == "verify-full":
5189
tlsConf.ServerName = o["host"]
52-
case "disable":
90+
case mode == "disable":
5391
return nil, nil
92+
case strings.HasPrefix(mode, "pqgo-"):
93+
tlsConf = getTLSConfigClone(mode[5:])
94+
if tlsConf == nil {
95+
return nil, fmt.Errorf(`pq: unknown custom sslmode %q`, mode)
96+
}
5497
default:
5598
return nil, fmt.Errorf(
5699
`pq: unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`,
@@ -67,11 +110,11 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
67110
tlsConf.ServerName = o["host"]
68111
}
69112

70-
err := sslClientCertificates(&tlsConf, o)
113+
err := sslClientCertificates(tlsConf, o)
71114
if err != nil {
72115
return nil, err
73116
}
74-
err = sslCertificateAuthority(&tlsConf, o)
117+
err = sslCertificateAuthority(tlsConf, o)
75118
if err != nil {
76119
return nil, err
77120
}
@@ -84,7 +127,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
84127
tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient
85128

86129
return func(conn net.Conn) (net.Conn, error) {
87-
client := tls.Client(conn, &tlsConf)
130+
client := tls.Client(conn, tlsConf)
88131
if verifyCaOnly {
89132
err := client.Handshake()
90133
if err != nil {

0 commit comments

Comments
 (0)