diff --git a/file/replica_client.go b/file/replica_client.go index 2365401..05bcac5 100644 --- a/file/replica_client.go +++ b/file/replica_client.go @@ -178,11 +178,20 @@ func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, ma } // Write LTX file to temporary file next to destination path. - f, err := internal.CreateFile(filename+".tmp", fileInfo) + tmpFilename := filename + ".tmp" + f, err := internal.CreateFile(tmpFilename, fileInfo) if err != nil { return nil, err } - defer f.Close() + + // Clean up temp file on error. On successful rename, the temp file + // becomes the final file and should not be removed. + defer func() { + _ = f.Close() + if err != nil { + _ = os.Remove(tmpFilename) + } + }() if _, err := io.Copy(f, fullReader); err != nil { return nil, err @@ -209,7 +218,7 @@ func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, ma } // Move LTX file to final path when it has been written & synced to disk. - if err := os.Rename(filename+".tmp", filename); err != nil { + if err := os.Rename(tmpFilename, filename); err != nil { return nil, err } diff --git a/file/replica_client_test.go b/file/replica_client_test.go index c7148c6..7047afc 100644 --- a/file/replica_client_test.go +++ b/file/replica_client_test.go @@ -1,7 +1,16 @@ package file_test import ( + "bytes" + "context" + "fmt" + "os" + "path/filepath" + "strings" "testing" + "time" + + "github.com/superfly/ltx" "github.com/benbjohnson/litestream/file" ) @@ -19,6 +28,157 @@ func TestReplicaClient_Type(t *testing.T) { } } +// TestReplicaClient_WriteLTXFile_ErrorCleanup verifies temp files are cleaned up on errors +func TestReplicaClient_WriteLTXFile_ErrorCleanup(t *testing.T) { + t.Run("DiskFull", func(t *testing.T) { + tmpDir := t.TempDir() + c := file.NewReplicaClient(tmpDir) + + // Create a reader that fails after 50 bytes to simulate disk full + failReader := &failAfterReader{ + data: createLTXHeader(1, 2), + n: 50, + err: fmt.Errorf("no space left on device"), + } + + _, err := c.WriteLTXFile(context.Background(), 0, 1, 2, failReader) + if err == nil { + t.Fatal("expected error from failReader") + } + if !strings.Contains(err.Error(), "no space left on device") { + t.Fatalf("expected disk full error, got: %v", err) + } + + // Verify no .tmp files remain + tmpFiles := findTmpFiles(t, tmpDir) + if len(tmpFiles) > 0 { + t.Fatalf("found %d .tmp files after error: %v", len(tmpFiles), tmpFiles) + } + }) + + t.Run("SuccessNoLeaks", func(t *testing.T) { + tmpDir := t.TempDir() + c := file.NewReplicaClient(tmpDir) + + ltxData := createLTXData(1, 2, []byte("test data")) + info, err := c.WriteLTXFile(context.Background(), 0, 1, 2, bytes.NewReader(ltxData)) + if err != nil { + t.Fatal(err) + } + if info == nil { + t.Fatal("expected FileInfo") + } + + // Verify no .tmp files remain + tmpFiles := findTmpFiles(t, tmpDir) + if len(tmpFiles) > 0 { + t.Fatalf("found %d .tmp files after successful write: %v", len(tmpFiles), tmpFiles) + } + + // Verify final file exists + finalPath := c.LTXFilePath(0, 1, 2) + if _, err := os.Stat(finalPath); err != nil { + t.Fatalf("final file missing: %v", err) + } + }) + + t.Run("MultipleErrors", func(t *testing.T) { + tmpDir := t.TempDir() + c := file.NewReplicaClient(tmpDir) + + // Simulate multiple failed writes + for i := 0; i < 5; i++ { + failReader := &failAfterReader{ + data: createLTXHeader(ltx.TXID(i+1), ltx.TXID(i+1)), + n: 30, + err: fmt.Errorf("write error %d", i), + } + + _, err := c.WriteLTXFile(context.Background(), 0, ltx.TXID(i+1), ltx.TXID(i+1), failReader) + if err == nil { + t.Fatalf("iteration %d: expected error from failReader", i) + } + } + + // Verify no .tmp files accumulated + tmpFiles := findTmpFiles(t, tmpDir) + if len(tmpFiles) > 0 { + t.Fatalf("found %d .tmp files after multiple errors: %v", len(tmpFiles), tmpFiles) + } + }) +} + +// failAfterReader simulates io.Copy failure after reading n bytes +type failAfterReader struct { + data []byte + n int // fail after n bytes + pos int + err error +} + +func (r *failAfterReader) Read(p []byte) (n int, err error) { + if r.pos >= r.n { + return 0, r.err + } + remaining := r.n - r.pos + toRead := len(p) + if toRead > remaining { + toRead = remaining + } + if toRead > len(r.data)-r.pos { + toRead = len(r.data) - r.pos + } + if toRead == 0 { + return 0, r.err + } + n = copy(p, r.data[r.pos:r.pos+toRead]) + r.pos += n + return n, nil +} + +// findTmpFiles recursively finds all .tmp files in the directory +func findTmpFiles(t *testing.T, root string) []string { + t.Helper() + var tmpFiles []string + err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + if !info.IsDir() && strings.HasSuffix(path, ".tmp") { + tmpFiles = append(tmpFiles, path) + } + return nil + }) + if err != nil { + t.Fatalf("walk error: %v", err) + } + return tmpFiles +} + +// createLTXData creates a minimal valid LTX file with a header for testing +func createLTXData(minTXID, maxTXID ltx.TXID, data []byte) []byte { + hdr := ltx.Header{ + Version: ltx.Version, + PageSize: 4096, + Commit: 1, + MinTXID: minTXID, + MaxTXID: maxTXID, + Timestamp: time.Now().UnixMilli(), + } + if minTXID == 1 { + hdr.PreApplyChecksum = 0 + } else { + hdr.PreApplyChecksum = ltx.ChecksumFlag + } + headerBytes, _ := hdr.MarshalBinary() + return append(headerBytes, data...) +} + +// createLTXHeader creates minimal LTX header for testing +func createLTXHeader(minTXID, maxTXID ltx.TXID) []byte { + return createLTXData(minTXID, maxTXID, nil) +} + /* func TestReplica_Sync(t *testing.T) { // Ensure replica can successfully sync after DB has sync'd.