Use libnftables in dynamically linked binary

Signed-off-by: Rob Murray <rob.murray@docker.com>
This commit is contained in:
Rob Murray
2025-09-22 11:14:17 +01:00
parent e98849831f
commit 6db6de2c20
8 changed files with 194 additions and 59 deletions

View File

@@ -501,6 +501,7 @@ RUN --mount=type=cache,sharing=locked,id=moby-dev-aptlib,target=/var/lib/apt \
jq \
libcap2-bin \
libnet1 \
libnftables-dev \
libnl-3-200 \
libprotobuf-c1 \
libyajl2 \
@@ -548,6 +549,7 @@ RUN --mount=type=cache,sharing=locked,id=moby-build-aptlib,target=/var/lib/apt \
xx-apt-get install --no-install-recommends -y \
gcc \
libc6-dev \
libnftables-dev \
libseccomp-dev \
libsystemd-dev \
pkg-config

View File

@@ -4,6 +4,7 @@ package nftabler
import (
"context"
"errors"
"github.com/containerd/log"
"github.com/moby/moby/v2/daemon/libnetwork/drivers/bridge/internal/firewaller"
@@ -50,6 +51,8 @@ type Nftabler struct {
table6 nftables.Table
}
// NewNftabler creates a new Nftabler instance, initializing the nftables tables.
// Call Close() on the returned Nftabler to release resources when done.
func NewNftabler(ctx context.Context, config firewaller.Config) (*Nftabler, error) {
nft := &Nftabler{config: config}
@@ -72,6 +75,12 @@ func NewNftabler(ctx context.Context, config firewaller.Config) (*Nftabler, erro
return nft, nil
}
// Close releases resources held by the Nftabler, the underlying nftables tables
// are not modified or deleted.
func (nft *Nftabler) Close() error {
return errors.Join(nft.table4.Close(), nft.table6.Close())
}
func (nft *Nftabler) init(ctx context.Context, family nftables.Family) (nftables.Table, error) {
// Instantiate the table.
table, err := nftables.NewTable(family, dockerTable)

View File

@@ -70,9 +70,11 @@ func TestNftabler(t *testing.T) {
t.Run(fmt.Sprintf("ipv4=%v/ipv6=%v/hairpin=%v/internal=%v/icc=%v/masq=%v/snat=%v/gwm=%v/bindlh=%v/wsl2mirrored=%v",
p(ipv4), p(ipv6), p(hairpin), p(internal), p(icc), p(masq), p(snat), gwmode, p(bindLocalhost), p(wsl2Mirrored)), func(t *testing.T) {
// If updating results, don't run in parallel because some of the results files are shared.
if !golden.FlagUpdate() {
t.Parallel()
}
// Tests are dynamically linked, so the nftables code under test uses cgo to call libnftables
// and they run faster without t.Parallel(). Strangely, when statically linked and exec-ing
// 'nft', they run faster in parallel.
// if !golden.FlagUpdate() { t.Parallel() }
// Combine results (golden output files) where possible to:
// - check params that should have no effect when made irrelevant by other params, and
// - minimise the number of results files.
@@ -101,7 +103,7 @@ func testNftabler(t *testing.T, tn string, config firewaller.Config, netConfig f
return
}
out := strings.ReplaceAll(res.Combined(), "type nat hook output priority -100", "type nat hook output priority dstnat")
assert.Assert(t, res.Error)
assert.Assert(t, res.Error, out)
golden.Assert(t, out, name+"__"+family+".golden")
}
@@ -129,6 +131,7 @@ func testNftabler(t *testing.T, tn string, config firewaller.Config, netConfig f
// end of the test (after deleting per-network and per-port rules).
fw, err := NewNftabler(context.Background(), config)
assert.NilError(t, err)
defer fw.Close()
checkResults("ip", rnWSL2Mirrored(fmt.Sprintf("%s/cleaned,hairpin=%v", tn, config.Hairpin)), config.IPv4)
checkResults("ip6", fmt.Sprintf("%s/cleaned,hairpin=%v", tn, config.Hairpin), config.IPv6)

View File

@@ -0,0 +1,78 @@
//go:build cgo && !static_build
package nftables
import (
"context"
"errors"
"fmt"
"unsafe"
"github.com/containerd/log"
"go.opentelemetry.io/otel"
)
// #cgo pkg-config: libnftables
// #cgo nocallback nft_run_cmd_from_buffer
// #cgo nocallback nft_ctx_get_output_buffer
// #cgo nocallback nft_ctx_get_error_buffer
// #include <stdlib.h>
// #include <nftables/libnftables.h>
import "C"
type nftHandle = *C.struct_nft_ctx
// nftApply calls libnftables to execute the nftables commands in nftCmd.
// Acquire t.applyLock before calling this function.
func (t *table) nftApply(ctx context.Context, nftCmd []byte) error {
ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".nftApply.cgo")
defer span.End()
if t.nftHandle == nil {
handle, err := newNftHandle()
if err != nil {
return err
}
t.nftHandle = handle
}
cCmd := C.CString(string(nftCmd))
defer C.free(unsafe.Pointer(cCmd))
ret := C.nft_run_cmd_from_buffer(t.nftHandle, cCmd)
stdout := C.GoString(C.nft_ctx_get_output_buffer(t.nftHandle))
stderr := C.GoString(C.nft_ctx_get_error_buffer(t.nftHandle))
if ret != 0 {
return fmt.Errorf("libnftables: failed to apply commands (code %d), stderr: %s", int(ret), stderr)
}
log.G(ctx).WithFields(log.Fields{"stdout": stdout, "stderr": stderr}).Debug("nftables: updated via libnftables")
return nil
}
func newNftHandle() (_ *C.struct_nft_ctx, retErr error) {
handle := C.nft_ctx_new(C.NFT_CTX_DEFAULT)
if handle == nil {
return nil, errors.New("libnftables: failed to create new nft handle")
}
defer func() {
if retErr != nil {
C.nft_ctx_free(handle)
}
}()
if ret := C.nft_ctx_buffer_output(handle); ret != 0 {
return nil, fmt.Errorf("libnftables: failed to set output buffer (code %d)", int(ret))
}
if ret := C.nft_ctx_buffer_error(handle); ret != 0 {
return nil, fmt.Errorf("libnftables: failed to set error buffer (code %d)", int(ret))
}
return handle, nil
}
func (t *table) closeNftHandle() {
t.applyLock.Lock()
defer t.applyLock.Unlock()
if t.nftHandle != nil {
C.nft_ctx_free(t.nftHandle)
t.nftHandle = nil
}
}

View File

@@ -0,0 +1,66 @@
//go:build !cgo || static_build
package nftables
import (
"context"
"fmt"
"io"
"os/exec"
"strings"
"github.com/containerd/log"
"go.opentelemetry.io/otel"
)
type nftHandle = struct{}
func (t *table) nftApply(ctx context.Context, nftCmd []byte) error {
ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".nftApply.exec")
defer span.End()
cmd := exec.Command(nftPath, "-f", "-")
stdinPipe, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("getting stdin pipe for nft: %w", err)
}
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("getting stdout pipe for nft: %w", err)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("getting stderr pipe for nft: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("starting nft: %w", err)
}
if _, err := stdinPipe.Write(nftCmd); err != nil {
return fmt.Errorf("sending nft commands: %w", err)
}
if err := stdinPipe.Close(); err != nil {
return fmt.Errorf("closing nft input pipe: %w", err)
}
stdoutBuf := strings.Builder{}
if _, err := io.Copy(&stdoutBuf, stdoutPipe); err != nil {
return fmt.Errorf("reading stdout of nft: %w", err)
}
stdout := stdoutBuf.String()
stderrBuf := strings.Builder{}
if _, err := io.Copy(&stderrBuf, stderrPipe); err != nil {
return fmt.Errorf("reading stderr of nft: %w", err)
}
stderr := stderrBuf.String()
err = cmd.Wait()
if err != nil {
return fmt.Errorf("running nft: %s %w", stderr, err)
}
log.G(ctx).WithFields(log.Fields{"stdout": stdout, "stderr": stderr}).Debug("nftables: updated")
return nil
}
func (t *table) closeNftHandle() {
}

View File

@@ -48,7 +48,6 @@ import (
"context"
"errors"
"fmt"
"io"
"iter"
"os/exec"
"runtime"
@@ -59,7 +58,6 @@ import (
"text/template"
"github.com/containerd/log"
"go.opentelemetry.io/otel"
)
// Prefix for OTEL span names.
@@ -193,6 +191,7 @@ type table struct {
MustFlush bool
applyLock sync.Mutex
nftHandle nftHandle // applyLock must be held to access
}
// Table is a handle for an nftables table.
@@ -235,6 +234,16 @@ func NewTable(family Family, name string) (Table, error) {
return t, nil
}
// Close releases resources associated with the table. It does not modify or delete
// the underlying nftables table.
func (t Table) Close() error {
if t.IsValid() {
t.t.closeNftHandle()
t.t = nil
}
return nil
}
// Name returns the name of the table, or an empty string if t is not valid.
func (t Table) Name() string {
if !t.IsValid() {
@@ -245,6 +254,9 @@ func (t Table) Name() string {
// Family returns the address family of the nftables table described by [TableRef].
func (t Table) Family() Family {
if !t.IsValid() {
return ""
}
return t.t.Family
}
@@ -464,7 +476,7 @@ func (t *Table) Apply(ctx context.Context, tm Modifier) (retErr error) {
return fmt.Errorf("failed to execute template nft ruleset: %w", err)
}
if err := nftApply(ctx, buf.Bytes()); err != nil {
if err := t.t.nftApply(ctx, buf.Bytes()); err != nil {
// On error, log a line-numbered version of the generated "nft" input (because
// nft error messages refer to line numbers).
var sb strings.Builder
@@ -497,10 +509,15 @@ func (t Table) Reload(ctx context.Context) error {
if !t.IsValid() {
return errors.New("invalid table")
}
t.t.applyLock.Lock()
defer t.t.applyLock.Unlock()
return t.t.reload(ctx)
}
func (t *table) reload(ctx context.Context) error {
if !Enabled() {
return errors.New("nftables is not enabled")
}
ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"table": t.Name, "family": t.Family}))
log.G(ctx).Warn("nftables: reloading table")
@@ -510,7 +527,7 @@ func (t *table) reload(ctx context.Context) error {
return fmt.Errorf("failed to execute reload template: %w", err)
}
if err := nftApply(ctx, buf.Bytes()); err != nil {
if err := t.nftApply(ctx, buf.Bytes()); err != nil {
// On error, log a line-numbered version of the generated "nft" input (because
// nft error messages refer to line numbers).
var sb strings.Builder
@@ -1067,54 +1084,3 @@ func parseTemplate() error {
}
return nil
}
// nftApply runs the "nft" command.
func nftApply(ctx context.Context, nftCmd []byte) error {
ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".nftApply")
defer span.End()
if !Enabled() {
return errors.New("nftables is not enabled")
}
cmd := exec.Command(nftPath, "-f", "-")
stdinPipe, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("getting stdin pipe for nft: %w", err)
}
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("getting stdout pipe for nft: %w", err)
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("getting stderr pipe for nft: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("starting nft: %w", err)
}
if _, err := stdinPipe.Write(nftCmd); err != nil {
return fmt.Errorf("sending nft commands: %w", err)
}
if err := stdinPipe.Close(); err != nil {
return fmt.Errorf("closing nft input pipe: %w", err)
}
stdoutBuf := strings.Builder{}
if _, err := io.Copy(&stdoutBuf, stdoutPipe); err != nil {
return fmt.Errorf("reading stdout of nft: %w", err)
}
stdout := stdoutBuf.String()
stderrBuf := strings.Builder{}
if _, err := io.Copy(&stderrBuf, stderrPipe); err != nil {
return fmt.Errorf("reading stderr of nft: %w", err)
}
stderr := stderrBuf.String()
err = cmd.Wait()
if err != nil {
return fmt.Errorf("running nft: %s %w", stderr, err)
}
log.G(ctx).WithFields(log.Fields{"stdout": stdout, "stderr": stderr}).Debug("nftables: updated")
return nil
}

View File

@@ -54,8 +54,10 @@ func TestTable(t *testing.T) {
tbl4, err := NewTable(IPv4, "ipv4_table")
assert.NilError(t, err)
defer tbl4.Close()
tbl6, err := NewTable(IPv6, "ipv6_table")
assert.NilError(t, err)
defer tbl6.Close()
// Update nftables and check what happened.
applyAndCheck(t, tbl4, Modifier{}, t.Name()+"/created4.golden")
@@ -68,6 +70,7 @@ func TestChain(t *testing.T) {
// Create a table.
tbl, err := NewTable(IPv4, "this_is_a_table")
assert.NilError(t, err)
defer tbl.Close()
// Create a base chain.
const bcName = "this_is_a_base_chain"
@@ -122,6 +125,7 @@ func TestChainRuleGroups(t *testing.T) {
tbl, err := NewTable(IPv4, "testtable")
assert.NilError(t, err)
defer tbl.Close()
tm := Modifier{}
chainName := "testchain"
tm.Create(Chain{Name: chainName})
@@ -137,6 +141,7 @@ func TestIgnoreExist(t *testing.T) {
defer testSetup(t)()
tbl, err := NewTable(IPv4, "this_is_a_table")
assert.NilError(t, err)
defer tbl.Close()
tm := Modifier{}
// Create a chain with a single rule, add the rule again but drop the duplicate.
@@ -179,6 +184,7 @@ func TestVMap(t *testing.T) {
// Create a table.
tbl, err := NewTable(IPv6, "this_is_a_table")
assert.NilError(t, err)
defer tbl.Close()
tm := Modifier{}
// Create a verdict map.
@@ -203,8 +209,10 @@ func TestSet(t *testing.T) {
// Create v4 and v6 tables.
tbl4, err := NewTable(IPv4, "table4")
assert.NilError(t, err)
defer tbl4.Close()
tbl6, err := NewTable(IPv6, "table6")
assert.NilError(t, err)
defer tbl6.Close()
// Create a set in each table.
const set4Name = "set4"
@@ -234,6 +242,7 @@ func TestReload(t *testing.T) {
const tableName = "this_is_a_table"
tbl, err := NewTable(IPv4, tableName)
assert.NilError(t, err)
defer tbl.Close()
tm := Modifier{}
const bcName = "a_base_chain"
@@ -627,6 +636,7 @@ func TestValidation(t *testing.T) {
defer testSetup(t)()
tbl, err := NewTable(IPv4, "tablename")
assert.NilError(t, err)
defer tbl.Close()
tm := Modifier{cmds: tc.cmds}
err = tbl.Apply(context.Background(), tm)
assert.Check(t, err != nil, "expected error containing '%s'", tc.expErr)

View File

@@ -97,6 +97,7 @@ func (r *Resolver) setupNftablesNAT(ctx context.Context, laddr, ltcpaddr, resolv
if err != nil {
return err
}
defer table.Close()
tm := nftables.Modifier{}
const dnatChain = "dns-dnat"