fix(file): clean up .ltx.tmp files on all error paths (#991)

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Cory LaNou
2026-01-12 09:20:01 -06:00
committed by GitHub
parent c77ecef421
commit 33639d49f1
2 changed files with 172 additions and 3 deletions

View File

@@ -178,11 +178,20 @@ func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, ma
}
// Write LTX file to temporary file next to destination path.
f, err := internal.CreateFile(filename+".tmp", fileInfo)
tmpFilename := filename + ".tmp"
f, err := internal.CreateFile(tmpFilename, fileInfo)
if err != nil {
return nil, err
}
defer f.Close()
// Clean up temp file on error. On successful rename, the temp file
// becomes the final file and should not be removed.
defer func() {
_ = f.Close()
if err != nil {
_ = os.Remove(tmpFilename)
}
}()
if _, err := io.Copy(f, fullReader); err != nil {
return nil, err
@@ -209,7 +218,7 @@ func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, ma
}
// Move LTX file to final path when it has been written & synced to disk.
if err := os.Rename(filename+".tmp", filename); err != nil {
if err := os.Rename(tmpFilename, filename); err != nil {
return nil, err
}

View File

@@ -1,7 +1,16 @@
package file_test
import (
"bytes"
"context"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/superfly/ltx"
"github.com/benbjohnson/litestream/file"
)
@@ -19,6 +28,157 @@ func TestReplicaClient_Type(t *testing.T) {
}
}
// TestReplicaClient_WriteLTXFile_ErrorCleanup verifies temp files are cleaned up on errors
func TestReplicaClient_WriteLTXFile_ErrorCleanup(t *testing.T) {
t.Run("DiskFull", func(t *testing.T) {
tmpDir := t.TempDir()
c := file.NewReplicaClient(tmpDir)
// Create a reader that fails after 50 bytes to simulate disk full
failReader := &failAfterReader{
data: createLTXHeader(1, 2),
n: 50,
err: fmt.Errorf("no space left on device"),
}
_, err := c.WriteLTXFile(context.Background(), 0, 1, 2, failReader)
if err == nil {
t.Fatal("expected error from failReader")
}
if !strings.Contains(err.Error(), "no space left on device") {
t.Fatalf("expected disk full error, got: %v", err)
}
// Verify no .tmp files remain
tmpFiles := findTmpFiles(t, tmpDir)
if len(tmpFiles) > 0 {
t.Fatalf("found %d .tmp files after error: %v", len(tmpFiles), tmpFiles)
}
})
t.Run("SuccessNoLeaks", func(t *testing.T) {
tmpDir := t.TempDir()
c := file.NewReplicaClient(tmpDir)
ltxData := createLTXData(1, 2, []byte("test data"))
info, err := c.WriteLTXFile(context.Background(), 0, 1, 2, bytes.NewReader(ltxData))
if err != nil {
t.Fatal(err)
}
if info == nil {
t.Fatal("expected FileInfo")
}
// Verify no .tmp files remain
tmpFiles := findTmpFiles(t, tmpDir)
if len(tmpFiles) > 0 {
t.Fatalf("found %d .tmp files after successful write: %v", len(tmpFiles), tmpFiles)
}
// Verify final file exists
finalPath := c.LTXFilePath(0, 1, 2)
if _, err := os.Stat(finalPath); err != nil {
t.Fatalf("final file missing: %v", err)
}
})
t.Run("MultipleErrors", func(t *testing.T) {
tmpDir := t.TempDir()
c := file.NewReplicaClient(tmpDir)
// Simulate multiple failed writes
for i := 0; i < 5; i++ {
failReader := &failAfterReader{
data: createLTXHeader(ltx.TXID(i+1), ltx.TXID(i+1)),
n: 30,
err: fmt.Errorf("write error %d", i),
}
_, err := c.WriteLTXFile(context.Background(), 0, ltx.TXID(i+1), ltx.TXID(i+1), failReader)
if err == nil {
t.Fatalf("iteration %d: expected error from failReader", i)
}
}
// Verify no .tmp files accumulated
tmpFiles := findTmpFiles(t, tmpDir)
if len(tmpFiles) > 0 {
t.Fatalf("found %d .tmp files after multiple errors: %v", len(tmpFiles), tmpFiles)
}
})
}
// failAfterReader simulates io.Copy failure after reading n bytes
type failAfterReader struct {
data []byte
n int // fail after n bytes
pos int
err error
}
func (r *failAfterReader) Read(p []byte) (n int, err error) {
if r.pos >= r.n {
return 0, r.err
}
remaining := r.n - r.pos
toRead := len(p)
if toRead > remaining {
toRead = remaining
}
if toRead > len(r.data)-r.pos {
toRead = len(r.data) - r.pos
}
if toRead == 0 {
return 0, r.err
}
n = copy(p, r.data[r.pos:r.pos+toRead])
r.pos += n
return n, nil
}
// findTmpFiles recursively finds all .tmp files in the directory
func findTmpFiles(t *testing.T, root string) []string {
t.Helper()
var tmpFiles []string
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
if !info.IsDir() && strings.HasSuffix(path, ".tmp") {
tmpFiles = append(tmpFiles, path)
}
return nil
})
if err != nil {
t.Fatalf("walk error: %v", err)
}
return tmpFiles
}
// createLTXData creates a minimal valid LTX file with a header for testing
func createLTXData(minTXID, maxTXID ltx.TXID, data []byte) []byte {
hdr := ltx.Header{
Version: ltx.Version,
PageSize: 4096,
Commit: 1,
MinTXID: minTXID,
MaxTXID: maxTXID,
Timestamp: time.Now().UnixMilli(),
}
if minTXID == 1 {
hdr.PreApplyChecksum = 0
} else {
hdr.PreApplyChecksum = ltx.ChecksumFlag
}
headerBytes, _ := hdr.MarshalBinary()
return append(headerBytes, data...)
}
// createLTXHeader creates minimal LTX header for testing
func createLTXHeader(minTXID, maxTXID ltx.TXID) []byte {
return createLTXData(minTXID, maxTXID, nil)
}
/*
func TestReplica_Sync(t *testing.T) {
// Ensure replica can successfully sync after DB has sync'd.