diff --git a/daemon/devices_amd_linux.go b/daemon/devices_amd_linux.go new file mode 100644 index 0000000000..a728c7074e --- /dev/null +++ b/daemon/devices_amd_linux.go @@ -0,0 +1,27 @@ +package daemon + +import ( + "strings" + + "github.com/opencontainers/runtime-spec/specs-go" +) + +func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error { + req := dev.req + if req.Count != 0 && len(req.DeviceIDs) > 0 { + return errConflictCountDeviceIDs + } + + switch { + case len(req.DeviceIDs) > 0: + s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES="+strings.Join(req.DeviceIDs, ",")) + case req.Count > 0: + s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES="+countToDevices(req.Count)) + case req.Count < 0: + s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES=all") + case req.Count == 0: + s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES=void") + } + + return nil +} diff --git a/daemon/nvidia_linux.go b/daemon/devices_nvidia_linux.go similarity index 78% rename from daemon/nvidia_linux.go rename to daemon/devices_nvidia_linux.go index abc6b4a351..8a30343134 100644 --- a/daemon/nvidia_linux.go +++ b/daemon/devices_nvidia_linux.go @@ -17,7 +17,10 @@ import ( var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request") -const nvidiaHook = "nvidia-container-runtime-hook" +const ( + nvidiaHook = "nvidia-container-runtime-hook" + amdContainerRuntimeExecutableName = "amd-container-runtime" +) // These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps var allNvidiaCaps = map[nvidia.Capability]struct{}{ @@ -30,19 +33,29 @@ var allNvidiaCaps = map[nvidia.Capability]struct{}{ } func init() { - if _, err := exec.LookPath(nvidiaHook); err != nil { - // do not register Nvidia driver if helper binary is not present. + // Register Nvidia driver if Nvidia helper binary is present. + if _, err := exec.LookPath(nvidiaHook); err == nil { + capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}} + for c := range allNvidiaCaps { + capset[string(c)] = struct{}{} + } + registerDeviceDriver("nvidia", &deviceDriver{ + capset: capset, + updateSpec: setNvidiaGPUs, + }) return } - capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}} - nvidiaDriver := &deviceDriver{ - capset: capset, - updateSpec: setNvidiaGPUs, + + // Register AMD driver if AMD helper binary is present. + if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil { + registerDeviceDriver("amd", &deviceDriver{ + capset: capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}}, + updateSpec: setAMDGPUs, + }) + return } - for c := range allNvidiaCaps { - nvidiaDriver.capset[string(c)] = struct{}{} - } - registerDeviceDriver("nvidia", nvidiaDriver) + + // No "gpu" capability } func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {