From ed28f2ee62fe908429ff7fb736eabc8a5ccb5a16 Mon Sep 17 00:00:00 2001 From: Cory LaNou Date: Sat, 17 Jan 2026 10:32:41 -0600 Subject: [PATCH] feat(replicate): add IPC control commands for dynamic start/stop (#1010) Co-authored-by: Claude Opus 4.5 --- cmd/litestream/main.go | 10 ++ cmd/litestream/replicate.go | 19 +++ cmd/litestream/start.go | 106 +++++++++++++++ cmd/litestream/stop.go | 107 +++++++++++++++ db.go | 29 ++++ etc/litestream.yml | 6 + server.go | 257 ++++++++++++++++++++++++++++++++++++ store.go | 71 +++++++++- 8 files changed, 603 insertions(+), 2 deletions(-) create mode 100644 cmd/litestream/start.go create mode 100644 cmd/litestream/stop.go create mode 100644 server.go diff --git a/cmd/litestream/main.go b/cmd/litestream/main.go index 4dd3aed..ed3153b 100644 --- a/cmd/litestream/main.go +++ b/cmd/litestream/main.go @@ -159,6 +159,10 @@ func (m *Main) Run(ctx context.Context, args []string) (err error) { slog.Info("litestream shut down") return err + case "start": + return (&StartCommand{}).Run(ctx, args) + case "stop": + return (&StopCommand{}).Run(ctx, args) case "reset": return (&ResetCommand{}).Run(ctx, args) case "restore": @@ -198,7 +202,9 @@ The commands are: replicate runs a server to replicate databases reset reset local state for a database restore recovers database backup from a replica + start start replication for a database status display replication status for databases + stop stop replication for a database version prints the binary version `[1:]) } @@ -211,6 +217,9 @@ type Config struct { // Bind address for serving metrics. Addr string `yaml:"addr"` + // Socket configuration for control commands. + Socket litestream.SocketConfig `yaml:"socket"` + // List of stages in a multi-level compaction. // Only includes L1 through the last non-snapshot level. Levels []*CompactionLevelConfig `yaml:"levels"` @@ -292,6 +301,7 @@ func DefaultConfig() Config { Interval: &defaultSnapshotInterval, Retention: &defaultSnapshotRetention, }, + Socket: litestream.DefaultSocketConfig(), L0Retention: &defaultL0Retention, L0RetentionCheckInterval: &defaultL0RetentionCheckInterval, ShutdownSyncTimeout: &defaultShutdownSyncTimeout, diff --git a/cmd/litestream/replicate.go b/cmd/litestream/replicate.go index 8da2a17..5c65b2d 100644 --- a/cmd/litestream/replicate.go +++ b/cmd/litestream/replicate.go @@ -41,6 +41,9 @@ type ReplicateCommand struct { // MCP server MCP *MCPServer + // Server for IPC control commands. + Server *litestream.Server + // Manages the set of databases & compaction levels. Store *litestream.Store @@ -272,6 +275,17 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) { return fmt.Errorf("cannot open store: %w", err) } + // Start control server if socket is enabled + if c.Config.Socket.Enabled { + c.Server = litestream.NewServer(c.Store) + c.Server.SocketPath = c.Config.Socket.Path + c.Server.SocketPerms = c.Config.Socket.Permissions + c.Server.PathExpander = expand + if err := c.Server.Start(); err != nil { + slog.Warn("failed to start control server", "error", err) + } + } + for _, entry := range watchables { monitor, err := NewDirectoryMonitor(ctx, c.Store, entry.config, entry.dbs) if err != nil { @@ -402,6 +416,11 @@ func (c *ReplicateCommand) Close(ctx context.Context) error { } c.directoryMonitors = nil + if c.Server != nil { + if err := c.Server.Close(); err != nil { + slog.Error("error closing control server", "error", err) + } + } if c.Store != nil { if err := c.Store.Close(ctx); err != nil { slog.Error("failed to close database", "error", err) diff --git a/cmd/litestream/start.go b/cmd/litestream/start.go new file mode 100644 index 0000000..4ef7cf5 --- /dev/null +++ b/cmd/litestream/start.go @@ -0,0 +1,106 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net" + "net/http" + "time" + + "github.com/benbjohnson/litestream" +) + +// StartCommand represents the command to start replication for a database. +type StartCommand struct{} + +// Run executes the start command. +func (c *StartCommand) Run(ctx context.Context, args []string) error { + fs := flag.NewFlagSet("litestream-start", flag.ContinueOnError) + timeout := fs.Int("timeout", 30, "timeout in seconds") + socketPath := fs.String("socket", "/var/run/litestream.sock", "control socket path") + fs.Usage = c.Usage + if err := fs.Parse(args); err != nil { + return err + } + + if fs.NArg() == 0 { + return fmt.Errorf("database path required") + } + if fs.NArg() > 1 { + return fmt.Errorf("too many arguments") + } + + dbPath := fs.Arg(0) + + // Create HTTP client that connects via Unix socket with timeout + clientTimeout := time.Duration(*timeout) * time.Second + client := &http.Client{ + Timeout: clientTimeout, + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.DialTimeout("unix", *socketPath, clientTimeout) + }, + }, + } + + req := litestream.StartRequest{ + Path: dbPath, + Timeout: *timeout, + } + reqBody, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + resp, err := client.Post("http://localhost/start", "application/json", bytes.NewReader(reqBody)) + if err != nil { + return fmt.Errorf("failed to connect to control socket: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var errResp litestream.ErrorResponse + if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" { + return fmt.Errorf("start failed: %s", errResp.Error) + } + return fmt.Errorf("start failed: %s", string(body)) + } + + var result litestream.StartResponse + if err := json.Unmarshal(body, &result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + output, err := json.MarshalIndent(result, "", " ") + if err != nil { + return fmt.Errorf("failed to format response: %w", err) + } + fmt.Println(string(output)) + + return nil +} + +// Usage prints the help text for the start command. +func (c *StartCommand) Usage() { + fmt.Println(` +usage: litestream start [OPTIONS] DB_PATH + +Start replication for a database. + +Options: + -timeout SECONDS + Maximum time to wait in seconds (default: 30). + + -socket PATH + Path to control socket (default: /var/run/litestream.sock). +`[1:]) +} diff --git a/cmd/litestream/stop.go b/cmd/litestream/stop.go new file mode 100644 index 0000000..c38a895 --- /dev/null +++ b/cmd/litestream/stop.go @@ -0,0 +1,107 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net" + "net/http" + "time" + + "github.com/benbjohnson/litestream" +) + +// StopCommand represents the command to stop replication for a database. +type StopCommand struct{} + +// Run executes the stop command. +func (c *StopCommand) Run(ctx context.Context, args []string) error { + fs := flag.NewFlagSet("litestream-stop", flag.ContinueOnError) + timeout := fs.Int("timeout", 30, "timeout in seconds") + socketPath := fs.String("socket", "/var/run/litestream.sock", "control socket path") + fs.Usage = c.Usage + if err := fs.Parse(args); err != nil { + return err + } + + if fs.NArg() == 0 { + return fmt.Errorf("database path required") + } + if fs.NArg() > 1 { + return fmt.Errorf("too many arguments") + } + + dbPath := fs.Arg(0) + + // Create HTTP client that connects via Unix socket with timeout + clientTimeout := time.Duration(*timeout) * time.Second + client := &http.Client{ + Timeout: clientTimeout, + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.DialTimeout("unix", *socketPath, clientTimeout) + }, + }, + } + + req := litestream.StopRequest{ + Path: dbPath, + Timeout: *timeout, + } + reqBody, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + resp, err := client.Post("http://localhost/stop", "application/json", bytes.NewReader(reqBody)) + if err != nil { + return fmt.Errorf("failed to connect to control socket: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var errResp litestream.ErrorResponse + if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" { + return fmt.Errorf("stop failed: %s", errResp.Error) + } + return fmt.Errorf("stop failed: %s", string(body)) + } + + var result litestream.StopResponse + if err := json.Unmarshal(body, &result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + output, err := json.MarshalIndent(result, "", " ") + if err != nil { + return fmt.Errorf("failed to format response: %w", err) + } + fmt.Println(string(output)) + + return nil +} + +// Usage prints the help text for the stop command. +func (c *StopCommand) Usage() { + fmt.Println(` +usage: litestream stop [OPTIONS] DB_PATH + +Stop replication for a database. +Stop always waits for shutdown and final sync. + +Options: + -timeout SECONDS + Maximum time to wait in seconds (default: 30). + + -socket PATH + Path to control socket (default: /var/run/litestream.sock). +`[1:]) +} diff --git a/db.go b/db.go index 5cae409..bc07bc5 100644 --- a/db.go +++ b/db.go @@ -71,6 +71,7 @@ type DB struct { pageSize int // page size, in bytes notify chan struct{} // closes on WAL change chkMu sync.RWMutex // checkpoint lock + opened bool // true if Open() was called and Close() not yet called // syncedSinceCheckpoint tracks whether any data has been synced since // the last checkpoint. Used to prevent time-based checkpoints from @@ -233,6 +234,13 @@ func (db *DB) Path() string { return db.path } +// IsOpen returns true if the database has been opened. +func (db *DB) IsOpen() bool { + db.mu.RLock() + defer db.mu.RUnlock() + return db.opened +} + // WALPath returns the path to the database's WAL file. func (db *DB) WALPath() string { return db.path + "-wal" @@ -395,6 +403,15 @@ func (db *DB) LastSuccessfulSyncAt() time.Time { // Open initializes the background monitoring goroutine. func (db *DB) Open() (err error) { + db.mu.Lock() + if db.opened { + db.mu.Unlock() + return nil // already open + } + // Recreate context for fresh start (handles reopen after close) + db.ctx, db.cancel = context.WithCancel(context.Background()) + db.mu.Unlock() + // Validate fields on database. if db.MinCheckpointPageN <= 0 { return fmt.Errorf("minimum checkpoint page count required") @@ -411,6 +428,11 @@ func (db *DB) Open() (err error) { go func() { defer db.wg.Done(); db.monitor() }() } + // Mark as opened only after successful initialization + db.mu.Lock() + db.opened = true + db.mu.Unlock() + return nil } @@ -448,14 +470,21 @@ func (db *DB) Close(ctx context.Context) (err error) { if e := db.db.Close(); e != nil && err == nil { err = e } + db.db = nil } if db.f != nil { if e := db.f.Close(); e != nil && err == nil { err = e } + db.f = nil } + db.mu.Lock() + db.opened = false + db.rtx = nil + db.mu.Unlock() + return err } diff --git a/etc/litestream.yml b/etc/litestream.yml index 28f3719..c2a31e2 100644 --- a/etc/litestream.yml +++ b/etc/litestream.yml @@ -2,6 +2,12 @@ # access-key-id: AKIAxxxxxxxxxxxxxxxx # secret-access-key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx/xxxxxxxxx +# Control socket for runtime start/stop commands +# socket: +# enabled: true # Enable control socket (default: false) +# path: /var/run/litestream.sock # Socket path (default: /var/run/litestream.sock) +# permissions: 0600 # Socket file permissions (default: 0600) + # dbs: # - path: /path/to/primary/db # Database to replicate from # replica: diff --git a/server.go b/server.go new file mode 100644 index 0000000..c53318a --- /dev/null +++ b/server.go @@ -0,0 +1,257 @@ +package litestream + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net" + "net/http" + "os" + "sync" + "time" +) + +// SocketConfig configures the Unix socket for control commands. +type SocketConfig struct { + Enabled bool `yaml:"enabled"` + Path string `yaml:"path"` + Permissions uint32 `yaml:"permissions"` +} + +// DefaultSocketConfig returns the default socket configuration. +func DefaultSocketConfig() SocketConfig { + return SocketConfig{ + Enabled: false, + Path: "/var/run/litestream.sock", + Permissions: 0600, + } +} + +// Server manages runtime control via Unix socket using HTTP. +type Server struct { + store *Store + + // SocketPath is the path to the Unix socket. + SocketPath string + + // SocketPerms is the file permissions for the socket. + SocketPerms uint32 + + // PathExpander optionally expands paths (e.g., ~ expansion). + // If nil, paths are used as-is. + PathExpander func(string) (string, error) + + socketListener net.Listener + httpServer *http.Server + + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + logger *slog.Logger +} + +// NewServer creates a new Server instance. +func NewServer(store *Store) *Server { + ctx, cancel := context.WithCancel(context.Background()) + s := &Server{ + store: store, + SocketPerms: 0600, + ctx: ctx, + cancel: cancel, + logger: slog.Default(), + } + + mux := http.NewServeMux() + mux.HandleFunc("POST /start", s.handleStart) + mux.HandleFunc("POST /stop", s.handleStop) + + s.httpServer = &http.Server{Handler: mux} + + return s +} + +// Start begins listening for control connections. +func (s *Server) Start() error { + if s.SocketPath == "" { + return fmt.Errorf("socket path required") + } + + // Check if socket file exists and is actually a socket before removing + if info, err := os.Lstat(s.SocketPath); err == nil { + if info.Mode()&os.ModeSocket != 0 { + if err := os.Remove(s.SocketPath); err != nil { + return fmt.Errorf("remove existing socket: %w", err) + } + } else { + return fmt.Errorf("socket path exists but is not a socket: %s", s.SocketPath) + } + } else if !os.IsNotExist(err) { + return fmt.Errorf("check socket path: %w", err) + } + + listener, err := net.Listen("unix", s.SocketPath) + if err != nil { + return fmt.Errorf("listen on unix socket: %w", err) + } + s.socketListener = listener + + if err := os.Chmod(s.SocketPath, os.FileMode(s.SocketPerms)); err != nil { + listener.Close() + return fmt.Errorf("chmod socket: %w", err) + } + + s.logger.Info("control socket listening", "path", s.SocketPath) + + s.wg.Add(1) + go func() { + defer s.wg.Done() + if err := s.httpServer.Serve(listener); err != nil && err != http.ErrServerClosed { + s.logger.Error("http server error", "error", err) + } + }() + + return nil +} + +// Close gracefully shuts down the control server. +func (s *Server) Close() error { + s.cancel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if s.httpServer != nil { + if err := s.httpServer.Shutdown(ctx); err != nil { + s.logger.Error("http server shutdown error", "error", err) + } + } + s.wg.Wait() + return nil +} + +// expandPath expands the path using PathExpander if set. +func (s *Server) expandPath(path string) (string, error) { + if s.PathExpander != nil { + return s.PathExpander(path) + } + return path, nil +} + +func (s *Server) handleStart(w http.ResponseWriter, r *http.Request) { + var req StartRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid request body", err.Error()) + return + } + + if req.Path == "" { + writeJSONError(w, http.StatusBadRequest, "path required", nil) + return + } + + expandedPath, err := s.expandPath(req.Path) + if err != nil { + writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("invalid path: %v", err), nil) + return + } + + ctx := s.ctx + if req.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(s.ctx, time.Duration(req.Timeout)*time.Second) + defer cancel() + } + + if err := s.store.EnableDB(ctx, expandedPath); err != nil { + writeJSONError(w, http.StatusInternalServerError, err.Error(), nil) + return + } + + writeJSON(w, http.StatusOK, StartResponse{ + Status: "started", + Path: expandedPath, + }) +} + +func (s *Server) handleStop(w http.ResponseWriter, r *http.Request) { + var req StopRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid request body", err.Error()) + return + } + + if req.Path == "" { + writeJSONError(w, http.StatusBadRequest, "path required", nil) + return + } + + expandedPath, err := s.expandPath(req.Path) + if err != nil { + writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("invalid path: %v", err), nil) + return + } + + timeout := req.Timeout + if timeout == 0 { + timeout = 30 + } + ctx, cancel := context.WithTimeout(s.ctx, time.Duration(timeout)*time.Second) + defer cancel() + + if err := s.store.DisableDB(ctx, expandedPath); err != nil { + writeJSONError(w, http.StatusInternalServerError, err.Error(), nil) + return + } + + writeJSON(w, http.StatusOK, StopResponse{ + Status: "stopped", + Path: expandedPath, + }) +} + +func writeJSON(w http.ResponseWriter, status int, v interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(v) +} + +func writeJSONError(w http.ResponseWriter, status int, message string, details interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: message, + Details: details, + }) +} + +// StartRequest is the request body for the /start endpoint. +type StartRequest struct { + Path string `json:"path"` + Timeout int `json:"timeout,omitempty"` +} + +// StartResponse is the response body for the /start endpoint. +type StartResponse struct { + Status string `json:"status"` + Path string `json:"path"` +} + +// StopRequest is the request body for the /stop endpoint. +type StopRequest struct { + Path string `json:"path"` + Timeout int `json:"timeout,omitempty"` +} + +// StopResponse is the response body for the /stop endpoint. +type StopResponse struct { + Status string `json:"status"` + Path string `json:"path"` +} + +// ErrorResponse is returned when an error occurs. +type ErrorResponse struct { + Error string `json:"error"` + Details interface{} `json:"details,omitempty"` +} diff --git a/store.go b/store.go index 3cd82c1..e7764f1 100644 --- a/store.go +++ b/store.go @@ -277,6 +277,62 @@ func (s *Store) RemoveDB(ctx context.Context, path string) error { return nil } +// EnableDB starts replication for a registered database. +// The context is checked for cancellation before opening. +// Note: db.Open() itself does not support cancellation. +func (s *Store) EnableDB(ctx context.Context, path string) error { + db := s.FindDB(path) + if db == nil { + return fmt.Errorf("database not found: %s", path) + } + + if db.IsOpen() { + return fmt.Errorf("database already enabled: %s", path) + } + + // Check for cancellation before starting open + if err := ctx.Err(); err != nil { + return fmt.Errorf("enable database: %w", err) + } + + if err := db.Open(); err != nil { + return fmt.Errorf("open database: %w", err) + } + + return nil +} + +// DisableDB stops replication for a database. +func (s *Store) DisableDB(ctx context.Context, path string) error { + db := s.FindDB(path) + if db == nil { + return fmt.Errorf("database not found: %s", path) + } + + if !db.IsOpen() { + return fmt.Errorf("database already disabled: %s", path) + } + + if err := db.Close(ctx); err != nil { + return fmt.Errorf("close database: %w", err) + } + + return nil +} + +// FindDB returns the database with the given path. +func (s *Store) FindDB(path string) *DB { + s.mu.Lock() + defer s.mu.Unlock() + + for _, db := range s.dbs { + if db.Path() == path { + return db + } + } + return nil +} + // SetL0Retention updates the retention window for L0 files and propagates it to // all managed databases. func (s *Store) SetL0Retention(d time.Duration) { @@ -339,6 +395,9 @@ func (s *Store) monitorCompactionLevel(ctx context.Context, lvl *CompactionLevel var notReadyDBs []string for _, db := range s.DBs() { + if !db.IsOpen() { + continue // skip disabled DBs + } _, err := s.CompactDB(ctx, db, lvl) switch { case errors.Is(err, ErrNoCompaction), errors.Is(err, ErrCompactionTooEarly): @@ -398,6 +457,9 @@ LOOP: } for _, db := range s.DBs() { + if !db.IsOpen() { + continue // skip disabled DBs + } if err := db.EnforceL0RetentionByTime(ctx); err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { continue @@ -496,20 +558,25 @@ func (s *Store) sendHeartbeatIfNeeded(ctx context.Context) { } // allDatabasesHealthy returns true if all databases have synced successfully -// since the given time. Returns false if there are no databases. +// since the given time. Returns false if there are no databases or no enabled databases. func (s *Store) allDatabasesHealthy(since time.Time) bool { dbs := s.DBs() if len(dbs) == 0 { return false } + enabledCount := 0 for _, db := range dbs { + if !db.IsOpen() { + continue // skip disabled DBs + } + enabledCount++ lastSync := db.LastSuccessfulSyncAt() if lastSync.IsZero() || lastSync.Before(since) { return false } } - return true + return enabledCount > 0 } // CompactDB performs a compaction or snapshot for a given database on a single destination level.