Skip to content

Commit cc2f69a

Browse files
committed
wip: matcher wasm module
Signed-off-by: Hank Donnay <[email protected]>
1 parent f992bec commit cc2f69a

File tree

8 files changed

+749
-0
lines changed

8 files changed

+749
-0
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ require (
2222
github.com/quay/goval-parser v0.8.8
2323
github.com/remind101/migrate v0.0.0-20170729031349-52c1edff7319
2424
github.com/spdx/tools-golang v0.5.6
25+
github.com/tetratelabs/wazero v1.11.0
2526
github.com/ulikunitz/xz v0.5.15
2627
go.opentelemetry.io/otel v1.39.0
2728
go.opentelemetry.io/otel/trace v1.39.0

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
105105
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
106106
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
107107
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
108+
github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA=
109+
github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU=
108110
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
109111
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
110112
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=

internal/matcher/wasm/host.go

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
package wasm
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"reflect"
7+
"slices"
8+
"strings"
9+
"sync"
10+
"unsafe"
11+
12+
"github.com/tetratelabs/wazero"
13+
"github.com/tetratelabs/wazero/api"
14+
15+
"github.com/quay/claircore"
16+
)
17+
18+
// PtrMember is a helper to take a pointer to a Go struct, then return a
19+
// pointer that's contained as a field.
20+
func ptrMember(off uintptr) api.GoModuleFunc {
21+
return func(ctx context.Context, mod api.Module, stack []uint64) {
22+
// Take in *A, which has a *B at offset "off".
23+
ref := unsafe.Pointer(api.DecodeExternref(stack[0])) // Shouldn't be nil.
24+
ptrField := unsafe.Add(ref, off) // This pointer can't be nil.
25+
ptr := *(*unsafe.Pointer)(ptrField) // Can be nil.
26+
stack[0] = api.EncodeExternref(uintptr(ptr))
27+
}
28+
}
29+
30+
// PtrToMember is a helper to take a pointer to a Go struct, then return a
31+
// pointer to a contained field.
32+
func ptrToMember(off uintptr) api.GoModuleFunc {
33+
return func(ctx context.Context, mod api.Module, stack []uint64) {
34+
// Take in *A, which has a B at offset "off".
35+
ref := unsafe.Pointer(api.DecodeExternref(stack[0])) // Shouldn't be nil.
36+
ptr := unsafe.Add(ref, off) // This pointer can't be nil.
37+
stack[0] = api.EncodeExternref(uintptr(ptr))
38+
}
39+
}
40+
41+
// StringMember is a helper to take a pointer to a Go struct, then return a
42+
// copy of a string member to a caller-allocated buffer.
43+
func stringMember(off uintptr) api.GoModuleFunc {
44+
return func(ctx context.Context, mod api.Module, stack []uint64) {
45+
// Unsure of another way to get this length information.
46+
h := (*reflect.StringHeader)(unsafe.Pointer(api.DecodeExternref(stack[0]) + off))
47+
offset := api.DecodeU32(stack[1])
48+
lim := int(api.DecodeU32(stack[2]))
49+
s := unsafe.String((*byte)(unsafe.Pointer(h.Data)), h.Len)
50+
sz := min(lim, len(s))
51+
if sz == 0 {
52+
stack[0] = api.EncodeI32(0)
53+
return
54+
}
55+
s = s[:sz]
56+
mem := mod.ExportedMemory("memory")
57+
if mem.WriteString(offset, s) {
58+
stack[0] = api.EncodeI32(int32(sz))
59+
} else {
60+
stack[0] = api.EncodeI32(0)
61+
}
62+
}
63+
}
64+
65+
// StringerMember is a helper to take a pointer to a Go struct, then place the
66+
// string representation of a member into a caller-allocated buffer.
67+
func stringerMember(off uintptr) api.GoModuleFunc {
68+
return func(ctx context.Context, mod api.Module, stack []uint64) {
69+
iface := (any)(unsafe.Pointer(api.DecodeExternref(stack[0]) + off)).(fmt.Stringer)
70+
offset := api.DecodeU32(stack[1])
71+
lim := int(api.DecodeU32(stack[2]))
72+
s := iface.String()
73+
sz := min(lim, len(s))
74+
if mod.ExportedMemory("memory").WriteString(offset, s[:sz]) {
75+
stack[0] = api.EncodeI32(int32(sz))
76+
} else {
77+
stack[0] = api.EncodeI32(0)
78+
}
79+
}
80+
}
81+
82+
// NotNil checks that the passed externref is not-nil.
83+
//
84+
// This is needed because externrefs are unobservable from within WASM; they
85+
// can only be handed back to the host and not manipulated in any way.
86+
func notNil(ctx context.Context, mod api.Module, stack []uint64) {
87+
if api.DecodeExternref(stack[0]) != 0 {
88+
stack[0] = api.EncodeI32(1)
89+
} else {
90+
stack[0] = api.EncodeI32(0)
91+
}
92+
}
93+
94+
type methodSpec struct {
95+
Name string
96+
Func api.GoModuleFunc
97+
Params []api.ValueType
98+
ParamNames []string
99+
Results []api.ValueType
100+
ResultNames []string
101+
}
102+
103+
func (s *methodSpec) Build(b wazero.HostModuleBuilder) wazero.HostModuleBuilder {
104+
return b.NewFunctionBuilder().
105+
WithName(s.Name).
106+
WithParameterNames(s.ParamNames...).
107+
WithResultNames(s.ResultNames...).
108+
WithGoModuleFunction(s.Func, s.Params, s.Results).
109+
Export(s.Name)
110+
}
111+
112+
func gettersFor[T any]() []methodSpec {
113+
t := reflect.TypeFor[T]()
114+
recv := strings.ToLower(t.Name())
115+
out := make([]methodSpec, 0, t.NumField())
116+
117+
switch t {
118+
// These types are passed-in and always valid.
119+
case reflect.TypeFor[claircore.IndexRecord](),
120+
reflect.TypeFor[claircore.Vulnerability]():
121+
default:
122+
out = append(out, methodSpec{
123+
Name: fmt.Sprintf("%s_valid", recv),
124+
Func: notNil,
125+
Params: []api.ValueType{api.ValueTypeExternref},
126+
Results: []api.ValueType{api.ValueTypeI32},
127+
ParamNames: []string{recv + "Ref"},
128+
ResultNames: []string{"ok"},
129+
})
130+
}
131+
for i := 0; i < t.NumField(); i++ {
132+
sf := t.Field(i)
133+
if !sf.IsExported() {
134+
continue
135+
}
136+
if sf.Name == "ID" { // Skip "id" fields.
137+
continue
138+
}
139+
140+
ft := sf.Type
141+
tgt := strings.ToLower(sf.Name)
142+
// Do some fixups:
143+
switch tgt {
144+
case "dist":
145+
tgt = "distribution"
146+
case "arch":
147+
tgt = "architecture"
148+
case "repo":
149+
tgt = "repository"
150+
}
151+
mi := len(out)
152+
out = append(out, methodSpec{})
153+
m := &out[mi]
154+
m.Name = fmt.Sprintf("%s_get_%s", recv, tgt)
155+
switch ft.Kind() {
156+
case reflect.Pointer:
157+
m.Func = ptrMember(sf.Offset)
158+
m.Params = []api.ValueType{api.ValueTypeExternref}
159+
m.Results = []api.ValueType{api.ValueTypeExternref}
160+
m.ParamNames = []string{recv + "Ref"}
161+
m.ResultNames = []string{tgt + "Ref"}
162+
case reflect.String:
163+
m.Func = stringMember(sf.Offset)
164+
m.Params = []api.ValueType{api.ValueTypeExternref, api.ValueTypeI32, api.ValueTypeI32}
165+
m.Results = []api.ValueType{api.ValueTypeI32}
166+
m.ParamNames = []string{recv + "Ref", "buf", "buf_len"}
167+
m.ResultNames = []string{"len"}
168+
case reflect.Struct:
169+
switch {
170+
case ft == reflect.TypeFor[claircore.Version]():
171+
m.Func = ptrToMember(sf.Offset)
172+
m.Params = []api.ValueType{api.ValueTypeExternref}
173+
m.Results = []api.ValueType{api.ValueTypeExternref}
174+
m.ParamNames = []string{recv + "Ref"}
175+
m.ResultNames = []string{tgt + "Ref"}
176+
case ft.Implements(reflect.TypeFor[fmt.Stringer]()):
177+
m.Func = stringerMember(sf.Offset)
178+
m.Params = []api.ValueType{api.ValueTypeExternref, api.ValueTypeI32, api.ValueTypeI32}
179+
m.Results = []api.ValueType{api.ValueTypeI32}
180+
m.ParamNames = []string{recv + "Ref", "buf", "buf_len"}
181+
m.ResultNames = []string{"len"}
182+
default:
183+
out = out[:mi]
184+
}
185+
default:
186+
out = out[:mi]
187+
}
188+
}
189+
190+
return slices.Clip(out)
191+
}
192+
193+
var hostV1Interface = sync.OnceValue(func() []methodSpec {
194+
return slices.Concat(
195+
gettersFor[claircore.IndexRecord](),
196+
gettersFor[claircore.Detector](),
197+
gettersFor[claircore.Distribution](),
198+
gettersFor[claircore.Package](),
199+
gettersFor[claircore.Range](),
200+
gettersFor[claircore.Repository](),
201+
gettersFor[claircore.Version](),
202+
gettersFor[claircore.Vulnerability](),
203+
)
204+
})
205+
206+
func buildHostV1Interface(rt wazero.Runtime) wazero.HostModuleBuilder {
207+
b := rt.NewHostModuleBuilder("claircore_matcher_1")
208+
for _, spec := range hostV1Interface() {
209+
b = spec.Build(b)
210+
}
211+
return b
212+
}

