Merge pull request #49741 from thaJeztah/atomicwriter_stricter_validate

pkg/atomicwriter: disallow symlinks for now, add more tests and touch-up GoDoc
This commit is contained in:
Sebastiaan van Stijn
2025-04-04 20:11:23 +02:00
committed by GitHub
2 changed files with 70 additions and 30 deletions

View File

@@ -1,3 +1,5 @@
// Package atomicwriter provides utilities to perform atomic writes to a
// file or set of files.
package atomicwriter
import (
@@ -6,6 +8,7 @@ import (
"io"
"os"
"path/filepath"
"syscall"
"github.com/moby/sys/sequential"
)
@@ -14,35 +17,33 @@ func validateDestination(fileName string) error {
if fileName == "" {
return errors.New("file name is empty")
}
if dir := filepath.Dir(fileName); dir != "" && dir != "." && dir != ".." {
di, err := os.Stat(dir)
if err != nil {
return fmt.Errorf("invalid output path: %w", err)
}
if !di.IsDir() {
return fmt.Errorf("invalid output path: %w", &os.PathError{Op: "stat", Path: dir, Err: syscall.ENOTDIR})
}
}
// Deliberately using Lstat here to match the behavior of [os.Rename],
// which is used when completing the write and does not resolve symlinks.
//
// TODO(thaJeztah): decide whether we want to disallow symlinks or to follow them.
if fi, err := os.Lstat(fileName); err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("failed to stat output path: %w", err)
fi, err := os.Lstat(fileName)
if err != nil {
if os.IsNotExist(err) {
return nil
}
} else if err := validateFileMode(fi.Mode()); err != nil {
return err
return fmt.Errorf("failed to stat output path: %w", err)
}
if dir := filepath.Dir(fileName); dir != "" && dir != "." {
if _, err := os.Stat(dir); errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("invalid file path: %w", err)
}
}
return nil
}
func validateFileMode(mode os.FileMode) error {
switch {
switch mode := fi.Mode(); {
case mode.IsRegular():
return nil // Regular file
case mode&os.ModeDir != 0:
return errors.New("cannot write to a directory")
// TODO(thaJeztah): decide whether we want to disallow symlinks or to follow them.
// case mode&os.ModeSymlink != 0:
// return errors.New("cannot write to a symbolic link directly")
case mode&os.ModeSymlink != 0:
return errors.New("cannot write to a symbolic link directly")
case mode&os.ModeNamedPipe != 0:
return errors.New("cannot write to a named pipe (FIFO)")
case mode&os.ModeSocket != 0:
@@ -59,8 +60,7 @@ func validateFileMode(mode os.FileMode) error {
case mode&os.ModeSticky != 0:
return errors.New("cannot write to a sticky bit file")
default:
// Unknown file mode; let's assume it works
return nil
return fmt.Errorf("unknown file mode: %[1]s (%#[1]o)", mode)
}
}
@@ -95,7 +95,12 @@ func New(filename string, perm os.FileMode) (io.WriteCloser, error) {
}, nil
}
// WriteFile atomically writes data to a file named by filename and with the specified permission bits.
// WriteFile atomically writes data to a file named by filename and with the
// specified permission bits. The given filename is created if it does not exist,
// but the destination directory must exist. It can be used as a drop-in replacement
// for [os.WriteFile], but currently does not allow the destination path to be
// a symlink. WriteFile is implemented using [New] for its implementation.
//
// NOTE: umask is not considered for the file's permissions.
func WriteFile(filename string, data []byte, perm os.FileMode) error {
f, err := New(filename, perm)

View File

@@ -7,6 +7,7 @@ import (
"path/filepath"
"runtime"
"strings"
"syscall"
"testing"
)
@@ -120,6 +121,23 @@ func TestNewInvalid(t *testing.T) {
t.Errorf("Should produce a 'not found' error, but got %[1]T (%[1]v)", err)
}
})
t.Run("target dir is not a directory", func(t *testing.T) {
tmpDir := t.TempDir()
parentPath := filepath.Join(tmpDir, "not-a-dir")
err := os.WriteFile(parentPath, nil, testMode())
if err != nil {
t.Fatalf("Error writing file: %v", err)
}
fileName := filepath.Join(parentPath, "new-file.txt")
writer, err := New(fileName, testMode())
if writer != nil {
t.Errorf("Should not have created writer")
}
// This should match the behavior of os.WriteFile, which returns a [os.PathError] with [syscall.ENOTDIR].
if !errors.Is(err, syscall.ENOTDIR) {
t.Errorf("Should produce a 'not a directory' error, but got %[1]T (%[1]v)", err)
}
})
t.Run("empty filename", func(t *testing.T) {
writer, err := New("", testMode())
if writer != nil {
@@ -139,6 +157,24 @@ func TestNewInvalid(t *testing.T) {
t.Errorf("Should produce a 'cannot write to a directory' error, but got %[1]T (%[1]v)", err)
}
})
t.Run("symlinked file", func(t *testing.T) {
tmpDir := t.TempDir()
linkTarget := filepath.Join(tmpDir, "symlink-target")
if err := os.WriteFile(linkTarget, []byte("orig content"), testMode()); err != nil {
t.Fatal(err)
}
fileName := filepath.Join(tmpDir, "symlinked-file")
if err := os.Symlink(linkTarget, fileName); err != nil {
t.Fatal(err)
}
writer, err := New(fileName, testMode())
if writer != nil {
t.Errorf("Should not have created writer")
}
if err == nil || err.Error() != "cannot write to a symbolic link directly" {
t.Errorf("Should produce a 'cannot write to a symbolic link directly' error, but got %[1]T (%[1]v)", err)
}
})
}
func TestWriteFile(t *testing.T) {
@@ -178,7 +214,9 @@ func TestWriteFile(t *testing.T) {
t.Run("symlinked file", func(t *testing.T) {
tmpDir := t.TempDir()
linkTarget := filepath.Join(tmpDir, "symlink-target")
if err := os.WriteFile(linkTarget, []byte("orig content"), testMode()); err != nil {
originalContent := []byte("original content")
fileMode := testMode()
if err := os.WriteFile(linkTarget, originalContent, fileMode); err != nil {
t.Fatal(err)
}
if err := os.Symlink(linkTarget, filepath.Join(tmpDir, "symlinked-file")); err != nil {
@@ -188,15 +226,12 @@ func TestWriteFile(t *testing.T) {
assertFileCount(t, tmpDir, origFileCount)
fileName := filepath.Join(tmpDir, "symlinked-file")
fileContent := []byte("new content")
fileMode := testMode()
if err := WriteFile(fileName, fileContent, fileMode); err != nil {
t.Fatalf("Error writing to file: %v", err)
err := WriteFile(fileName, []byte("new content"), testMode())
if err == nil || err.Error() != "cannot write to a symbolic link directly" {
t.Errorf("Should produce a 'cannot write to a symbolic link directly' error, but got %[1]T (%[1]v)", err)
}
assertFile(t, fileName, fileContent, fileMode)
assertFile(t, linkTarget, originalContent, fileMode)
assertFileCount(t, tmpDir, origFileCount)
// FIXME(thaJeztah): [os.Rename] does not resolve symlinks, so writing to a symlinked location replaces the link with a file.
// assertFile(t, linkTarget, fileContent, fileMode)
})
t.Run("symlinked directory", func(t *testing.T) {
tmpDir := t.TempDir()