feat(replicate): add IPC control commands for dynamic start/stop (#1010)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Cory LaNou
2026-01-17 10:32:41 -06:00
committed by GitHub
parent 87d1f0d781
commit ed28f2ee62
8 changed files with 603 additions and 2 deletions

View File

@@ -159,6 +159,10 @@ func (m *Main) Run(ctx context.Context, args []string) (err error) {
slog.Info("litestream shut down")
return err
case "start":
return (&StartCommand{}).Run(ctx, args)
case "stop":
return (&StopCommand{}).Run(ctx, args)
case "reset":
return (&ResetCommand{}).Run(ctx, args)
case "restore":
@@ -198,7 +202,9 @@ The commands are:
replicate runs a server to replicate databases
reset reset local state for a database
restore recovers database backup from a replica
start start replication for a database
status display replication status for databases
stop stop replication for a database
version prints the binary version
`[1:])
}
@@ -211,6 +217,9 @@ type Config struct {
// Bind address for serving metrics.
Addr string `yaml:"addr"`
// Socket configuration for control commands.
Socket litestream.SocketConfig `yaml:"socket"`
// List of stages in a multi-level compaction.
// Only includes L1 through the last non-snapshot level.
Levels []*CompactionLevelConfig `yaml:"levels"`
@@ -292,6 +301,7 @@ func DefaultConfig() Config {
Interval: &defaultSnapshotInterval,
Retention: &defaultSnapshotRetention,
},
Socket: litestream.DefaultSocketConfig(),
L0Retention: &defaultL0Retention,
L0RetentionCheckInterval: &defaultL0RetentionCheckInterval,
ShutdownSyncTimeout: &defaultShutdownSyncTimeout,

View File

@@ -41,6 +41,9 @@ type ReplicateCommand struct {
// MCP server
MCP *MCPServer
// Server for IPC control commands.
Server *litestream.Server
// Manages the set of databases & compaction levels.
Store *litestream.Store
@@ -272,6 +275,17 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) {
return fmt.Errorf("cannot open store: %w", err)
}
// Start control server if socket is enabled
if c.Config.Socket.Enabled {
c.Server = litestream.NewServer(c.Store)
c.Server.SocketPath = c.Config.Socket.Path
c.Server.SocketPerms = c.Config.Socket.Permissions
c.Server.PathExpander = expand
if err := c.Server.Start(); err != nil {
slog.Warn("failed to start control server", "error", err)
}
}
for _, entry := range watchables {
monitor, err := NewDirectoryMonitor(ctx, c.Store, entry.config, entry.dbs)
if err != nil {
@@ -402,6 +416,11 @@ func (c *ReplicateCommand) Close(ctx context.Context) error {
}
c.directoryMonitors = nil
if c.Server != nil {
if err := c.Server.Close(); err != nil {
slog.Error("error closing control server", "error", err)
}
}
if c.Store != nil {
if err := c.Store.Close(ctx); err != nil {
slog.Error("failed to close database", "error", err)

106
cmd/litestream/start.go Normal file
View File

@@ -0,0 +1,106 @@
package main
import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net"
"net/http"
"time"
"github.com/benbjohnson/litestream"
)
// StartCommand represents the command to start replication for a database.
type StartCommand struct{}
// Run executes the start command.
func (c *StartCommand) Run(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("litestream-start", flag.ContinueOnError)
timeout := fs.Int("timeout", 30, "timeout in seconds")
socketPath := fs.String("socket", "/var/run/litestream.sock", "control socket path")
fs.Usage = c.Usage
if err := fs.Parse(args); err != nil {
return err
}
if fs.NArg() == 0 {
return fmt.Errorf("database path required")
}
if fs.NArg() > 1 {
return fmt.Errorf("too many arguments")
}
dbPath := fs.Arg(0)
// Create HTTP client that connects via Unix socket with timeout
clientTimeout := time.Duration(*timeout) * time.Second
client := &http.Client{
Timeout: clientTimeout,
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.DialTimeout("unix", *socketPath, clientTimeout)
},
},
}
req := litestream.StartRequest{
Path: dbPath,
Timeout: *timeout,
}
reqBody, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
resp, err := client.Post("http://localhost/start", "application/json", bytes.NewReader(reqBody))
if err != nil {
return fmt.Errorf("failed to connect to control socket: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errResp litestream.ErrorResponse
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" {
return fmt.Errorf("start failed: %s", errResp.Error)
}
return fmt.Errorf("start failed: %s", string(body))
}
var result litestream.StartResponse
if err := json.Unmarshal(body, &result); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
output, err := json.MarshalIndent(result, "", " ")
if err != nil {
return fmt.Errorf("failed to format response: %w", err)
}
fmt.Println(string(output))
return nil
}
// Usage prints the help text for the start command.
func (c *StartCommand) Usage() {
fmt.Println(`
usage: litestream start [OPTIONS] DB_PATH
Start replication for a database.
Options:
-timeout SECONDS
Maximum time to wait in seconds (default: 30).
-socket PATH
Path to control socket (default: /var/run/litestream.sock).
`[1:])
}

107
cmd/litestream/stop.go Normal file
View File

@@ -0,0 +1,107 @@
package main
import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net"
"net/http"
"time"
"github.com/benbjohnson/litestream"
)
// StopCommand represents the command to stop replication for a database.
type StopCommand struct{}
// Run executes the stop command.
func (c *StopCommand) Run(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("litestream-stop", flag.ContinueOnError)
timeout := fs.Int("timeout", 30, "timeout in seconds")
socketPath := fs.String("socket", "/var/run/litestream.sock", "control socket path")
fs.Usage = c.Usage
if err := fs.Parse(args); err != nil {
return err
}
if fs.NArg() == 0 {
return fmt.Errorf("database path required")
}
if fs.NArg() > 1 {
return fmt.Errorf("too many arguments")
}
dbPath := fs.Arg(0)
// Create HTTP client that connects via Unix socket with timeout
clientTimeout := time.Duration(*timeout) * time.Second
client := &http.Client{
Timeout: clientTimeout,
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.DialTimeout("unix", *socketPath, clientTimeout)
},
},
}
req := litestream.StopRequest{
Path: dbPath,
Timeout: *timeout,
}
reqBody, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
resp, err := client.Post("http://localhost/stop", "application/json", bytes.NewReader(reqBody))
if err != nil {
return fmt.Errorf("failed to connect to control socket: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errResp litestream.ErrorResponse
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" {
return fmt.Errorf("stop failed: %s", errResp.Error)
}
return fmt.Errorf("stop failed: %s", string(body))
}
var result litestream.StopResponse
if err := json.Unmarshal(body, &result); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
output, err := json.MarshalIndent(result, "", " ")
if err != nil {
return fmt.Errorf("failed to format response: %w", err)
}
fmt.Println(string(output))
return nil
}
// Usage prints the help text for the stop command.
func (c *StopCommand) Usage() {
fmt.Println(`
usage: litestream stop [OPTIONS] DB_PATH
Stop replication for a database.
Stop always waits for shutdown and final sync.
Options:
-timeout SECONDS
Maximum time to wait in seconds (default: 30).
-socket PATH
Path to control socket (default: /var/run/litestream.sock).
`[1:])
}

29
db.go
View File

@@ -71,6 +71,7 @@ type DB struct {
pageSize int // page size, in bytes
notify chan struct{} // closes on WAL change
chkMu sync.RWMutex // checkpoint lock
opened bool // true if Open() was called and Close() not yet called
// syncedSinceCheckpoint tracks whether any data has been synced since
// the last checkpoint. Used to prevent time-based checkpoints from
@@ -233,6 +234,13 @@ func (db *DB) Path() string {
return db.path
}
// IsOpen returns true if the database has been opened.
func (db *DB) IsOpen() bool {
db.mu.RLock()
defer db.mu.RUnlock()
return db.opened
}
// WALPath returns the path to the database's WAL file.
func (db *DB) WALPath() string {
return db.path + "-wal"
@@ -395,6 +403,15 @@ func (db *DB) LastSuccessfulSyncAt() time.Time {
// Open initializes the background monitoring goroutine.
func (db *DB) Open() (err error) {
db.mu.Lock()
if db.opened {
db.mu.Unlock()
return nil // already open
}
// Recreate context for fresh start (handles reopen after close)
db.ctx, db.cancel = context.WithCancel(context.Background())
db.mu.Unlock()
// Validate fields on database.
if db.MinCheckpointPageN <= 0 {
return fmt.Errorf("minimum checkpoint page count required")
@@ -411,6 +428,11 @@ func (db *DB) Open() (err error) {
go func() { defer db.wg.Done(); db.monitor() }()
}
// Mark as opened only after successful initialization
db.mu.Lock()
db.opened = true
db.mu.Unlock()
return nil
}
@@ -448,14 +470,21 @@ func (db *DB) Close(ctx context.Context) (err error) {
if e := db.db.Close(); e != nil && err == nil {
err = e
}
db.db = nil
}
if db.f != nil {
if e := db.f.Close(); e != nil && err == nil {
err = e
}
db.f = nil
}
db.mu.Lock()
db.opened = false
db.rtx = nil
db.mu.Unlock()
return err
}

View File

@@ -2,6 +2,12 @@
# access-key-id: AKIAxxxxxxxxxxxxxxxx
# secret-access-key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx/xxxxxxxxx
# Control socket for runtime start/stop commands
# socket:
# enabled: true # Enable control socket (default: false)
# path: /var/run/litestream.sock # Socket path (default: /var/run/litestream.sock)
# permissions: 0600 # Socket file permissions (default: 0600)
# dbs:
# - path: /path/to/primary/db # Database to replicate from
# replica:

257
server.go Normal file
View File

@@ -0,0 +1,257 @@
package litestream
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"os"
"sync"
"time"
)
// SocketConfig configures the Unix socket for control commands.
type SocketConfig struct {
Enabled bool `yaml:"enabled"`
Path string `yaml:"path"`
Permissions uint32 `yaml:"permissions"`
}
// DefaultSocketConfig returns the default socket configuration.
func DefaultSocketConfig() SocketConfig {
return SocketConfig{
Enabled: false,
Path: "/var/run/litestream.sock",
Permissions: 0600,
}
}
// Server manages runtime control via Unix socket using HTTP.
type Server struct {
store *Store
// SocketPath is the path to the Unix socket.
SocketPath string
// SocketPerms is the file permissions for the socket.
SocketPerms uint32
// PathExpander optionally expands paths (e.g., ~ expansion).
// If nil, paths are used as-is.
PathExpander func(string) (string, error)
socketListener net.Listener
httpServer *http.Server
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
logger *slog.Logger
}
// NewServer creates a new Server instance.
func NewServer(store *Store) *Server {
ctx, cancel := context.WithCancel(context.Background())
s := &Server{
store: store,
SocketPerms: 0600,
ctx: ctx,
cancel: cancel,
logger: slog.Default(),
}
mux := http.NewServeMux()
mux.HandleFunc("POST /start", s.handleStart)
mux.HandleFunc("POST /stop", s.handleStop)
s.httpServer = &http.Server{Handler: mux}
return s
}
// Start begins listening for control connections.
func (s *Server) Start() error {
if s.SocketPath == "" {
return fmt.Errorf("socket path required")
}
// Check if socket file exists and is actually a socket before removing
if info, err := os.Lstat(s.SocketPath); err == nil {
if info.Mode()&os.ModeSocket != 0 {
if err := os.Remove(s.SocketPath); err != nil {
return fmt.Errorf("remove existing socket: %w", err)
}
} else {
return fmt.Errorf("socket path exists but is not a socket: %s", s.SocketPath)
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("check socket path: %w", err)
}
listener, err := net.Listen("unix", s.SocketPath)
if err != nil {
return fmt.Errorf("listen on unix socket: %w", err)
}
s.socketListener = listener
if err := os.Chmod(s.SocketPath, os.FileMode(s.SocketPerms)); err != nil {
listener.Close()
return fmt.Errorf("chmod socket: %w", err)
}
s.logger.Info("control socket listening", "path", s.SocketPath)
s.wg.Add(1)
go func() {
defer s.wg.Done()
if err := s.httpServer.Serve(listener); err != nil && err != http.ErrServerClosed {
s.logger.Error("http server error", "error", err)
}
}()
return nil
}
// Close gracefully shuts down the control server.
func (s *Server) Close() error {
s.cancel()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if s.httpServer != nil {
if err := s.httpServer.Shutdown(ctx); err != nil {
s.logger.Error("http server shutdown error", "error", err)
}
}
s.wg.Wait()
return nil
}
// expandPath expands the path using PathExpander if set.
func (s *Server) expandPath(path string) (string, error) {
if s.PathExpander != nil {
return s.PathExpander(path)
}
return path, nil
}
func (s *Server) handleStart(w http.ResponseWriter, r *http.Request) {
var req StartRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSONError(w, http.StatusBadRequest, "invalid request body", err.Error())
return
}
if req.Path == "" {
writeJSONError(w, http.StatusBadRequest, "path required", nil)
return
}
expandedPath, err := s.expandPath(req.Path)
if err != nil {
writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("invalid path: %v", err), nil)
return
}
ctx := s.ctx
if req.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(s.ctx, time.Duration(req.Timeout)*time.Second)
defer cancel()
}
if err := s.store.EnableDB(ctx, expandedPath); err != nil {
writeJSONError(w, http.StatusInternalServerError, err.Error(), nil)
return
}
writeJSON(w, http.StatusOK, StartResponse{
Status: "started",
Path: expandedPath,
})
}
func (s *Server) handleStop(w http.ResponseWriter, r *http.Request) {
var req StopRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSONError(w, http.StatusBadRequest, "invalid request body", err.Error())
return
}
if req.Path == "" {
writeJSONError(w, http.StatusBadRequest, "path required", nil)
return
}
expandedPath, err := s.expandPath(req.Path)
if err != nil {
writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("invalid path: %v", err), nil)
return
}
timeout := req.Timeout
if timeout == 0 {
timeout = 30
}
ctx, cancel := context.WithTimeout(s.ctx, time.Duration(timeout)*time.Second)
defer cancel()
if err := s.store.DisableDB(ctx, expandedPath); err != nil {
writeJSONError(w, http.StatusInternalServerError, err.Error(), nil)
return
}
writeJSON(w, http.StatusOK, StopResponse{
Status: "stopped",
Path: expandedPath,
})
}
func writeJSON(w http.ResponseWriter, status int, v interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(v)
}
func writeJSONError(w http.ResponseWriter, status int, message string, details interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(ErrorResponse{
Error: message,
Details: details,
})
}
// StartRequest is the request body for the /start endpoint.
type StartRequest struct {
Path string `json:"path"`
Timeout int `json:"timeout,omitempty"`
}
// StartResponse is the response body for the /start endpoint.
type StartResponse struct {
Status string `json:"status"`
Path string `json:"path"`
}
// StopRequest is the request body for the /stop endpoint.
type StopRequest struct {
Path string `json:"path"`
Timeout int `json:"timeout,omitempty"`
}
// StopResponse is the response body for the /stop endpoint.
type StopResponse struct {
Status string `json:"status"`
Path string `json:"path"`
}
// ErrorResponse is returned when an error occurs.
type ErrorResponse struct {
Error string `json:"error"`
Details interface{} `json:"details,omitempty"`
}

View File

@@ -277,6 +277,62 @@ func (s *Store) RemoveDB(ctx context.Context, path string) error {
return nil
}
// EnableDB starts replication for a registered database.
// The context is checked for cancellation before opening.
// Note: db.Open() itself does not support cancellation.
func (s *Store) EnableDB(ctx context.Context, path string) error {
db := s.FindDB(path)
if db == nil {
return fmt.Errorf("database not found: %s", path)
}
if db.IsOpen() {
return fmt.Errorf("database already enabled: %s", path)
}
// Check for cancellation before starting open
if err := ctx.Err(); err != nil {
return fmt.Errorf("enable database: %w", err)
}
if err := db.Open(); err != nil {
return fmt.Errorf("open database: %w", err)
}
return nil
}
// DisableDB stops replication for a database.
func (s *Store) DisableDB(ctx context.Context, path string) error {
db := s.FindDB(path)
if db == nil {
return fmt.Errorf("database not found: %s", path)
}
if !db.IsOpen() {
return fmt.Errorf("database already disabled: %s", path)
}
if err := db.Close(ctx); err != nil {
return fmt.Errorf("close database: %w", err)
}
return nil
}
// FindDB returns the database with the given path.
func (s *Store) FindDB(path string) *DB {
s.mu.Lock()
defer s.mu.Unlock()
for _, db := range s.dbs {
if db.Path() == path {
return db
}
}
return nil
}
// SetL0Retention updates the retention window for L0 files and propagates it to
// all managed databases.
func (s *Store) SetL0Retention(d time.Duration) {
@@ -339,6 +395,9 @@ func (s *Store) monitorCompactionLevel(ctx context.Context, lvl *CompactionLevel
var notReadyDBs []string
for _, db := range s.DBs() {
if !db.IsOpen() {
continue // skip disabled DBs
}
_, err := s.CompactDB(ctx, db, lvl)
switch {
case errors.Is(err, ErrNoCompaction), errors.Is(err, ErrCompactionTooEarly):
@@ -398,6 +457,9 @@ LOOP:
}
for _, db := range s.DBs() {
if !db.IsOpen() {
continue // skip disabled DBs
}
if err := db.EnforceL0RetentionByTime(ctx); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
continue
@@ -496,20 +558,25 @@ func (s *Store) sendHeartbeatIfNeeded(ctx context.Context) {
}
// allDatabasesHealthy returns true if all databases have synced successfully
// since the given time. Returns false if there are no databases.
// since the given time. Returns false if there are no databases or no enabled databases.
func (s *Store) allDatabasesHealthy(since time.Time) bool {
dbs := s.DBs()
if len(dbs) == 0 {
return false
}
enabledCount := 0
for _, db := range dbs {
if !db.IsOpen() {
continue // skip disabled DBs
}
enabledCount++
lastSync := db.LastSuccessfulSyncAt()
if lastSync.IsZero() || lastSync.Before(since) {
return false
}
}
return true
return enabledCount > 0
}
// CompactDB performs a compaction or snapshot for a given database on a single destination level.