internal/matcher/wasm/host_test.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package wasm
2+
3+
import (
4+
"maps"
5+
"os"
6+
"slices"
7+
"strings"
8+
"testing"
9+
10+
"github.com/google/go-cmp/cmp"
11+
"github.com/tetratelabs/wazero"
12+
"github.com/tetratelabs/wazero/api"
13+
14+
"github.com/quay/claircore"
15+
"github.com/quay/claircore/libvuln/driver"
16+
)
17+
18+
func testRuntimeConfig(t *testing.T) wazero.RuntimeConfig {
19+
t.Helper()
20+
const pages = 1024
21+
cache, err := wazero.NewCompilationCacheWithDir(t.TempDir())
22+
if err != nil {
23+
t.Logf("unable to create cache: %v", err)
24+
cache = wazero.NewCompilationCache()
25+
}
26+
return wazero.NewRuntimeConfig().
27+
WithCloseOnContextDone(true).
28+
WithCompilationCache(cache).
29+
WithCustomSections(true).
30+
WithMemoryLimitPages(pages).
31+
WithCoreFeatures(api.CoreFeaturesV2)
32+
}
33+
34+
func TestHostV1(t *testing.T) {
35+
ctx := t.Context()
36+
rt := wazero.NewRuntimeWithConfig(ctx, testRuntimeConfig(t))
37+
mod, err := buildHostV1Interface(rt).Compile(ctx)
38+
if err != nil {
39+
t.Fatal(err)
40+
}
41+
fns := mod.ExportedFunctions()
42+
keys := slices.Collect(maps.Keys(fns))
43+
slices.Sort(keys)
44+
var b strings.Builder
45+
46+
writelist := func(ts []api.ValueType, ns []string) {
47+
b.WriteByte('(')
48+
for i := range ts {
49+
if i != 0 {
50+
b.WriteString(", ")
51+
}
52+
b.WriteString(ns[i])
53+
b.WriteString(": ")
54+
switch ts[i] {
55+
case api.ValueTypeExternref:
56+
b.WriteString("externref")
57+
case api.ValueTypeI32:
58+
b.WriteString("i32")
59+
case api.ValueTypeI64:
60+
b.WriteString("i64")
61+
case api.ValueTypeF32:
62+
b.WriteString("f32")
63+
case api.ValueTypeF64:
64+
b.WriteString("f64")
65+
default:
66+
b.WriteString("???")
67+
}
68+
}
69+
b.WriteByte(')')
70+
}
71+
for _, k := range keys {
72+
v := fns[k]
73+
b.Reset()
74+
b.WriteString(v.DebugName())
75+
writelist(v.ParamTypes(), v.ParamNames())
76+
b.WriteString(" -> ")
77+
writelist(v.ResultTypes(), v.ResultNames())
78+
79+
t.Log(b.String())
80+
}
81+
}
82+
83+
func TestTrivial(t *testing.T) {
84+
ctx := t.Context()
85+
f, err := os.Open("testdata/trivial.wasm")
86+
if err != nil {
87+
t.Fatal(err)
88+
}
89+
defer f.Close()
90+
91+
m, err := NewMatcher(ctx, "trivial", f)
92+
if err != nil {
93+
t.Fatal(err)
94+
}
95+
96+
t.Run("Query", func(t *testing.T) {
97+
want := []driver.MatchConstraint{driver.PackageName, driver.HasFixedInVersion}
98+
got := m.Query()
99+
if !cmp.Equal(got, want) {
100+
t.Error(cmp.Diff(got, want))
101+
}
102+
})
103+
104+
t.Log(`testing trvial matcher: "Filter() == true" when "len(IndexRecord.Package.Name) != 0"`)
105+
r := &claircore.IndexRecord{
106+
Package: &claircore.Package{Name: "pkg"},
107+
}
108+
ok := m.Filter(r)
109+
t.Logf("package name %q: %v", r.Package.Name, ok)
110+
if !ok {
111+
t.Fail()
112+
}
113+
114+
r.Package = new(claircore.Package)
115+
ok = m.Filter(r)
116+
t.Logf("package name %q: %v", r.Package.Name, ok)
117+
if ok {
118+
t.Fail()
119+
}
120+
}

0 commit comments

Comments
 (0)