portallocator: always check for ports allocated for 0.0.0.0/::

We set SO_REUSEADDR on sockets used for host port mappings by
docker-proxy - which means it's possible to bind the same port
on a specific address as well as 0.0.0.0/::.

For TCP sockets, an error is raised when listen() is called on
both sockets - and the port allocator will be called again to
avoid the clash (if the port was allocated from a range, otherwise
the container will just fail to start).

But, for UDP sockets, there's no listen() - so take more care
to avoid the clash in the portallocator.

The port allocator keeps a set of allocated ports for each of
the host IP addresses it's seen, including 0.0.0.0/::. So, if a
mapping to 0.0.0.0/:: is requested, find a port that's free in
the range for each of the known IP addresses (but still only
mark it as allocated against 0.0.0.0/::). And, if a port is
requested for specific host addresses, make sure it's also
free in the corresponding 0.0.0.0/:: set (but only mark it as
allocated against the specific addresses - because the same
port can be allocated against a different specific address).

Signed-off-by: Rob Murray <rob.murray@docker.com>
This commit is contained in:
Rob Murray
2025-05-22 17:55:13 +01:00
parent ae2fc2ddd1
commit d6620915db
3 changed files with 158 additions and 60 deletions

View File

@@ -1286,21 +1286,42 @@ func TestSkipRawRules(t *testing.T) {
// Regression test for https://github.com/docker/compose/issues/12846
func TestMixAnyWithSpecificHostAddrs(t *testing.T) {
ctx := setupTest(t)
// Start a new daemon, so the port allocator will start with new/empty ephemeral port ranges,
// making a clash more likely.
d := daemon.New(t)
d.StartWithBusybox(ctx, t)
defer d.Stop(t)
c := d.NewClientT(t)
defer c.Close()
ctrId := container.Run(ctx, t, c,
container.WithExposedPorts("80/tcp", "81/tcp", "82/tcp"),
container.WithPortMap(nat.PortMap{
"81/tcp": {{}},
"82/tcp": {{}},
"80/tcp": {{HostIP: "127.0.0.1"}},
}),
)
defer c.ContainerRemove(ctx, ctrId, containertypes.RemoveOptions{Force: true})
for _, proto := range []string{"tcp", "udp"} {
t.Run(proto, func(t *testing.T) {
// Start a new daemon, so the port allocator will start with new/empty ephemeral port ranges,
// making a clash more likely.
d := daemon.New(t)
d.StartWithBusybox(ctx, t)
defer d.Stop(t)
c := d.NewClientT(t)
defer c.Close()
ctrId := container.Run(ctx, t, c,
container.WithExposedPorts("80/"+proto, "81/"+proto, "82/"+proto),
container.WithPortMap(nat.PortMap{
nat.Port("81/" + proto): {{}},
nat.Port("82/" + proto): {{}},
nat.Port("80/" + proto): {{HostIP: "127.0.0.1"}},
}),
)
defer c.ContainerRemove(ctx, ctrId, containertypes.RemoveOptions{Force: true})
insp := container.Inspect(ctx, t, c, ctrId)
hostPorts := map[string]struct{}{}
for cp, hps := range insp.NetworkSettings.Ports {
// Check each of the container ports is mapped to a different host port.
p := hps[0].HostPort
if _, ok := hostPorts[p]; ok {
t.Errorf("host port %s is mapped to different container ports: %v", p, insp.NetworkSettings.Ports)
}
hostPorts[p] = struct{}{}
// For this container port, check the same host port is mapped for each host address (0.0.0.0 and ::).
for _, hp := range hps {
assert.Check(t, p == hp.HostPort, "container port %d is mapped to different host ports: %v", cp, hps)
}
}
})
}
}

View File

@@ -1,6 +1,3 @@
// FIXME(thaJeztah): remove once we are a module; the go:build directive prevents go from downgrading language version to go1.16:
//go:build go1.23
package portallocator
import (
@@ -9,7 +6,6 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"sync"
"github.com/containerd/log"
@@ -78,16 +74,36 @@ func Get() *PortAllocator {
}
func newInstance() *PortAllocator {
start, end, err := getDynamicPortRange()
begin, end := dynamicPortRange()
return &PortAllocator{
ipMap: makeIpMapping(begin, end),
defaultIP: net.IPv4zero,
begin: begin,
end: end,
}
}
func dynamicPortRange() (start, end int) {
begin, end, err := getDynamicPortRange()
if err != nil {
log.G(context.TODO()).WithError(err).Infof("falling back to default port range %d-%d", defaultPortRangeStart, defaultPortRangeEnd)
start, end = defaultPortRangeStart, defaultPortRangeEnd
return defaultPortRangeStart, defaultPortRangeEnd
}
return &PortAllocator{
ipMap: ipMapping{},
defaultIP: net.IPv4zero,
begin: start,
end: end,
return begin, end
}
func makeIpMapping(begin, end int) ipMapping {
return ipMapping{
netip.IPv4Unspecified(): makeProtoMap(begin, end),
netip.IPv6Unspecified(): makeProtoMap(begin, end),
}
}
func makeProtoMap(begin, end int) protoMap {
return protoMap{
"tcp": newPortMap(begin, end),
"udp": newPortMap(begin, end),
"sctp": newPortMap(begin, end),
}
}
@@ -120,44 +136,82 @@ func (p *PortAllocator) RequestPortsInRange(ips []net.IP, proto string, portStar
if proto != "tcp" && proto != "udp" && proto != "sctp" {
return 0, errUnknownProtocol
}
if portStart != 0 || portEnd != 0 {
// Validate custom port-range
if portStart == 0 || portEnd == 0 || portEnd < portStart {
return 0, fmt.Errorf("invalid port range: %d-%d", portStart, portEnd)
}
}
if len(ips) == 0 {
return 0, fmt.Errorf("no IP addresses specified")
}
p.mutex.Lock()
defer p.mutex.Unlock()
// Make sure there are maps for each ip address.
pMaps := make([]*portMap, len(ips))
for i, ip := range ips {
// Collect the portMap for the required proto and each of the IP addresses.
// If there's a new IP address, create portMap objects for each of the protocols
// and collect the one that's needed for this request.
// Mark these portMap objects as needing port allocations.
type portMapRef struct {
portMap *portMap
allocate bool
}
ipToPortMapRef := map[netip.Addr]*portMapRef{}
var ips4, ips6 bool
for _, ip := range ips {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return 0, fmt.Errorf("invalid IP address: %s", ip)
}
addr = addr.Unmap()
if _, ok := p.ipMap[addr]; !ok {
p.ipMap[addr] = protoMap{
"tcp": newPortMap(p.begin, p.end),
"udp": newPortMap(p.begin, p.end),
"sctp": newPortMap(p.begin, p.end),
}
if addr.Is4() {
ips4 = true
} else {
ips6 = true
}
// Make sure addr -> protoMap[proto] -> portMap exists.
if _, ok := p.ipMap[addr]; !ok {
p.ipMap[addr] = makeProtoMap(p.begin, p.end)
}
// Remember the protoMap[proto] portMap, it needs the port allocation.
ipToPortMapRef[addr] = &portMapRef{
portMap: p.ipMap[addr][proto],
allocate: true,
}
}
// If ips includes an unspecified address, the port needs to be free in all ipMaps
// for that address family. Otherwise, the port needs only needs to be free in the
// per-address maps for ips, and the map for 0.0.0.0/::.
//
// Collect the additional portMaps where the port needs to be free, but
// don't mark them as needing port allocation.
for _, unspecAddr := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
if _, ok := ipToPortMapRef[unspecAddr]; ok {
for addr, ipm := range p.ipMap {
if unspecAddr.Is4() == addr.Is4() {
if _, ok := ipToPortMapRef[addr]; !ok {
ipToPortMapRef[addr] = &portMapRef{portMap: ipm[proto]}
}
}
}
} else if (unspecAddr.Is4() && ips4) || (unspecAddr.Is6() && ips6) {
ipToPortMapRef[unspecAddr] = &portMapRef{portMap: p.ipMap[unspecAddr][proto]}
}
pMaps[i] = p.ipMap[addr][proto]
}
// Handle a request for a specific port.
if portStart > 0 && portStart == portEnd {
for i, pMap := range pMaps {
if _, allocated := pMap.p[portStart]; allocated {
return 0, alreadyAllocatedErr{ip: ips[i].String(), port: portStart}
for addr, pMap := range ipToPortMapRef {
if _, allocated := pMap.portMap.p[portStart]; allocated {
return 0, alreadyAllocatedErr{ip: addr.String(), port: portStart}
}
}
for _, pMap := range pMaps {
pMap.p[portStart] = struct{}{}
for _, pMap := range ipToPortMapRef {
if pMap.allocate {
pMap.portMap.p[portStart] = struct{}{}
}
}
return portStart, nil
}
@@ -165,27 +219,37 @@ func (p *PortAllocator) RequestPortsInRange(ips []net.IP, proto string, portStar
// Handle a request for a port range.
// Create/fetch ranges for each portMap.
pRanges := make([]*portRange, len(pMaps))
for i, pMap := range pMaps {
pRanges[i] = pMap.getPortRange(portStart, portEnd)
pRanges := map[netip.Addr]*portRange{}
for addr, pMap := range ipToPortMapRef {
pRanges[addr] = pMap.portMap.getPortRange(portStart, portEnd)
}
// Starting after the last port allocated for the first address, search
// Arbitrarily starting after the last port allocated for the first address, search
// for a port that's available in all ranges.
port := pRanges[0].last
for i := pRanges[0].begin; i <= pRanges[0].end; i++ {
firstAddr, _ := netip.AddrFromSlice(ips[0])
firstRange := pRanges[firstAddr.Unmap()]
port := firstRange.last
for i := firstRange.begin; i <= firstRange.end; i++ {
port++
if port > pRanges[0].end {
port = pRanges[0].begin
if port > firstRange.end {
port = firstRange.begin
}
if !slices.ContainsFunc(pMaps, func(pMap *portMap) bool {
_, allocated := pMap.p[port]
return allocated
}) {
for pi, pMap := range pMaps {
pMap.p[port] = struct{}{}
pRanges[pi].last = port
portAlreadyAllocated := func() bool {
for _, pMap := range ipToPortMapRef {
if _, ok := pMap.portMap.p[port]; ok {
return true
}
}
return false
}
if !portAlreadyAllocated() {
for addr, pMap := range ipToPortMapRef {
if pMap.allocate {
pMap.portMap.p[port] = struct{}{}
pRanges[addr].last = port
}
}
return port, nil
}
@@ -214,8 +278,9 @@ func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) {
// ReleaseAll releases all ports for all ips.
func (p *PortAllocator) ReleaseAll() {
begin, end := dynamicPortRange()
p.mutex.Lock()
p.ipMap = ipMapping{}
p.ipMap = makeIpMapping(begin, end)
p.mutex.Unlock()
}

View File

@@ -319,7 +319,7 @@ func TestRequestPortForMultipleIPs(t *testing.T) {
// Same single-port range, expect an error.
_, err = p.RequestPortsInRange(addrs, "tcp", 10000, 10000)
assert.Check(t, is.Error(err, "Bind for 127.0.0.1:10000 failed: port is already allocated"))
assert.Check(t, is.ErrorContains(err, "port is already allocated"))
// Release the port from one address.
p.ReleasePort(addrs[0], "tcp", 10000)
@@ -344,3 +344,15 @@ func TestRequestPortForMultipleIPs(t *testing.T) {
assert.Check(t, is.Equal(port, i))
}
}
func TestMixUnspecAndSpecificAddrs(t *testing.T) {
p := newInstance()
port, err := p.RequestPort(net.IPv4(127, 0, 0, 1), "udp", 0)
assert.Check(t, err)
assert.Check(t, is.Equal(port, p.begin))
port, err = p.RequestPort(net.IPv4zero, "udp", 0)
assert.Check(t, err)
assert.Check(t, is.Equal(port, p.begin+1))
}