From 6db6de2c20ae80770c89000d8621eb5f833cee79 Mon Sep 17 00:00:00 2001 From: Rob Murray Date: Mon, 22 Sep 2025 11:14:17 +0100 Subject: [PATCH] Use libnftables in dynamically linked binary Signed-off-by: Rob Murray --- Dockerfile | 2 + .../bridge/internal/nftabler/nftabler.go | 9 +++ .../bridge/internal/nftabler/nftabler_test.go | 11 ++- .../internal/nftables/nft_cgo_linux.go | 78 +++++++++++++++++++ .../internal/nftables/nft_exec_linux.go | 66 ++++++++++++++++ .../internal/nftables/nftables_linux.go | 76 +++++------------- .../internal/nftables/nftables_linux_test.go | 10 +++ daemon/libnetwork/resolver_unix.go | 1 + 8 files changed, 194 insertions(+), 59 deletions(-) create mode 100644 daemon/libnetwork/internal/nftables/nft_cgo_linux.go create mode 100644 daemon/libnetwork/internal/nftables/nft_exec_linux.go diff --git a/Dockerfile b/Dockerfile index 803c6106d1..9e97232af7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -501,6 +501,7 @@ RUN --mount=type=cache,sharing=locked,id=moby-dev-aptlib,target=/var/lib/apt \ jq \ libcap2-bin \ libnet1 \ + libnftables-dev \ libnl-3-200 \ libprotobuf-c1 \ libyajl2 \ @@ -548,6 +549,7 @@ RUN --mount=type=cache,sharing=locked,id=moby-build-aptlib,target=/var/lib/apt \ xx-apt-get install --no-install-recommends -y \ gcc \ libc6-dev \ + libnftables-dev \ libseccomp-dev \ libsystemd-dev \ pkg-config diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go index 1f90f78171..afe8c442c4 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go @@ -4,6 +4,7 @@ package nftabler import ( "context" + "errors" "github.com/containerd/log" "github.com/moby/moby/v2/daemon/libnetwork/drivers/bridge/internal/firewaller" @@ -50,6 +51,8 @@ type Nftabler struct { table6 nftables.Table } +// NewNftabler creates a new Nftabler instance, initializing the nftables tables. +// Call Close() on the returned Nftabler to release resources when done. func NewNftabler(ctx context.Context, config firewaller.Config) (*Nftabler, error) { nft := &Nftabler{config: config} @@ -72,6 +75,12 @@ func NewNftabler(ctx context.Context, config firewaller.Config) (*Nftabler, erro return nft, nil } +// Close releases resources held by the Nftabler, the underlying nftables tables +// are not modified or deleted. +func (nft *Nftabler) Close() error { + return errors.Join(nft.table4.Close(), nft.table6.Close()) +} + func (nft *Nftabler) init(ctx context.Context, family nftables.Family) (nftables.Table, error) { // Instantiate the table. table, err := nftables.NewTable(family, dockerTable) diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go index d00d5f99c4..2aa857819e 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go @@ -70,9 +70,11 @@ func TestNftabler(t *testing.T) { t.Run(fmt.Sprintf("ipv4=%v/ipv6=%v/hairpin=%v/internal=%v/icc=%v/masq=%v/snat=%v/gwm=%v/bindlh=%v/wsl2mirrored=%v", p(ipv4), p(ipv6), p(hairpin), p(internal), p(icc), p(masq), p(snat), gwmode, p(bindLocalhost), p(wsl2Mirrored)), func(t *testing.T) { // If updating results, don't run in parallel because some of the results files are shared. - if !golden.FlagUpdate() { - t.Parallel() - } + // Tests are dynamically linked, so the nftables code under test uses cgo to call libnftables + // and they run faster without t.Parallel(). Strangely, when statically linked and exec-ing + // 'nft', they run faster in parallel. + // if !golden.FlagUpdate() { t.Parallel() } + // Combine results (golden output files) where possible to: // - check params that should have no effect when made irrelevant by other params, and // - minimise the number of results files. @@ -101,7 +103,7 @@ func testNftabler(t *testing.T, tn string, config firewaller.Config, netConfig f return } out := strings.ReplaceAll(res.Combined(), "type nat hook output priority -100", "type nat hook output priority dstnat") - assert.Assert(t, res.Error) + assert.Assert(t, res.Error, out) golden.Assert(t, out, name+"__"+family+".golden") } @@ -129,6 +131,7 @@ func testNftabler(t *testing.T, tn string, config firewaller.Config, netConfig f // end of the test (after deleting per-network and per-port rules). fw, err := NewNftabler(context.Background(), config) assert.NilError(t, err) + defer fw.Close() checkResults("ip", rnWSL2Mirrored(fmt.Sprintf("%s/cleaned,hairpin=%v", tn, config.Hairpin)), config.IPv4) checkResults("ip6", fmt.Sprintf("%s/cleaned,hairpin=%v", tn, config.Hairpin), config.IPv6) diff --git a/daemon/libnetwork/internal/nftables/nft_cgo_linux.go b/daemon/libnetwork/internal/nftables/nft_cgo_linux.go new file mode 100644 index 0000000000..7bdac59001 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/nft_cgo_linux.go @@ -0,0 +1,78 @@ +//go:build cgo && !static_build + +package nftables + +import ( + "context" + "errors" + "fmt" + "unsafe" + + "github.com/containerd/log" + "go.opentelemetry.io/otel" +) + +// #cgo pkg-config: libnftables +// #cgo nocallback nft_run_cmd_from_buffer +// #cgo nocallback nft_ctx_get_output_buffer +// #cgo nocallback nft_ctx_get_error_buffer +// #include +// #include +import "C" + +type nftHandle = *C.struct_nft_ctx + +// nftApply calls libnftables to execute the nftables commands in nftCmd. +// Acquire t.applyLock before calling this function. +func (t *table) nftApply(ctx context.Context, nftCmd []byte) error { + ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".nftApply.cgo") + defer span.End() + + if t.nftHandle == nil { + handle, err := newNftHandle() + if err != nil { + return err + } + t.nftHandle = handle + } + + cCmd := C.CString(string(nftCmd)) + defer C.free(unsafe.Pointer(cCmd)) + + ret := C.nft_run_cmd_from_buffer(t.nftHandle, cCmd) + stdout := C.GoString(C.nft_ctx_get_output_buffer(t.nftHandle)) + stderr := C.GoString(C.nft_ctx_get_error_buffer(t.nftHandle)) + if ret != 0 { + return fmt.Errorf("libnftables: failed to apply commands (code %d), stderr: %s", int(ret), stderr) + } + log.G(ctx).WithFields(log.Fields{"stdout": stdout, "stderr": stderr}).Debug("nftables: updated via libnftables") + return nil +} + +func newNftHandle() (_ *C.struct_nft_ctx, retErr error) { + handle := C.nft_ctx_new(C.NFT_CTX_DEFAULT) + if handle == nil { + return nil, errors.New("libnftables: failed to create new nft handle") + } + defer func() { + if retErr != nil { + C.nft_ctx_free(handle) + } + }() + if ret := C.nft_ctx_buffer_output(handle); ret != 0 { + return nil, fmt.Errorf("libnftables: failed to set output buffer (code %d)", int(ret)) + } + if ret := C.nft_ctx_buffer_error(handle); ret != 0 { + return nil, fmt.Errorf("libnftables: failed to set error buffer (code %d)", int(ret)) + } + return handle, nil +} + +func (t *table) closeNftHandle() { + t.applyLock.Lock() + defer t.applyLock.Unlock() + if t.nftHandle != nil { + C.nft_ctx_free(t.nftHandle) + t.nftHandle = nil + } +} diff --git a/daemon/libnetwork/internal/nftables/nft_exec_linux.go b/daemon/libnetwork/internal/nftables/nft_exec_linux.go new file mode 100644 index 0000000000..87409c4f44 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/nft_exec_linux.go @@ -0,0 +1,66 @@ +//go:build !cgo || static_build + +package nftables + +import ( + "context" + "fmt" + "io" + "os/exec" + "strings" + + "github.com/containerd/log" + "go.opentelemetry.io/otel" +) + +type nftHandle = struct{} + +func (t *table) nftApply(ctx context.Context, nftCmd []byte) error { + ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".nftApply.exec") + defer span.End() + + cmd := exec.Command(nftPath, "-f", "-") + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("getting stdin pipe for nft: %w", err) + } + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("getting stdout pipe for nft: %w", err) + } + stderrPipe, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("getting stderr pipe for nft: %w", err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("starting nft: %w", err) + } + if _, err := stdinPipe.Write(nftCmd); err != nil { + return fmt.Errorf("sending nft commands: %w", err) + } + if err := stdinPipe.Close(); err != nil { + return fmt.Errorf("closing nft input pipe: %w", err) + } + + stdoutBuf := strings.Builder{} + if _, err := io.Copy(&stdoutBuf, stdoutPipe); err != nil { + return fmt.Errorf("reading stdout of nft: %w", err) + } + stdout := stdoutBuf.String() + stderrBuf := strings.Builder{} + if _, err := io.Copy(&stderrBuf, stderrPipe); err != nil { + return fmt.Errorf("reading stderr of nft: %w", err) + } + stderr := stderrBuf.String() + + err = cmd.Wait() + if err != nil { + return fmt.Errorf("running nft: %s %w", stderr, err) + } + log.G(ctx).WithFields(log.Fields{"stdout": stdout, "stderr": stderr}).Debug("nftables: updated") + return nil +} + +func (t *table) closeNftHandle() { +} diff --git a/daemon/libnetwork/internal/nftables/nftables_linux.go b/daemon/libnetwork/internal/nftables/nftables_linux.go index b2a2c9c4b3..7efdcb8ce5 100644 --- a/daemon/libnetwork/internal/nftables/nftables_linux.go +++ b/daemon/libnetwork/internal/nftables/nftables_linux.go @@ -48,7 +48,6 @@ import ( "context" "errors" "fmt" - "io" "iter" "os/exec" "runtime" @@ -59,7 +58,6 @@ import ( "text/template" "github.com/containerd/log" - "go.opentelemetry.io/otel" ) // Prefix for OTEL span names. @@ -193,6 +191,7 @@ type table struct { MustFlush bool applyLock sync.Mutex + nftHandle nftHandle // applyLock must be held to access } // Table is a handle for an nftables table. @@ -235,6 +234,16 @@ func NewTable(family Family, name string) (Table, error) { return t, nil } +// Close releases resources associated with the table. It does not modify or delete +// the underlying nftables table. +func (t Table) Close() error { + if t.IsValid() { + t.t.closeNftHandle() + t.t = nil + } + return nil +} + // Name returns the name of the table, or an empty string if t is not valid. func (t Table) Name() string { if !t.IsValid() { @@ -245,6 +254,9 @@ func (t Table) Name() string { // Family returns the address family of the nftables table described by [TableRef]. func (t Table) Family() Family { + if !t.IsValid() { + return "" + } return t.t.Family } @@ -464,7 +476,7 @@ func (t *Table) Apply(ctx context.Context, tm Modifier) (retErr error) { return fmt.Errorf("failed to execute template nft ruleset: %w", err) } - if err := nftApply(ctx, buf.Bytes()); err != nil { + if err := t.t.nftApply(ctx, buf.Bytes()); err != nil { // On error, log a line-numbered version of the generated "nft" input (because // nft error messages refer to line numbers). var sb strings.Builder @@ -497,10 +509,15 @@ func (t Table) Reload(ctx context.Context) error { if !t.IsValid() { return errors.New("invalid table") } + t.t.applyLock.Lock() + defer t.t.applyLock.Unlock() return t.t.reload(ctx) } func (t *table) reload(ctx context.Context) error { + if !Enabled() { + return errors.New("nftables is not enabled") + } ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"table": t.Name, "family": t.Family})) log.G(ctx).Warn("nftables: reloading table") @@ -510,7 +527,7 @@ func (t *table) reload(ctx context.Context) error { return fmt.Errorf("failed to execute reload template: %w", err) } - if err := nftApply(ctx, buf.Bytes()); err != nil { + if err := t.nftApply(ctx, buf.Bytes()); err != nil { // On error, log a line-numbered version of the generated "nft" input (because // nft error messages refer to line numbers). var sb strings.Builder @@ -1067,54 +1084,3 @@ func parseTemplate() error { } return nil } - -// nftApply runs the "nft" command. -func nftApply(ctx context.Context, nftCmd []byte) error { - ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".nftApply") - defer span.End() - - if !Enabled() { - return errors.New("nftables is not enabled") - } - cmd := exec.Command(nftPath, "-f", "-") - stdinPipe, err := cmd.StdinPipe() - if err != nil { - return fmt.Errorf("getting stdin pipe for nft: %w", err) - } - stdoutPipe, err := cmd.StdoutPipe() - if err != nil { - return fmt.Errorf("getting stdout pipe for nft: %w", err) - } - stderrPipe, err := cmd.StderrPipe() - if err != nil { - return fmt.Errorf("getting stderr pipe for nft: %w", err) - } - - if err := cmd.Start(); err != nil { - return fmt.Errorf("starting nft: %w", err) - } - if _, err := stdinPipe.Write(nftCmd); err != nil { - return fmt.Errorf("sending nft commands: %w", err) - } - if err := stdinPipe.Close(); err != nil { - return fmt.Errorf("closing nft input pipe: %w", err) - } - - stdoutBuf := strings.Builder{} - if _, err := io.Copy(&stdoutBuf, stdoutPipe); err != nil { - return fmt.Errorf("reading stdout of nft: %w", err) - } - stdout := stdoutBuf.String() - stderrBuf := strings.Builder{} - if _, err := io.Copy(&stderrBuf, stderrPipe); err != nil { - return fmt.Errorf("reading stderr of nft: %w", err) - } - stderr := stderrBuf.String() - - err = cmd.Wait() - if err != nil { - return fmt.Errorf("running nft: %s %w", stderr, err) - } - log.G(ctx).WithFields(log.Fields{"stdout": stdout, "stderr": stderr}).Debug("nftables: updated") - return nil -} diff --git a/daemon/libnetwork/internal/nftables/nftables_linux_test.go b/daemon/libnetwork/internal/nftables/nftables_linux_test.go index 6949ad342b..de24ff1173 100644 --- a/daemon/libnetwork/internal/nftables/nftables_linux_test.go +++ b/daemon/libnetwork/internal/nftables/nftables_linux_test.go @@ -54,8 +54,10 @@ func TestTable(t *testing.T) { tbl4, err := NewTable(IPv4, "ipv4_table") assert.NilError(t, err) + defer tbl4.Close() tbl6, err := NewTable(IPv6, "ipv6_table") assert.NilError(t, err) + defer tbl6.Close() // Update nftables and check what happened. applyAndCheck(t, tbl4, Modifier{}, t.Name()+"/created4.golden") @@ -68,6 +70,7 @@ func TestChain(t *testing.T) { // Create a table. tbl, err := NewTable(IPv4, "this_is_a_table") assert.NilError(t, err) + defer tbl.Close() // Create a base chain. const bcName = "this_is_a_base_chain" @@ -122,6 +125,7 @@ func TestChainRuleGroups(t *testing.T) { tbl, err := NewTable(IPv4, "testtable") assert.NilError(t, err) + defer tbl.Close() tm := Modifier{} chainName := "testchain" tm.Create(Chain{Name: chainName}) @@ -137,6 +141,7 @@ func TestIgnoreExist(t *testing.T) { defer testSetup(t)() tbl, err := NewTable(IPv4, "this_is_a_table") assert.NilError(t, err) + defer tbl.Close() tm := Modifier{} // Create a chain with a single rule, add the rule again but drop the duplicate. @@ -179,6 +184,7 @@ func TestVMap(t *testing.T) { // Create a table. tbl, err := NewTable(IPv6, "this_is_a_table") assert.NilError(t, err) + defer tbl.Close() tm := Modifier{} // Create a verdict map. @@ -203,8 +209,10 @@ func TestSet(t *testing.T) { // Create v4 and v6 tables. tbl4, err := NewTable(IPv4, "table4") assert.NilError(t, err) + defer tbl4.Close() tbl6, err := NewTable(IPv6, "table6") assert.NilError(t, err) + defer tbl6.Close() // Create a set in each table. const set4Name = "set4" @@ -234,6 +242,7 @@ func TestReload(t *testing.T) { const tableName = "this_is_a_table" tbl, err := NewTable(IPv4, tableName) assert.NilError(t, err) + defer tbl.Close() tm := Modifier{} const bcName = "a_base_chain" @@ -627,6 +636,7 @@ func TestValidation(t *testing.T) { defer testSetup(t)() tbl, err := NewTable(IPv4, "tablename") assert.NilError(t, err) + defer tbl.Close() tm := Modifier{cmds: tc.cmds} err = tbl.Apply(context.Background(), tm) assert.Check(t, err != nil, "expected error containing '%s'", tc.expErr) diff --git a/daemon/libnetwork/resolver_unix.go b/daemon/libnetwork/resolver_unix.go index b186da27ab..716fce3fac 100644 --- a/daemon/libnetwork/resolver_unix.go +++ b/daemon/libnetwork/resolver_unix.go @@ -97,6 +97,7 @@ func (r *Resolver) setupNftablesNAT(ctx context.Context, laddr, ltcpaddr, resolv if err != nil { return err } + defer table.Close() tm := nftables.Modifier{} const dnatChain = "dns-dnat"