pkg/atomicwriter: don't overwrite destination on close without write

Creating a writer (`atomicwriter.New()`) and closing it without a write
ever happening, would replace the destination file with an empty file.

This patch adds a check whether a write was performed (either successful
or unsuccessful); if no write happened, we cleanup the tempfile without
replacing the destination file.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
This commit is contained in:
Sebastiaan van Stijn
2025-03-10 10:44:39 +01:00
parent 88a5bca43c
commit ff061e28c1
2 changed files with 13 additions and 4 deletions

View File

@@ -1,6 +1,7 @@
package atomicwriter
import (
"errors"
"io"
"os"
"path/filepath"
@@ -49,10 +50,12 @@ type atomicFileWriter struct {
f *os.File
fn string
writeErr error
written bool
perm os.FileMode
}
func (w *atomicFileWriter) Write(dt []byte) (int, error) {
w.written = true
n, err := w.f.Write(dt)
if err != nil {
w.writeErr = err
@@ -62,12 +65,12 @@ func (w *atomicFileWriter) Write(dt []byte) (int, error) {
func (w *atomicFileWriter) Close() (retErr error) {
defer func() {
if retErr != nil || w.writeErr != nil {
os.Remove(w.f.Name())
if err := os.Remove(w.f.Name()); !errors.Is(err, os.ErrNotExist) && retErr == nil {
retErr = err
}
}()
if err := w.f.Sync(); err != nil {
w.f.Close()
_ = w.f.Close()
return err
}
if err := w.f.Close(); err != nil {
@@ -76,7 +79,7 @@ func (w *atomicFileWriter) Close() (retErr error) {
if err := os.Chmod(w.f.Name(), w.perm); err != nil {
return err
}
if w.writeErr == nil {
if w.writeErr == nil && w.written {
return os.Rename(w.f.Name(), w.fn)
}
return nil

View File

@@ -93,9 +93,15 @@ func TestNew(t *testing.T) {
t.Errorf("Unexpected file name for temp-file: %s", tmpFileName)
}
// Closing the writer without writing should clean up the temp-file,
// and should not replace the destination file.
if err = writer.Close(); err != nil {
t.Errorf("Error closing writer: %v", err)
}
assertFileCount(t, actualParentDir, origFileCount)
if tc == "existing-file" {
assertFile(t, fileName, []byte("original content"), testMode())
}
})
}
})