libnet: remove struct endpointCnt

endpointCnt is a refcounter used to track how many endpoints use a
network, and how many networks references a config-only network. It's
stored separately from the network.

This is only used to determine if a network can be removed.

This commit removes the `endpointCnt` struct and all its references. The
refcounter is replaced by two lookups in the newly introduced `networks`
and `endpoints` caches added to the `Controller`.

Signed-off-by: Albin Kerouanton <albinker@gmail.com>
This commit is contained in:
Albin Kerouanton
2025-04-03 10:05:53 +02:00
parent d377cd3810
commit 51d7f95c4b
10 changed files with 127 additions and 288 deletions

View File

@@ -0,0 +1,14 @@
// FIXME(thaJeztah): remove once we are a module; the go:build directive prevents go from downgrading language version to go1.16:
//go:build go1.22
package maputil
func FilterValues[K comparable, V any](in map[K]V, fn func(V) bool) []V {
var out []V
for _, v := range in {
if fn(v) {
out = append(out, v)
}
}
return out
}

View File

@@ -554,8 +554,6 @@ func (c *Controller) NewNetwork(networkType, name string, id string, options ...
var (
caps driverapi.Capability
err error
skipCfgEpCount bool
)
// Reset network types, force local scope and skip allocation and
@@ -590,13 +588,6 @@ func (c *Controller) NewNetwork(networkType, name string, id string, options ...
if err := configNetwork.applyConfigurationTo(nw); err != nil {
return nil, types.InternalErrorf("Failed to apply configuration: %v", err)
}
defer func() {
if retErr == nil && !skipCfgEpCount {
if err := configNetwork.getEpCnt().IncEndpointCnt(); err != nil {
log.G(context.TODO()).Warnf("Failed to update reference count for configuration network %q on creation of network %q: %v", configNetwork.Name(), nw.name, err)
}
}
}()
}
// At this point the network scope is still unknown if not set by user
@@ -662,11 +653,7 @@ func (c *Controller) NewNetwork(networkType, name string, id string, options ...
//
// To cut a long story short: if this broke anything, you know who to blame :)
if err := c.addNetwork(nw); err != nil {
if _, ok := err.(types.MaskableError); ok { //nolint:gosimple
// This error can be ignored and set this boolean
// value to skip a refcount increment for configOnly networks
skipCfgEpCount = true
} else {
if _, ok := err.(types.MaskableError); !ok { //nolint:gosimple
return nil, err
}
}
@@ -694,22 +681,6 @@ func (c *Controller) NewNetwork(networkType, name string, id string, options ...
}
addToStore:
// First store the endpoint count, then the network. To avoid to
// end up with a datastore containing a network and not an epCnt,
// in case of an ungraceful shutdown during this function call.
epCnt := &endpointCnt{n: nw}
if err := c.updateToStore(context.TODO(), epCnt); err != nil {
return nil, err
}
defer func() {
if retErr != nil {
if err := c.deleteFromStore(epCnt); err != nil {
log.G(context.TODO()).Warnf("could not rollback from store, epCnt %v on failure (%v): %v", epCnt, retErr, err)
}
}
}()
nw.epCnt = epCnt
if err := c.storeNetwork(context.TODO(), nw); err != nil {
return nil, err
}

View File

@@ -1047,10 +1047,6 @@ func (ep *Endpoint) Delete(ctx context.Context, force bool) error {
ep.releaseAddress()
if err := n.getEpCnt().DecEndpointCnt(); err != nil {
log.G(ctx).Warnf("failed to decrement endpoint count for ep %s: %v", ep.ID(), err)
}
return nil
}
@@ -1399,20 +1395,6 @@ func (c *Controller) cleanupLocalEndpoints() error {
log.G(context.TODO()).Warnf("Could not delete local endpoint %s during endpoint cleanup: %v", ep.name, err)
}
}
epl, err = n.getEndpointsFromStore()
if err != nil {
log.G(context.TODO()).Warnf("Could not get list of endpoints in network %s for count update: %v", n.name, err)
continue
}
epCnt := n.getEpCnt().EndpointCnt()
if epCnt != uint64(len(epl)) {
log.G(context.TODO()).Infof("Fixing inconsistent endpoint_cnt for network %s. Expected=%d, Actual=%d", n.name, len(epl), epCnt)
if err := n.getEpCnt().setCnt(uint64(len(epl))); err != nil {
log.G(context.TODO()).WithField("network", n.name).WithError(err).Warn("Error while fixing inconsistent endpoint_cnt for network")
}
}
}
return nil

View File

@@ -1,173 +0,0 @@
package libnetwork
import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/docker/docker/libnetwork/datastore"
)
type endpointCnt struct {
n *Network
Count uint64
dbIndex uint64
dbExists bool
sync.Mutex
}
const epCntKeyPrefix = "endpoint_count"
func (ec *endpointCnt) Key() []string {
ec.Lock()
defer ec.Unlock()
return []string{epCntKeyPrefix, ec.n.id}
}
func (ec *endpointCnt) KeyPrefix() []string {
ec.Lock()
defer ec.Unlock()
return []string{epCntKeyPrefix, ec.n.id}
}
func (ec *endpointCnt) Value() []byte {
ec.Lock()
defer ec.Unlock()
b, err := json.Marshal(ec)
if err != nil {
return nil
}
return b
}
func (ec *endpointCnt) SetValue(value []byte) error {
ec.Lock()
defer ec.Unlock()
return json.Unmarshal(value, &ec)
}
func (ec *endpointCnt) Index() uint64 {
ec.Lock()
defer ec.Unlock()
return ec.dbIndex
}
func (ec *endpointCnt) SetIndex(index uint64) {
ec.Lock()
ec.dbIndex = index
ec.dbExists = true
ec.Unlock()
}
func (ec *endpointCnt) Exists() bool {
ec.Lock()
defer ec.Unlock()
return ec.dbExists
}
func (ec *endpointCnt) Skip() bool {
ec.Lock()
defer ec.Unlock()
return !ec.n.persist
}
func (ec *endpointCnt) New() datastore.KVObject {
ec.Lock()
defer ec.Unlock()
return &endpointCnt{
n: ec.n,
}
}
func (ec *endpointCnt) CopyTo(o datastore.KVObject) error {
ec.Lock()
defer ec.Unlock()
dstEc := o.(*endpointCnt)
dstEc.n = ec.n
dstEc.Count = ec.Count
dstEc.dbExists = ec.dbExists
dstEc.dbIndex = ec.dbIndex
return nil
}
func (ec *endpointCnt) EndpointCnt() uint64 {
ec.Lock()
defer ec.Unlock()
return ec.Count
}
func (ec *endpointCnt) updateStore() error {
c := ec.n.getController()
// make a copy of count and n to avoid being overwritten by store.GetObject
count := ec.EndpointCnt()
n := ec.n
for {
if err := c.updateToStore(context.TODO(), ec); err == nil || err != datastore.ErrKeyModified {
return err
}
if err := c.store.GetObject(ec); err != nil {
return fmt.Errorf("could not update the kvobject to latest on endpoint count update: %v", err)
}
ec.Lock()
ec.Count = count
ec.n = n
ec.Unlock()
}
}
func (ec *endpointCnt) setCnt(cnt uint64) error {
ec.Lock()
ec.Count = cnt
ec.Unlock()
return ec.updateStore()
}
func (ec *endpointCnt) atomicIncDecEpCnt(inc bool) error {
store := ec.n.getController().store
tmp := &endpointCnt{n: ec.n}
if err := store.GetObject(tmp); err != nil {
return err
}
retry:
ec.Lock()
if inc {
ec.Count++
} else {
if ec.Count > 0 {
ec.Count--
}
}
ec.Unlock()
if err := ec.n.getController().updateToStore(context.TODO(), ec); err != nil {
if err == datastore.ErrKeyModified {
if err := store.GetObject(ec); err != nil {
return fmt.Errorf("could not update the kvobject to latest when trying to atomic add endpoint count: %v", err)
}
goto retry
}
return err
}
return nil
}
func (ec *endpointCnt) IncEndpointCnt() error {
return ec.atomicIncDecEpCnt(true)
}
func (ec *endpointCnt) DecEndpointCnt() error {
return ec.atomicIncDecEpCnt(false)
}

