Files
moby/daemon/devices_nvidia_linux.go
Sudheendra Gopinath e32715ec03 Added support for AMD GPUs in "docker run --gpus".
Added backend code to support the exact same interface
used today for Nvidia GPUs, allowing customers to use
the same docker commands for both Nvidia and AMD GPUs.

Signed-off-by: Sudheendra Gopinath <sudheendra.gopinath@amd.com>

Reused common functions from nvidia_linux.go.

Removed duplicate code in amd_linux.go by reusing
the init() and countToDevices() functions in
nvidia_linux.go. AMD driver is registered in init().

Signed-off-by: Sudheendra Gopinath <sudheendra.gopinath@amd.com>

Renamed amd-container-runtime constant

Signed-off-by: Sudheendra Gopinath <sudheendra.gopinath@amd.com>

Removed empty branch to keep linter happy.

Also renamed amd_linux.go to gpu_amd_linux.go.

Signed-off-by: Sudheendra Gopinath <sudheendra.gopinath@amd.com>

Renamed nvidia_linux.go and gpu_amd_linux.go.

Signed-off-by: Sudheendra Gopinath <sudheendra.gopinath@amd.com>
2025-06-05 14:44:18 +00:00

128 lines
3.7 KiB
Go

package daemon
import (
"os"
"os/exec"
"strconv"
"strings"
"github.com/containerd/containerd/v2/contrib/nvidia"
"github.com/docker/docker/daemon/internal/capabilities"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/pkg/errors"
)
// TODO: nvidia should not be hard-coded, and should be a device plugin instead on the daemon object.
// TODO: add list of device capabilities in daemon/node info
var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request")
const (
nvidiaHook = "nvidia-container-runtime-hook"
amdContainerRuntimeExecutableName = "amd-container-runtime"
)
// These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
var allNvidiaCaps = map[nvidia.Capability]struct{}{
nvidia.Compute: {},
nvidia.Compat32: {},
nvidia.Graphics: {},
nvidia.Utility: {},
nvidia.Video: {},
nvidia.Display: {},
}
func init() {
// Register Nvidia driver if Nvidia helper binary is present.
if _, err := exec.LookPath(nvidiaHook); err == nil {
capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}}
for c := range allNvidiaCaps {
capset[string(c)] = struct{}{}
}
registerDeviceDriver("nvidia", &deviceDriver{
capset: capset,
updateSpec: setNvidiaGPUs,
})
return
}
// Register AMD driver if AMD helper binary is present.
if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
registerDeviceDriver("amd", &deviceDriver{
capset: capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}},
updateSpec: setAMDGPUs,
})
return
}
// No "gpu" capability
}
func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {
req := dev.req
if req.Count != 0 && len(req.DeviceIDs) > 0 {
return errConflictCountDeviceIDs
}
switch {
case len(req.DeviceIDs) > 0:
s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+strings.Join(req.DeviceIDs, ","))
case req.Count > 0:
s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+countToDevices(req.Count))
case req.Count < 0:
s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=all")
case req.Count == 0:
s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=void")
}
var nvidiaCaps []string
// req.Capabilities contains device capabilities, some but not all are NVIDIA driver capabilities.
for _, c := range dev.selectedCaps {
nvcap := nvidia.Capability(c)
if _, isNvidiaCap := allNvidiaCaps[nvcap]; isNvidiaCap {
nvidiaCaps = append(nvidiaCaps, c)
continue
}
// TODO: nvidia.WithRequiredCUDAVersion
// for now we let the prestart hook verify cuda versions but errors are not pretty.
}
if nvidiaCaps != nil {
s.Process.Env = append(s.Process.Env, "NVIDIA_DRIVER_CAPABILITIES="+strings.Join(nvidiaCaps, ","))
}
path, err := exec.LookPath(nvidiaHook)
if err != nil {
return err
}
if s.Hooks == nil {
s.Hooks = &specs.Hooks{}
}
// This implementation uses prestart hooks, which are deprecated.
// CreateRuntime is the closest equivalent, and executed in the same
// locations as prestart-hooks, but depending on what these hooks do,
// possibly one of the other hooks could be used instead (such as
// CreateContainer or StartContainer).
s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{ //nolint:staticcheck // FIXME(thaJeztah); replace prestart hook with a non-deprecated one.
Path: path,
Args: []string{
nvidiaHook,
"prestart",
},
Env: os.Environ(),
})
return nil
}
// countToDevices returns the list 0, 1, ... count-1 of deviceIDs.
func countToDevices(count int) string {
devices := make([]string, count)
for i := range devices {
devices[i] = strconv.Itoa(i)
}
return strings.Join(devices, ",")
}