diff --git a/daemon/libnetwork/portallocator/osallocator_linux.go b/daemon/libnetwork/portallocator/osallocator_linux.go index 999429eebf..bd7f323b8a 100644 --- a/daemon/libnetwork/portallocator/osallocator_linux.go +++ b/daemon/libnetwork/portallocator/osallocator_linux.go @@ -2,15 +2,19 @@ package portallocator import ( "context" + "errors" "fmt" "net" "net/netip" "os" + "runtime" "syscall" "github.com/containerd/log" "github.com/ishidawataru/sctp" "github.com/moby/moby/v2/daemon/libnetwork/types" + "golang.org/x/net/bpf" + "golang.org/x/sys/unix" ) type OSAllocator struct { @@ -27,16 +31,15 @@ func NewOSAllocator() OSAllocator { } // RequestPortsInRange reserves a port available in the range [portStart, portEnd] -// for all the specified addrs, and then try to bind those addresses to allocate -// the port from the OS. It returns the allocated port, and all the sockets -// bound, or an error if the reserved port isn't available. Callers must take -// care of closing the returned sockets. +// for all the specified addrs, and then try to bind/listen those addresses to +// allocate the port from the OS. // -// Due to the semantic of SO_REUSEADDR, the OSAllocator can't fully determine -// if a port is free when binding 0.0.0.0 or ::. If another socket is binding -// the same port, but it's not listening to it yet, the bind will succeed but a -// subsequent listen might fail. For this reason, RequestPortsInRange doesn't -// retry on failure — it's caller's responsibility. +// It returns the allocated port, and all the sockets bound, or an error if the +// reserved port isn't available. These sockets have a filter set to ensure that +// the kernel doesn't accept connections on these. Callers must take care of +// calling DetachSocketFilter once they're ready to accept connections (e.g. after +// setting up DNAT rules, and before starting the userland proxy), and they must +// take care of closing the returned sockets. // // It's safe for concurrent use. func (pa OSAllocator) RequestPortsInRange(addrs []net.IP, proto types.Protocol, portStart, portEnd int) (_ int, _ []*os.File, retErr error) { @@ -73,11 +76,11 @@ func (pa OSAllocator) RequestPortsInRange(addrs []net.IP, proto types.Protocol, var sock *os.File switch proto { case types.TCP: - sock, err = bindTCPOrUDP(addrPort, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + sock, err = listenTCP(addrPort) case types.UDP: sock, err = bindTCPOrUDP(addrPort, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) case types.SCTP: - sock, err = bindSCTP(addrPort) + sock, err = listenSCTP(addrPort) default: return 0, nil, fmt.Errorf("protocol %s not supported", proto) } @@ -101,6 +104,20 @@ func (pa OSAllocator) ReleasePorts(addrs []net.IP, proto types.Protocol, port in } } +func listenTCP(addr netip.AddrPort) (_ *os.File, retErr error) { + boundSocket, err := bindTCPOrUDP(addr, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + if err != nil { + return nil, err + } + + somaxconn := -1 // silently capped to "/proc/sys/net/core/somaxconn" + if err := syscall.Listen(int(boundSocket.Fd()), somaxconn); err != nil { + return nil, fmt.Errorf("failed to listen on tcp socket: %w", err) + } + + return boundSocket, nil +} + func bindTCPOrUDP(addr netip.AddrPort, typ int, proto types.Protocol) (_ *os.File, retErr error) { var domain int var sa syscall.Sockaddr @@ -128,6 +145,16 @@ func bindTCPOrUDP(addr netip.AddrPort, typ int, proto types.Protocol) (_ *os.Fil } } + // We need to listen to make sure that the port is free, and no other process is racing against us to acquire this + // port. But listening means that connections could be accepted before DNAT rules are inserted, and they'd never + // reach the container. To avoid this, set a socket filter to drop all connections — TCP SYNs will be + // re-transmitted anyway. Callers must call DetachSocketFilter. + // + // Set the socket filter _before_ binding the socket to make sure that no UDP datagrams will fill the queue. + if err := setSocketFilter(sd); err != nil { + return nil, fmt.Errorf("failed to set drop packets filter for %s/%s: %w", addr, proto, err) + } + if domain == syscall.AF_INET6 { syscall.SetsockoptInt(sd, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 1) } @@ -158,8 +185,21 @@ func bindTCPOrUDP(addr netip.AddrPort, typ int, proto types.Protocol) (_ *os.Fil return boundSocket, nil } -// bindSCTP is based on sctp.ListenSCTP. The socket is created and bound, but -// does not start listening. +// listenSCTP is based on sctp.ListenSCTP. +func listenSCTP(addr netip.AddrPort) (_ *os.File, retErr error) { + boundSocket, err := bindSCTP(addr) + if err != nil { + return nil, err + } + + somaxconn := -1 // silently capped to "/proc/sys/net/core/somaxconn" + if err := syscall.Listen(int(boundSocket.Fd()), somaxconn); err != nil { + return nil, fmt.Errorf("failed to listen on sctp socket: %w", err) + } + + return boundSocket, nil +} + func bindSCTP(addr netip.AddrPort) (_ *os.File, retErr error) { domain := syscall.AF_INET if addr.Addr().Unmap().Is6() { @@ -190,9 +230,58 @@ func bindSCTP(addr netip.AddrPort) (_ *os.File, retErr error) { return nil, fmt.Errorf("failed to bind host port %s/sctp: %w", addr, err) } + // We need to listen to make sure that the port is free, and no other process is racing against us to acquire this + // port. But listening means that connections could be accepted before DNAT rules are inserted, and they'd never + // reach the container. To avoid this, set a socket filter to drop all connections — SCTP handshake will be + // re-transmitted anyway. Callers must call DetachSocketFilter. + if err := setSocketFilter(sd); err != nil { + return nil, fmt.Errorf("failed to set drop packets filter for %s/sctp: %w", addr, err) + } + boundSocket := os.NewFile(uintptr(sd), "listener") if boundSocket == nil { return nil, fmt.Errorf("failed to convert socket %s/sctp", addr) } return boundSocket, nil } + +// DetachSocketFilter removes the BPF filter set during port allocation to prevent the kernel from accepting connections +// before DNAT rules are inserted. +func DetachSocketFilter(f *os.File) error { + return unix.SetsockoptInt(int(f.Fd()), syscall.SOL_SOCKET, syscall.SO_DETACH_FILTER, 0 /* ignored */) +} + +// setSocketFilter sets a cBPF program on socket sd to drop all packets. To start receiving packets on this socket, +// callers must call DetachSocketFilter. +func setSocketFilter(sd int) error { + asm, err := bpf.Assemble([]bpf.Instruction{ + // A cBPF program attached to a socket with SO_ATTACH_FILTER and + // returning 0 tells the kernel to drop all packets. + bpf.RetConstant{Val: 0x0}, + }) + if err != nil { + // (bpf.RetConstant).Assemble() doesn't return an error, so this should + // be unreachable code. + return fmt.Errorf("attaching socket filter: %w", err) + } + // Make sure the asm slice is not GC'd before setsockopt is called + defer runtime.KeepAlive(asm) + + if len(asm) == 0 { + return errors.New("attaching socket filter: empty BPF program") + } + + f := make([]unix.SockFilter, len(asm)) + for i := range asm { + f[i] = unix.SockFilter{ + Code: asm[i].Op, + Jt: asm[i].Jt, + Jf: asm[i].Jf, + K: asm[i].K, + } + } + return unix.SetsockoptSockFprog(sd, syscall.SOL_SOCKET, syscall.SO_ATTACH_FILTER, &unix.SockFprog{ + Len: uint16(len(f)), + Filter: &f[0], + }) +} diff --git a/daemon/libnetwork/portallocator/osallocator_linux_test.go b/daemon/libnetwork/portallocator/osallocator_linux_test.go index efeda1a447..662a1ead2a 100644 --- a/daemon/libnetwork/portallocator/osallocator_linux_test.go +++ b/daemon/libnetwork/portallocator/osallocator_linux_test.go @@ -1,12 +1,18 @@ package portallocator import ( + "fmt" "io" "net" "net/netip" "os" + "os/exec" + "strconv" + "strings" + "sync/atomic" "syscall" "testing" + "time" "github.com/ishidawataru/sctp" "github.com/moby/moby/v2/daemon/libnetwork/netutils" @@ -228,3 +234,159 @@ func TestOnlyOneSocketBindsUDPPort(t *testing.T) { assert.ErrorContains(t, err, "failed to bind host port") assert.Equal(t, len(socks), 0) } + +// TestSocketBacklogEqualsSomaxconn verifies that the listen syscall made for +// TCP / SCTP sockets has a backlog size equal to somaxconn. +func TestSocketBacklogEqualsSomaxconn(t *testing.T) { + // Retrieve and parse sysctl net.core.somaxconn + somaxconnSysctl, err := os.ReadFile("/proc/sys/net/core/somaxconn") + assert.NilError(t, err) + somaxconn, err := strconv.Atoi(strings.TrimSpace(string(somaxconnSysctl))) + assert.NilError(t, err) + + // UDP isn't included in the list of protos to test because it doesn't have a backlog, and the ss Send-Q column + // reports memory allocation instead of the socket's max backlog size (unlike TCP and SCTP). + // + // This is where the kernel writes the max backlog size into the sk struct: https://elixir.bootlin.com/linux/v6.16/source/net/ipv4/af_inet.c#L199 + // + // And here's where the kernel writes the 'idiag_wqueue' field used by ss: + // + // - For TCP: https://elixir.bootlin.com/linux/v6.16/source/net/ipv4/tcp_diag.c#L25 + // - For UDP: https://elixir.bootlin.com/linux/v6.16/source/net/ipv4/udp_diag.c#L163 + // - For SCTP: https://elixir.bootlin.com/linux/v6.16/source/net/sctp/diag.c#L414 + for _, proto := range []types.Protocol{ + types.TCP, + types.SCTP, + } { + t.Run(proto.String(), func(t *testing.T) { + // Allocate an ephemeral port using the OSAllocator. + alloc := NewOSAllocator() + port, socks, err := alloc.RequestPortsInRange([]net.IP{net.IPv4zero}, proto, 0, 0) + assert.NilError(t, err) + defer closeSocks(t, socks) + + // 'ss' output looks like that: + // + // Netid State Recv-Q Send-Q Local Address:Port Peer Address:Port Process + // tcp LISTEN 0 4096 0.0.0.0:32768 0.0.0.0:* + // + // The max backlog size ('idiag_wqueue' field of 'struct inet_diag_msg' in the kernel) is the 4th field in + // the output. + out, err := exec.Command("ss", "-Stl", "sport", "=", fmt.Sprintf("inet:%d", port)).Output() + assert.NilError(t, err) + + t.Logf("ss output:\n" + string(out)) + + lines := strings.Split(string(out), "\n") + assert.Assert(t, len(lines) >= 2) + + fields := strings.Fields(lines[1]) + assert.Equal(t, len(fields), 6) + + backlog, err := strconv.Atoi(fields[3]) + assert.NilError(t, err) + + assert.Equal(t, fields[4], "0.0.0.0:"+strconv.Itoa(port)) + assert.Equal(t, backlog, somaxconn, "socket backlog should be equal to net.core.somaxconn") + }) + } +} + +// TestPacketsAreDroppedUntilDetachSocketFilter tests that SYN packets are +// dropped until DetachSocketFilter is called on the socket. +func TestPacketsAreDroppedUntilDetachSocketFilter(t *testing.T) { + const port = 61100 + addr := net.ParseIP("127.0.0.1") + + var detached atomic.Bool + dialCh, readCh := make(chan error), make(chan error) + + alloc := NewOSAllocator() + _, socks, err := alloc.RequestPortsInRange([]net.IP{addr}, types.TCP, port, port) + assert.NilError(t, err) + assert.Check(t, len(socks) > 0) + + // Start a goroutine that attempts to connect to a listening socket. It'll send SYN packets until + // DetachSocketFilter is called. If no filter is attached, the connection will succeed immediately, and it'll send + // a payload of 0x0 (or the call to DetachSocketFilter will fail with an error). When the filter is detached, it'll + // send a payload of 0x1, which will be read by the other goroutine. + go func() { + defer close(dialCh) + + c, err := net.Dial("tcp", net.JoinHostPort(addr.String(), strconv.Itoa(port))) + if err != nil { + dialCh <- fmt.Errorf("net.Dial: %w", err) + return + } + defer c.Close() + + payload := []byte{0x0} + if detached.Load() { + payload = []byte{0x1} + } + + n, err := c.Write(payload) + if err != nil { + dialCh <- fmt.Errorf("c.Write: %w", err) + return + } + if n != len(payload) { + dialCh <- fmt.Errorf("expected to write %d bytes, but wrote %d", len(payload), n) + } + }() + + // Start a goroutine that accepts a connection on the listening socket created by RequestPortsInRange, and reads + // the payload sent by the 1st goroutine. It should not receive any new connection until DetachSocketFilter is + // called on the socket. + go func() { + defer close(readCh) + + // net.FileListener dup's the fd, so DetachSocketFilter will have no effect. Use raw syscalls instead. + sd := int(socks[0].Fd()) + + var err error + connfd, _, err := syscall.Accept(sd) + if err != nil { + readCh <- fmt.Errorf("syscall.Accept: %w", err) + return + } + + payload := make([]byte, 1) + n, err := syscall.Read(connfd, payload) + if err != nil { + readCh <- fmt.Errorf("c.Read: %w", err) + return + } + if n != 1 { + readCh <- fmt.Errorf("expected to read 1 byte, but read %d", n) + return + } + + if payload[0] != 0x1 { + readCh <- fmt.Errorf("expected payload 0x1, but got %x", payload[0]) + } + }() + + // Sleep for a bit to make sure that both goroutines were scheduled. + time.Sleep(500 * time.Millisecond) + + detached.Store(true) + err = DetachSocketFilter(socks[0]) + assert.NilError(t, err) + + var dialStopped, readStopped bool + for { + if dialStopped && readStopped { + return + } + + select { + case err, ok := <-dialCh: + dialStopped = !ok + assert.NilError(t, err) + case err, ok := <-readCh: + readStopped = !ok + assert.NilError(t, err) + } + } +} diff --git a/daemon/libnetwork/portmappers/nat/mapper_linux.go b/daemon/libnetwork/portmappers/nat/mapper_linux.go index bc40aabf0e..64aa7b44be 100644 --- a/daemon/libnetwork/portmappers/nat/mapper_linux.go +++ b/daemon/libnetwork/portmappers/nat/mapper_linux.go @@ -8,7 +8,6 @@ import ( "net/netip" "os" "strconv" - "syscall" "github.com/containerd/log" "github.com/moby/moby/v2/daemon/libnetwork/internal/rlkclient" @@ -105,6 +104,9 @@ func (pm PortMapper) MapPorts(ctx context.Context, cfg []portmapperapi.PortBindi if bindings[i].BoundSocket == nil || bindings[i].RootlesskitUnsupported || bindings[i].StopProxy != nil { continue } + if err := portallocator.DetachSocketFilter(bindings[i].BoundSocket); err != nil { + return nil, fmt.Errorf("failed to detach socket filter for port mapping %s: %w", bindings[i].PortBinding, err) + } var err error bindings[i].StopProxy, err = pm.startProxy( bindings[i].ChildPortBinding(), bindings[i].BoundSocket, @@ -226,17 +228,6 @@ func (pm PortMapper) attemptBindHostPorts( if err := fwn.AddPorts(ctx, mergeChildHostIPs(res)); err != nil { return nil, err } - // Now the firewall rules are set up, it's safe to listen on the socket. (Listening - // earlier could result in dropped connections if the proxy becomes unreachable due - // to NAT rules sending packets directly to the container.) - // - // If not starting the proxy, nothing will ever accept a connection on the - // socket. Listen here anyway because SO_REUSEADDR is set, so bind() won't notice - // the problem if a port's bound to both INADDR_ANY and a specific address. (Also - // so the binding shows up in "netstat -at".) - if err := listenBoundPorts(res, pm.enableProxy); err != nil { - return nil, err - } return res, nil } @@ -297,29 +288,3 @@ func configPortDriver(ctx context.Context, pbs []portmapperapi.PortBinding, pdc } return nil } - -func listenBoundPorts(pbs []portmapperapi.PortBinding, proxyEnabled bool) error { - for i := range pbs { - if pbs[i].BoundSocket == nil || pbs[i].RootlesskitUnsupported || pbs[i].Proto == types.UDP { - continue - } - rc, err := pbs[i].BoundSocket.SyscallConn() - if err != nil { - return fmt.Errorf("raw conn not available on %d socket: %w", pbs[i].Proto, err) - } - if errC := rc.Control(func(fd uintptr) { - somaxconn := 0 - // SCTP sockets do not support somaxconn=0 - if proxyEnabled || pbs[i].Proto == types.SCTP { - somaxconn = -1 // silently capped to "/proc/sys/net/core/somaxconn" - } - err = syscall.Listen(int(fd), somaxconn) - }); errC != nil { - return fmt.Errorf("failed to Control %s socket: %w", pbs[i].Proto, err) - } - if err != nil { - return fmt.Errorf("failed to listen on %s socket: %w", pbs[i].Proto, err) - } - } - return nil -}