diff --git a/registry/auth.go b/registry/auth.go index 8f35dfff9c..f3b7551512 100644 --- a/registry/auth.go +++ b/registry/auth.go @@ -66,11 +66,11 @@ func (scs staticCredentialStore) SetRefreshToken(*url.URL, string, string) { // loginV2 tries to login to the v2 registry server. The given registry // endpoint will be pinged to get authorization challenges. These challenges // will be used to authenticate against the registry to validate credentials. -func loginV2(authConfig *registry.AuthConfig, endpoint APIEndpoint, userAgent string) (token string, _ error) { +func loginV2(ctx context.Context, authConfig *registry.AuthConfig, endpoint APIEndpoint, userAgent string) (token string, _ error) { endpointStr := strings.TrimRight(endpoint.URL.String(), "/") + "/v2/" - log.G(context.TODO()).Debugf("attempting v2 login to registry endpoint %s", endpointStr) + log.G(ctx).WithField("endpoint", endpointStr).Debug("attempting v2 login to registry endpoint") - req, err := http.NewRequest(http.MethodGet, endpointStr, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpointStr, nil) if err != nil { return "", err } diff --git a/registry/registry.go b/registry/registry.go index 80a9fecf2c..09cf912087 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -138,15 +138,13 @@ func newTransport(tlsConfig *tls.Config) http.RoundTripper { tlsConfig = tlsconfig.ServerDefault() } - direct := &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - } - return otelhttp.NewTransport( &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: direct.DialContext, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, // TODO(dmcgowan): Call close idle connections when complete and use keep alive diff --git a/registry/search.go b/registry/search.go index 2dec4d1318..0564a62e25 100644 --- a/registry/search.go +++ b/registry/search.go @@ -125,7 +125,7 @@ func (s *Service) searchUnfiltered(ctx context.Context, term string, limit int, client = v2Client } else { client = endpoint.client - if err := authorizeClient(client, authConfig, endpoint); err != nil { + if err := authorizeClient(ctx, client, authConfig, endpoint); err != nil { return nil, err } } diff --git a/registry/search_session.go b/registry/search_session.go index 87a3199612..efc22430af 100644 --- a/registry/search_session.go +++ b/registry/search_session.go @@ -173,18 +173,18 @@ func (tr *authTransport) CancelRequest(req *http.Request) { } } -func authorizeClient(client *http.Client, authConfig *registry.AuthConfig, endpoint *v1Endpoint) error { +func authorizeClient(ctx context.Context, client *http.Client, authConfig *registry.AuthConfig, endpoint *v1Endpoint) error { var alwaysSetBasicAuth bool // If we're working with a standalone private registry over HTTPS, send Basic Auth headers // alongside all our requests. if endpoint.String() != IndexServer && endpoint.URL.Scheme == "https" { - info, err := endpoint.ping(context.TODO()) + info, err := endpoint.ping(ctx) if err != nil { return err } if info.Standalone && authConfig != nil { - log.G(context.TODO()).Debugf("Endpoint %s is eligible for private registry. Enabling decorator.", endpoint.String()) + log.G(ctx).WithField("endpoint", endpoint.String()).Debug("Endpoint is eligible for private registry; enabling alwaysSetBasicAuth") alwaysSetBasicAuth = true } } diff --git a/registry/search_test.go b/registry/search_test.go index f02928bd99..3584b3c1dc 100644 --- a/registry/search_test.go +++ b/registry/search_test.go @@ -26,7 +26,7 @@ func spawnTestRegistrySession(t *testing.T) *session { tr = transport.NewTransport(newAuthTransport(tr, authConfig, false), Headers(userAgent, nil)...) client := httpClient(tr) - if err := authorizeClient(client, authConfig, endpoint); err != nil { + if err := authorizeClient(context.Background(), client, authConfig, endpoint); err != nil { t.Fatal(err) } r := newSession(client, endpoint) diff --git a/registry/service.go b/registry/service.go index 48726a7929..3d87cfb478 100644 --- a/registry/service.go +++ b/registry/service.go @@ -74,17 +74,20 @@ func (s *Service) Auth(ctx context.Context, authConfig *registry.AuthConfig, use // Lookup endpoints for authentication but exclude mirrors to prevent // sending credentials of the upstream registry to a mirror. s.mu.RLock() - endpoints, err := s.lookupV2Endpoints(registryHostName, false) + endpoints, err := s.lookupV2Endpoints(ctx, registryHostName, false) s.mu.RUnlock() if err != nil { + if errdefs.IsContext(err) { + return "", "", err + } return "", "", invalidParam(err) } var lastErr error for _, endpoint := range endpoints { - authToken, err := loginV2(authConfig, endpoint, userAgent) + authToken, err := loginV2(ctx, authConfig, endpoint, userAgent) if err != nil { - if errdefs.IsUnauthorized(err) { + if errdefs.IsContext(err) || errdefs.IsUnauthorized(err) { // Failed to authenticate; don't continue with (non-TLS) endpoints. return "", "", err } @@ -149,7 +152,7 @@ func (s *Service) LookupPullEndpoints(hostname string) (endpoints []APIEndpoint, s.mu.RLock() defer s.mu.RUnlock() - return s.lookupV2Endpoints(hostname, true) + return s.lookupV2Endpoints(context.TODO(), hostname, true) } // LookupPushEndpoints creates a list of v2 endpoints to try to push to, in order of preference. @@ -158,7 +161,7 @@ func (s *Service) LookupPushEndpoints(hostname string) (endpoints []APIEndpoint, s.mu.RLock() defer s.mu.RUnlock() - return s.lookupV2Endpoints(hostname, false) + return s.lookupV2Endpoints(context.TODO(), hostname, false) } // IsInsecureRegistry returns true if the registry at given host is configured as diff --git a/registry/service_v2.go b/registry/service_v2.go index f96bc0e035..ee1ee0d44b 100644 --- a/registry/service_v2.go +++ b/registry/service_v2.go @@ -8,12 +8,14 @@ import ( "github.com/docker/go-connections/tlsconfig" ) -func (s *Service) lookupV2Endpoints(hostname string, includeMirrors bool) ([]APIEndpoint, error) { - ctx := context.TODO() +func (s *Service) lookupV2Endpoints(ctx context.Context, hostname string, includeMirrors bool) ([]APIEndpoint, error) { var endpoints []APIEndpoint if hostname == DefaultNamespace || hostname == IndexHostname { if includeMirrors { for _, mirror := range s.config.Mirrors { + if ctx.Err() != nil { + return nil, ctx.Err() + } if !strings.HasPrefix(mirror, "http://") && !strings.HasPrefix(mirror, "https://") { mirror = "https://" + mirror }