diff --git a/pkg/atomicwriter/atomicwriter_test.go b/pkg/atomicwriter/atomicwriter_test.go index 3a2f48f91d..a4702ba576 100644 --- a/pkg/atomicwriter/atomicwriter_test.go +++ b/pkg/atomicwriter/atomicwriter_test.go @@ -2,48 +2,57 @@ package atomicwriter import ( "bytes" + "errors" "os" "path/filepath" "runtime" "testing" ) -var testMode os.FileMode = 0o640 - -func init() { - // Windows does not support full Linux file mode +// testMode returns the file-mode to use in tests, accounting for Windows +// not supporting full Linux file mode. +func testMode() os.FileMode { if runtime.GOOS == "windows" { - testMode = 0o666 + return 0o666 } + return 0o640 } -func TestAtomicWriteToFile(t *testing.T) { - tmpDir := t.TempDir() - - expected := []byte("barbaz") - if err := WriteFile(filepath.Join(tmpDir, "foo"), expected, testMode); err != nil { - t.Fatalf("Error writing to file: %v", err) - } - - actual, err := os.ReadFile(filepath.Join(tmpDir, "foo")) +// assertFile asserts the given fileName to exist, and to have the expected +// content and mode. +func assertFile(t *testing.T, fileName string, fileContent []byte, expectedMode os.FileMode) { + t.Helper() + actual, err := os.ReadFile(fileName) if err != nil { t.Fatalf("Error reading from file: %v", err) } - if !bytes.Equal(actual, expected) { - t.Fatalf("Data mismatch, expected %q, got %q", expected, actual) + if !bytes.Equal(actual, fileContent) { + t.Errorf("Data mismatch, expected %q, got %q", fileContent, actual) } - st, err := os.Stat(filepath.Join(tmpDir, "foo")) + st, err := os.Stat(fileName) if err != nil { t.Fatalf("Error statting file: %v", err) } - if expected := testMode; st.Mode() != expected { - t.Fatalf("Mode mismatched, expected %o, got %o", expected, st.Mode()) + if st.Mode() != expectedMode { + t.Errorf("Mode mismatched, expected %o, got %o", expectedMode, st.Mode()) } } -func TestAtomicWriteSetCommit(t *testing.T) { +func TestWriteFile(t *testing.T) { + tmpDir := t.TempDir() + + fileName := filepath.Join(tmpDir, "test.txt") + fileContent := []byte("file content") + fileMode := testMode() + if err := WriteFile(fileName, fileContent, fileMode); err != nil { + t.Fatalf("Error writing to file: %v", err) + } + assertFile(t, fileName, fileContent, fileMode) +} + +func TestWriteSetCommit(t *testing.T) { tmpDir := t.TempDir() if err := os.Mkdir(filepath.Join(tmpDir, "tmp"), 0o700); err != nil { @@ -56,8 +65,10 @@ func TestAtomicWriteSetCommit(t *testing.T) { t.Fatalf("Error creating atomic write set: %s", err) } - expected := []byte("barbaz") - if err := ws.WriteFile("foo", expected, testMode); err != nil { + fileContent := []byte("file content") + fileMode := testMode() + + if err := ws.WriteFile("foo", fileContent, fileMode); err != nil { t.Fatalf("Error writing to file: %v", err) } @@ -69,25 +80,10 @@ func TestAtomicWriteSetCommit(t *testing.T) { t.Fatalf("Error committing file: %s", err) } - actual, err := os.ReadFile(filepath.Join(targetDir, "foo")) - if err != nil { - t.Fatalf("Error reading from file: %v", err) - } - - if !bytes.Equal(actual, expected) { - t.Fatalf("Data mismatch, expected %q, got %q", expected, actual) - } - - st, err := os.Stat(filepath.Join(targetDir, "foo")) - if err != nil { - t.Fatalf("Error statting file: %v", err) - } - if expected := testMode; st.Mode() != expected { - t.Fatalf("Mode mismatched, expected %o, got %o", expected, st.Mode()) - } + assertFile(t, filepath.Join(targetDir, "foo"), fileContent, fileMode) } -func TestAtomicWriteSetCancel(t *testing.T) { +func TestWriteSetCancel(t *testing.T) { tmpDir := t.TempDir() if err := os.Mkdir(filepath.Join(tmpDir, "tmp"), 0o700); err != nil { @@ -99,8 +95,9 @@ func TestAtomicWriteSetCancel(t *testing.T) { t.Fatalf("Error creating atomic write set: %s", err) } - expected := []byte("barbaz") - if err := ws.WriteFile("foo", expected, testMode); err != nil { + fileContent := []byte("file content") + fileMode := testMode() + if err := ws.WriteFile("foo", fileContent, fileMode); err != nil { t.Fatalf("Error writing to file: %v", err) } @@ -110,7 +107,7 @@ func TestAtomicWriteSetCancel(t *testing.T) { if _, err := os.ReadFile(filepath.Join(tmpDir, "target", "foo")); err == nil { t.Fatalf("Expected error reading file where should not exist") - } else if !os.IsNotExist(err) { + } else if !errors.Is(err, os.ErrNotExist) { t.Fatalf("Unexpected error reading file: %s", err) } }