View File

@@ -1,6 +1,13 @@
// FIXME(thaJeztah): remove once we are a module; the go:build directive prevents go from downgrading language version to go1.16:
//go:build go1.22
package libnetwork
import "context"
import (
"context"
"github.com/docker/docker/internal/maputil"
)
// storeEndpoint inserts or updates the endpoint in the store and the in-memory
// cache maintained by the Controller.
@@ -39,3 +46,21 @@ func (c *Controller) cacheEndpoint(ep *Endpoint) {
defer c.endpointsMu.Unlock()
c.endpoints[ep.id] = ep
}
// findEndpoints looks for all endpoints matching the filter from the in-memory
// cache of endpoints maintained by the Controller.
//
// This method is thread-safe, but do not use it unless you're sure your code
// uses the returned endpoints in thread-safe way (see the comment on
// Controller.endpoints).
func (c *Controller) findEndpoints(filter func(ep *Endpoint) bool) []*Endpoint {
c.endpointsMu.Lock()
defer c.endpointsMu.Unlock()
return maputil.FilterValues(c.endpoints, filter)
}
func filterEndpointByNetworkId(expected string) func(ep *Endpoint) bool {
return func(ep *Endpoint) bool {
return ep.network != nil && ep.network.id == expected
}
}

View File

@@ -25,10 +25,21 @@ func TestEndpointStore(t *testing.T) {
err = c.storeEndpoint(context.Background(), ep2)
assert.NilError(t, err)
// Check that we can find both endpoints
found := c.findEndpoints(filterEndpointByNetworkId("testNetwork"))
assert.Equal(t, len(found), 2)
assert.Equal(t, found[0], ep1)
assert.Equal(t, found[1], ep2)
// Delete the first endpoint
err = c.deleteStoredEndpoint(ep1)
assert.NilError(t, err)
// Check that we can only find the second endpoint
found = c.findEndpoints(filterEndpointByNetworkId("testNetwork"))
assert.Equal(t, len(found), 1)
assert.Equal(t, found[0], ep2)
// Store the second endpoint again
err = c.storeEndpoint(context.Background(), ep2)
assert.NilError(t, err)

View File

@@ -190,7 +190,6 @@ type Network struct {
ipamV6Info []*IpamInfo
enableIPv4 bool
enableIPv6 bool
epCnt *endpointCnt
generic options.Generic
dbIndex uint64
dbExists bool
@@ -542,13 +541,6 @@ func (n *Network) CopyTo(o datastore.KVObject) error {
return nil
}
func (n *Network) getEpCnt() *endpointCnt {
n.mu.Lock()
defer n.mu.Unlock()
return n.epCnt
}
func (n *Network) validateAdvertiseAddrConfig() error {
var errs []error
_, err := n.validatedAdvertiseAddrNMsgs()
@@ -1040,15 +1032,20 @@ func (n *Network) delete(force bool, rmLBEndpoint bool) error {
return &ActiveEndpointsError{name: n.name, id: n.id}
}
if !force && n.configOnly {
refNws := c.findNetworks(filterNetworkByConfigFrom(n.name))
if len(refNws) > 0 {
return types.ForbiddenErrorf("configuration network %q is in use", n.Name())
}
}
// Check that the network is empty
var emptyCount uint64
var emptyCount int
if n.hasLoadBalancerEndpoint() {
emptyCount = 1
}
if !force && n.getEpCnt().EndpointCnt() > emptyCount {
if n.configOnly {
return types.ForbiddenErrorf("configuration network %q is in use", n.Name())
}
eps := c.findEndpoints(filterEndpointByNetworkId(n.id))
if !force && len(eps) > emptyCount {
return &ActiveEndpointsError{name: n.name, id: n.id}
}
@@ -1062,11 +1059,6 @@ func (n *Network) delete(force bool, rmLBEndpoint bool) error {
// continue deletion when force is true even on error
log.G(context.TODO()).Warnf("Error deleting load balancer sandbox: %v", err)
}
// Reload the network from the store to update the epcnt.
n, err = c.getNetworkFromStore(id)
if err != nil {
return errdefs.NotFound(fmt.Errorf("unknown network %s id %s", name, id))
}
}
// Up to this point, errors that we returned were recoverable.
@@ -1080,17 +1072,6 @@ func (n *Network) delete(force bool, rmLBEndpoint bool) error {
return fmt.Errorf("error marking network %s (%s) for deletion: %v", n.Name(), n.ID(), err)
}
if n.ConfigFrom() != "" {
if t, err := c.getConfigNetwork(n.ConfigFrom()); err == nil {
if err := t.getEpCnt().DecEndpointCnt(); err != nil {
log.G(context.TODO()).Warnf("Failed to update reference count for configuration network %q on removal of network %q: %v",
t.Name(), n.Name(), err)
}
} else {
log.G(context.TODO()).Warnf("Could not find configuration network %q during removal of network %q", n.configFrom, n.Name())
}
}
if n.configOnly {
goto removeFromStore
}
@@ -1127,16 +1108,6 @@ func (n *Network) delete(force bool, rmLBEndpoint bool) error {
}
removeFromStore:
// deleteFromStore performs an atomic delete operation and the
// Network.epCnt will help prevent any possible
// race between endpoint join and network delete
if err = c.deleteFromStore(n.getEpCnt()); err != nil {
if !force {
return fmt.Errorf("error deleting network endpoint count from store: %v", err)
}
log.G(context.TODO()).Debugf("Error deleting endpoint count from store for stale network %s (%s) for deletion: %v", n.Name(), n.ID(), err)
}
if err = c.deleteStoredNetwork(n); err != nil {
return fmt.Errorf("error deleting network from store: %v", err)
}
@@ -1293,11 +1264,6 @@ func (n *Network) createEndpoint(ctx context.Context, name string, options ...En
}()
}
// Increment endpoint count to indicate completion of endpoint addition
if err = n.getEpCnt().IncEndpointCnt(); err != nil {
return nil, err
}
return ep, nil
}

View File

@@ -1,7 +1,12 @@
// FIXME(thaJeztah): remove once we are a module; the go:build directive prevents go from downgrading language version to go1.16:
//go:build go1.22
package libnetwork
import (
"context"
"github.com/docker/docker/internal/maputil"
)
// storeNetwork inserts or updates the network in the store and the in-memory
@@ -41,3 +46,21 @@ func (c *Controller) cacheNetwork(n *Network) {
defer c.networksMu.Unlock()
c.networks[n.ID()] = n
}
// findNetworks looks for all networks matching the filter from the in-memory
// cache of networks maintained by the Controller.
//
// This method is thread-safe, but do not use it unless you're sure your code
// uses the returned networks in thread-safe way (see the comment on
// Controller.networks).
func (c *Controller) findNetworks(filter func(nw *Network) bool) []*Network {
c.networksMu.Lock()
defer c.networksMu.Unlock()
return maputil.FilterValues(c.networks, filter)
}
func filterNetworkByConfigFrom(expected string) func(nw *Network) bool {
return func(nw *Network) bool {
return nw.configFrom == expected
}
}

View File

@@ -2,6 +2,7 @@ package libnetwork
import (
"context"
"slices"
"testing"
"github.com/docker/docker/libnetwork/config"
@@ -24,10 +25,51 @@ func TestNetworkStore(t *testing.T) {
err = c.storeNetwork(context.Background(), nw2)
assert.NilError(t, err)
netSorter := func(a, b *Network) int {
if a.name < b.name {
return -1
}
if a.name > b.name {
return 1
}
return 0
}
for _, tc := range []struct {
name string
filter func(nw *Network) bool
expNetworks []*Network
}{
{
name: "no filter",
filter: func(nw *Network) bool { return true },
expNetworks: []*Network{nw1, nw2},
},
{
name: "filter by configFrom",
filter: filterNetworkByConfigFrom("config-network"),
expNetworks: []*Network{nw1},
},
} {
t.Run(tc.name, func(t *testing.T) {
found := c.findNetworks(tc.filter)
assert.Equal(t, len(found), len(tc.expNetworks))
slices.SortFunc(found, netSorter)
for i, nw := range tc.expNetworks {
assert.Check(t, found[i] == nw, "got: %s; expected: %s", found[i].name, nw.name)
}
})
}
// Delete the first network
err = c.deleteStoredNetwork(nw1)
assert.NilError(t, err)
// Check that we can only find the second network
found := c.findNetworks(func(nw *Network) bool { return true })
assert.Equal(t, len(found), 1)
assert.Check(t, found[0] == nw2)
// Store the second network again
err = c.storeNetwork(context.Background(), nw2)
assert.NilError(t, err)

View File

@@ -3,7 +3,6 @@ package libnetwork
import (
"context"
"fmt"
"strings"
"github.com/containerd/log"
"github.com/docker/docker/libnetwork/datastore"
@@ -32,15 +31,6 @@ func (c *Controller) getNetworks() ([]*Network, error) {
n := kvo.(*Network)
n.ctrlr = c
c.cacheNetwork(n)
ec := &endpointCnt{n: n}
err = c.store.GetObject(ec)
if err != nil && !n.inDelete {
log.G(context.TODO()).Warnf("Could not find endpoint count key %s for network %s while listing: %v", datastore.Key(ec.Key()...), n.Name(), err)
continue
}
n.epCnt = ec
if n.scope == "" {
n.scope = scope.Local
}
@@ -61,22 +51,10 @@ func (c *Controller) getNetworksFromStore(ctx context.Context) []*Network { // F
return nil
}
kvep, err := c.store.Map(datastore.Key(epCntKeyPrefix), &endpointCnt{})
if err != nil && err != datastore.ErrKeyNotFound {
log.G(ctx).Warnf("failed to get endpoint_count map from store: %v", err)
}
for _, kvo := range kvol {
n := kvo.(*Network)
n.mu.Lock()
n.ctrlr = c
ec := &endpointCnt{n: n}
// Trim the leading & trailing "/" to make it consistent across all stores
if val, ok := kvep[strings.Trim(datastore.Key(ec.Key()...), "/")]; ok {
ec = val.(*endpointCnt)
ec.n = n
n.epCnt = ec
}
if n.scope == "" {
n.scope = scope.Local
}