From 8efcdd7e59bb3f7ce30cba7fc408a9e1155e10be Mon Sep 17 00:00:00 2001 From: Ben Johnson Date: Wed, 10 Dec 2025 16:10:55 -0700 Subject: [PATCH] Refactor replica URL parsing (#884) Co-authored-by: Cory LaNou Co-authored-by: Claude Opus 4.5 --- abs/replica_client.go | 26 ++ cmd/litestream-vfs/main.go | 37 +- cmd/litestream/ltx.go | 2 +- cmd/litestream/main.go | 166 +-------- cmd/litestream/main_test.go | 16 +- cmd/litestream/restore.go | 2 +- db_internal_test.go | 4 + file/replica_client.go | 20 + gs/replica_client.go | 18 + mock/replica_client.go | 8 + nats/replica_client.go | 32 ++ oss/replica_client.go | 22 ++ replica_client.go | 5 + replica_client_test.go | 6 +- replica_url.go | 209 +++++++++++ replica_url_test.go | 632 ++++++++++++++++++++++++++++++++ s3/replica_client.go | 107 ++++++ sftp/replica_client.go | 50 ++- store_compaction_remote_test.go | 2 + webdav/replica_client.go | 52 ++- webdav/replica_client_test.go | 2 +- 21 files changed, 1218 insertions(+), 200 deletions(-) create mode 100644 replica_url.go create mode 100644 replica_url_test.go diff --git a/abs/replica_client.go b/abs/replica_client.go index 966584e..69cac15 100644 --- a/abs/replica_client.go +++ b/abs/replica_client.go @@ -8,6 +8,7 @@ import ( "io" "log/slog" "net/http" + "net/url" "os" "path" "strings" @@ -28,6 +29,10 @@ import ( "github.com/benbjohnson/litestream/internal" ) +func init() { + litestream.RegisterReplicaClientFactory("abs", NewReplicaClientFromURL) +} + // ReplicaClientType is the client type for this package. const ReplicaClientType = "abs" @@ -60,6 +65,27 @@ func NewReplicaClient() *ReplicaClient { } } +// NewReplicaClientFromURL creates a new ReplicaClient from URL components. +// This is used by the replica client factory registration. +// URL format: abs://[account-name@]container/path +func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) { + client := NewReplicaClient() + + // Extract account name from userinfo if present (abs://account@container/path) + if userinfo != nil { + client.AccountName = userinfo.Username() + } + + client.Bucket = host + client.Path = urlPath + + if client.Bucket == "" { + return nil, fmt.Errorf("bucket required for abs replica URL") + } + + return client, nil +} + // Type returns "abs" as the client type. func (c *ReplicaClient) Type() string { return ReplicaClientType diff --git a/cmd/litestream-vfs/main.go b/cmd/litestream-vfs/main.go index 3911c79..d7853e8 100644 --- a/cmd/litestream-vfs/main.go +++ b/cmd/litestream-vfs/main.go @@ -16,43 +16,40 @@ import ( "log" "log/slog" "os" - "strconv" "strings" "unsafe" "github.com/psanford/sqlite3vfs" "github.com/benbjohnson/litestream" - "github.com/benbjohnson/litestream/s3" + + // Import all replica backends to register their URL factories. + _ "github.com/benbjohnson/litestream/abs" + _ "github.com/benbjohnson/litestream/file" + _ "github.com/benbjohnson/litestream/gs" + _ "github.com/benbjohnson/litestream/nats" + _ "github.com/benbjohnson/litestream/oss" + _ "github.com/benbjohnson/litestream/s3" + _ "github.com/benbjohnson/litestream/sftp" + _ "github.com/benbjohnson/litestream/webdav" ) func main() {} //export LitestreamVFSRegister func LitestreamVFSRegister() { + var client litestream.ReplicaClient var err error - client := s3.NewReplicaClient() - client.AccessKeyID = os.Getenv("AWS_ACCESS_KEY_ID") - client.SecretAccessKey = os.Getenv("AWS_SECRET_ACCESS_KEY") - client.Region = os.Getenv("LITESTREAM_S3_REGION") - client.Bucket = os.Getenv("LITESTREAM_S3_BUCKET") - client.Path = os.Getenv("LITESTREAM_S3_PATH") - client.Endpoint = os.Getenv("LITESTREAM_S3_ENDPOINT") - if v := os.Getenv("LITESTREAM_S3_FORCE_PATH_STYLE"); v != "" { - if client.ForcePathStyle, err = strconv.ParseBool(v); err != nil { - log.Fatalf("failed to parse LITESTREAM_S3_FORCE_PATH_STYLE: %s", err) - } - } - - if v := os.Getenv("LITESTREAM_S3_SKIP_VERIFY"); v != "" { - if client.SkipVerify, err = strconv.ParseBool(v); err != nil { - log.Fatalf("failed to parse LITESTREAM_S3_SKIP_VERIFY: %s", err) - } + replicaURL := os.Getenv("LITESTREAM_REPLICA_URL") + client, err = litestream.NewReplicaClientFromURL(replicaURL) + if err != nil { + log.Fatalf("failed to create replica client from URL: %s", err) } + // Initialize the client. if err := client.Init(context.Background()); err != nil { - log.Fatalf("failed to initialize litestream s3 client: %s", err) + log.Fatalf("failed to initialize litestream replica client: %s", err) } var level slog.Level diff --git a/cmd/litestream/ltx.go b/cmd/litestream/ltx.go index 5be084a..dfe2090 100644 --- a/cmd/litestream/ltx.go +++ b/cmd/litestream/ltx.go @@ -28,7 +28,7 @@ func (c *LTXCommand) Run(ctx context.Context, args []string) (err error) { } var r *litestream.Replica - if isURL(fs.Arg(0)) { + if litestream.IsURL(fs.Arg(0)) { if *configPath != "" { return fmt.Errorf("cannot specify a replica URL and the -config flag") } diff --git a/cmd/litestream/main.go b/cmd/litestream/main.go index f4a2331..30517f1 100644 --- a/cmd/litestream/main.go +++ b/cmd/litestream/main.go @@ -13,7 +13,6 @@ import ( "os/user" "path" "path/filepath" - "regexp" "strings" "time" @@ -994,7 +993,7 @@ type ReplicaConfig struct { // NewReplicaFromConfig instantiates a replica for a DB based on a config. func NewReplicaFromConfig(c *ReplicaConfig, db *litestream.DB) (_ *litestream.Replica, err error) { // Ensure user did not specify URL in path. - if isURL(c.Path) { + if litestream.IsURL(c.Path) { return nil, fmt.Errorf("replica path cannot be a url, please use the 'url' field instead: %s", c.Path) } @@ -1064,7 +1063,7 @@ func newFileReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ // Parse configPath from URL, if specified. configPath := c.Path if c.URL != "" { - if _, _, configPath, err = ParseReplicaURL(c.URL); err != nil { + if _, _, configPath, err = litestream.ParseReplicaURL(c.URL); err != nil { return nil, err } } @@ -1126,7 +1125,7 @@ func NewS3ReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ *s } if c.URL != "" { - _, host, upath, query, err := ParseReplicaURLWithQuery(c.URL) + _, host, upath, query, _, err := litestream.ParseReplicaURLWithQuery(c.URL) if err != nil { return nil, err } @@ -1140,7 +1139,7 @@ func NewS3ReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ *s if strings.HasPrefix(host, "arn:") { ubucket = host - uregion = regionFromS3ARN(host) + uregion = litestream.RegionFromS3ARN(host) } else { ubucket, uregion, uendpoint, uforcePathStyle = s3.ParseHost(host) } @@ -1168,11 +1167,11 @@ func NewS3ReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ *s if qSkipVerify := query.Get("skipVerify"); qSkipVerify != "" { skipVerify = qSkipVerify == "true" } - if v, ok := boolQueryValue(query, "signPayload", "sign-payload"); ok { + if v, ok := litestream.BoolQueryValue(query, "signPayload", "sign-payload"); ok { usignPayload = v usignPayloadSet = true } - if v, ok := boolQueryValue(query, "requireContentMD5", "require-content-md5"); ok { + if v, ok := litestream.BoolQueryValue(query, "requireContentMD5", "require-content-md5"); ok { urequireContentMD5 = v urequireContentMD5Set = true } @@ -1206,8 +1205,8 @@ func NewS3ReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ *s return nil, fmt.Errorf("bucket required for s3 replica") } - isTigris := isTigrisEndpoint(endpoint) - if !isTigris && !endpointWasSet && isTigrisEndpoint(c.Endpoint) { + isTigris := litestream.IsTigrisEndpoint(endpoint) + if !isTigris && !endpointWasSet && litestream.IsTigrisEndpoint(c.Endpoint) { isTigris = true } @@ -1253,7 +1252,7 @@ func newGSReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ *g // Apply settings from URL, if specified. if c.URL != "" { - _, uhost, upath, err := ParseReplicaURL(c.URL) + _, uhost, upath, err := litestream.ParseReplicaURL(c.URL) if err != nil { return nil, err } @@ -1438,7 +1437,7 @@ func newNATSReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ // Parse URL if provided to extract bucket name and server URL var url, bucket string if c.URL != "" { - scheme, host, bucketPath, err := ParseReplicaURL(c.URL) + scheme, host, bucketPath, err := litestream.ParseReplicaURL(c.URL) if err != nil { return nil, fmt.Errorf("invalid NATS URL: %w", err) } @@ -1520,7 +1519,7 @@ func newOSSReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ * // Apply settings from URL, if specified. if c.URL != "" { - _, host, upath, err := ParseReplicaURL(c.URL) + _, host, upath, err := litestream.ParseReplicaURL(c.URL) if err != nil { return nil, err } @@ -1585,128 +1584,6 @@ func applyLitestreamEnv() { } } -// ParseReplicaURL parses a replica URL. -func ParseReplicaURL(s string) (scheme, host, urlpath string, err error) { - if strings.HasPrefix(strings.ToLower(s), "s3://arn:") { - return parseS3AccessPointURL(s) - } - - scheme, host, urlpath, _, err = ParseReplicaURLWithQuery(s) - return scheme, host, urlpath, err -} - -// ParseReplicaURLWithQuery parses a replica URL and returns query parameters. -func ParseReplicaURLWithQuery(s string) (scheme, host, urlpath string, query url.Values, err error) { - // Handle S3 Access Point ARNs which can't be parsed by standard url.Parse - if strings.HasPrefix(strings.ToLower(s), "s3://arn:") { - scheme, host, urlpath, err := parseS3AccessPointURL(s) - return scheme, host, urlpath, nil, err - } - - u, err := url.Parse(s) - if err != nil { - return "", "", "", nil, err - } - - switch u.Scheme { - case "file": - scheme, u.Scheme = u.Scheme, "" - // Remove query params from path for file URLs - u.RawQuery = "" - return scheme, "", path.Clean(u.String()), nil, nil - - case "": - return u.Scheme, u.Host, u.Path, nil, fmt.Errorf("replica url scheme required: %s", s) - - default: - return u.Scheme, u.Host, strings.TrimPrefix(path.Clean(u.Path), "/"), u.Query(), nil - } -} - -func parseS3AccessPointURL(s string) (scheme, host, urlpath string, err error) { - const prefix = "s3://" - if !strings.HasPrefix(strings.ToLower(s), prefix) { - return "", "", "", fmt.Errorf("invalid s3 access point url: %s", s) - } - - arnWithPath := s[len(prefix):] - bucket, key, err := splitS3AccessPointARN(arnWithPath) - if err != nil { - return "", "", "", err - } - - return "s3", bucket, cleanReplicaURLPath(key), nil -} - -func splitS3AccessPointARN(s string) (bucket, key string, err error) { - lower := strings.ToLower(s) - const marker = ":accesspoint/" - idx := strings.Index(lower, marker) - if idx == -1 { - return "", "", fmt.Errorf("invalid s3 access point arn: %s", s) - } - - nameStart := idx + len(marker) - if nameStart >= len(s) { - return "", "", fmt.Errorf("invalid s3 access point arn: %s", s) - } - - remainder := s[nameStart:] - slashIdx := strings.IndexByte(remainder, '/') - if slashIdx == -1 { - return s, "", nil - } - - bucketEnd := nameStart + slashIdx - bucket = s[:bucketEnd] - key = remainder[slashIdx+1:] - return bucket, key, nil -} - -func cleanReplicaURLPath(p string) string { - if p == "" { - return "" - } - cleaned := path.Clean("/" + p) - cleaned = strings.TrimPrefix(cleaned, "/") - if cleaned == "." { - return "" - } - return cleaned -} - -func boolQueryValue(query url.Values, keys ...string) (bool, bool) { - if query == nil { - return false, false - } - for _, key := range keys { - if raw := query.Get(key); raw != "" { - switch strings.ToLower(raw) { - case "true", "1", "t", "yes": - return true, true - case "false", "0", "f", "no": - return false, true - default: - return false, true - } - } - } - return false, false -} - -func isTigrisEndpoint(endpoint string) bool { - endpoint = strings.TrimSpace(strings.ToLower(endpoint)) - if endpoint == "" { - return false - } - if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") { - if u, err := url.Parse(endpoint); err == nil && u.Host != "" { - endpoint = u.Host - } - } - return endpoint == "fly.storage.tigris.dev" -} - type boolSetting struct { value bool set bool @@ -1727,27 +1604,10 @@ func (s *boolSetting) ApplyDefault(value bool) { } } -func regionFromS3ARN(arn string) string { - parts := strings.SplitN(arn, ":", 6) - if len(parts) >= 4 { - return parts[3] - } - return "" -} - -// isURL returns true if s can be parsed and has a scheme. -func isURL(s string) bool { - return regexp.MustCompile(`^\w+:\/\/`).MatchString(s) -} - // ReplicaType returns the type based on the type field or extracted from the URL. func (c *ReplicaConfig) ReplicaType() string { - scheme, _, _, _ := ParseReplicaURL(c.URL) - if scheme != "" { - if scheme == "webdavs" { - return "webdav" - } - return scheme + if replicaType := litestream.ReplicaTypeFromURL(c.URL); replicaType != "" { + return replicaType } else if c.Type != "" { return c.Type } diff --git a/cmd/litestream/main_test.go b/cmd/litestream/main_test.go index 409d31a..6cb1029 100644 --- a/cmd/litestream/main_test.go +++ b/cmd/litestream/main_test.go @@ -496,7 +496,7 @@ snapshot: func TestParseReplicaURL_AccessPoint(t *testing.T) { t.Run("WithPrefix", func(t *testing.T) { - scheme, host, urlPath, err := main.ParseReplicaURL("s3://arn:aws:s3:us-east-1:123456789012:accesspoint/db-access/backups/prod") + scheme, host, urlPath, err := litestream.ParseReplicaURL("s3://arn:aws:s3:us-east-1:123456789012:accesspoint/db-access/backups/prod") if err != nil { t.Fatal(err) } @@ -512,7 +512,7 @@ func TestParseReplicaURL_AccessPoint(t *testing.T) { }) t.Run("Invalid", func(t *testing.T) { - if _, _, _, err := main.ParseReplicaURL("s3://arn:aws:s3:us-east-1:123456789012:accesspoint/"); err == nil { + if _, _, _, err := litestream.ParseReplicaURL("s3://arn:aws:s3:us-east-1:123456789012:accesspoint/"); err == nil { t.Fatal("expected error") } }) @@ -1498,7 +1498,7 @@ func TestFindSQLiteDatabases(t *testing.T) { func TestParseReplicaURLWithQuery(t *testing.T) { t.Run("S3WithEndpoint", func(t *testing.T) { url := "s3://mybucket/path/to/db?endpoint=localhost:9000®ion=us-east-1&forcePathStyle=true" - scheme, host, path, query, err := main.ParseReplicaURLWithQuery(url) + scheme, host, path, query, _, err := litestream.ParseReplicaURLWithQuery(url) if err != nil { t.Fatal(err) } @@ -1524,7 +1524,7 @@ func TestParseReplicaURLWithQuery(t *testing.T) { t.Run("S3WithoutQuery", func(t *testing.T) { url := "s3://mybucket/path/to/db" - scheme, host, path, query, err := main.ParseReplicaURLWithQuery(url) + scheme, host, path, query, _, err := litestream.ParseReplicaURLWithQuery(url) if err != nil { t.Fatal(err) } @@ -1544,7 +1544,7 @@ func TestParseReplicaURLWithQuery(t *testing.T) { t.Run("FileURL", func(t *testing.T) { url := "file:///path/to/db" - scheme, host, path, query, err := main.ParseReplicaURLWithQuery(url) + scheme, host, path, query, _, err := litestream.ParseReplicaURLWithQuery(url) if err != nil { t.Fatal(err) } @@ -1565,7 +1565,7 @@ func TestParseReplicaURLWithQuery(t *testing.T) { t.Run("BackwardCompatibility", func(t *testing.T) { // Test that ParseReplicaURL still works as before url := "s3://mybucket/path/to/db?endpoint=localhost:9000" - scheme, host, path, err := main.ParseReplicaURL(url) + scheme, host, path, err := litestream.ParseReplicaURL(url) if err != nil { t.Fatal(err) } @@ -1582,7 +1582,7 @@ func TestParseReplicaURLWithQuery(t *testing.T) { t.Run("S3TigrisExample", func(t *testing.T) { url := "s3://mybucket/db?endpoint=fly.storage.tigris.dev®ion=auto" - scheme, host, path, query, err := main.ParseReplicaURLWithQuery(url) + scheme, host, path, query, _, err := litestream.ParseReplicaURLWithQuery(url) if err != nil { t.Fatal(err) } @@ -1605,7 +1605,7 @@ func TestParseReplicaURLWithQuery(t *testing.T) { t.Run("S3WithSkipVerify", func(t *testing.T) { url := "s3://mybucket/db?endpoint=self-signed.local&skipVerify=true" - _, _, _, query, err := main.ParseReplicaURLWithQuery(url) + _, _, _, query, _, err := litestream.ParseReplicaURLWithQuery(url) if err != nil { t.Fatal(err) } diff --git a/cmd/litestream/restore.go b/cmd/litestream/restore.go index e0bfa15..0dbc4ff 100644 --- a/cmd/litestream/restore.go +++ b/cmd/litestream/restore.go @@ -46,7 +46,7 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) { // Determine replica to restore from. var r *litestream.Replica - if isURL(fs.Arg(0)) { + if litestream.IsURL(fs.Arg(0)) { if *configPath != "" { return fmt.Errorf("cannot specify a replica URL and the -config flag") } diff --git a/db_internal_test.go b/db_internal_test.go index 3afa5d2..ddb9ddf 100644 --- a/db_internal_test.go +++ b/db_internal_test.go @@ -20,6 +20,8 @@ type testReplicaClient struct { dir string } +func (c *testReplicaClient) Init(_ context.Context) error { return nil } + func (c *testReplicaClient) Type() string { return "test" } func (c *testReplicaClient) LTXFiles(_ context.Context, _ int, _ ltx.TXID, _ bool) (ltx.FileIterator, error) { @@ -59,6 +61,8 @@ type errorReplicaClient struct { writeErr error } +func (c *errorReplicaClient) Init(_ context.Context) error { return nil } + func (c *errorReplicaClient) Type() string { return "error" } func (c *errorReplicaClient) LTXFiles(_ context.Context, _ int, _ ltx.TXID, _ bool) (ltx.FileIterator, error) { diff --git a/file/replica_client.go b/file/replica_client.go index 4e5f95f..3f04d4d 100644 --- a/file/replica_client.go +++ b/file/replica_client.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "log/slog" + "net/url" "os" "path/filepath" "time" @@ -16,6 +17,10 @@ import ( "github.com/benbjohnson/litestream/internal" ) +func init() { + litestream.RegisterReplicaClientFactory("file", NewReplicaClientFromURL) +} + // ReplicaClientType is the client type for this package. const ReplicaClientType = "file" @@ -37,6 +42,16 @@ func NewReplicaClient(path string) *ReplicaClient { } } +// NewReplicaClientFromURL creates a new ReplicaClient from URL components. +// This is used by the replica client factory registration. +func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) { + // For file URLs, the path is the full path + if urlPath == "" { + return nil, fmt.Errorf("file replica path required") + } + return NewReplicaClient(urlPath), nil +} + // db returns the database, if available. func (c *ReplicaClient) db() *litestream.DB { if c.Replica == nil { @@ -50,6 +65,11 @@ func (c *ReplicaClient) Type() string { return ReplicaClientType } +// Init is a no-op for file replica client as no initialization is required. +func (c *ReplicaClient) Init(ctx context.Context) error { + return nil +} + // Path returns the destination path to replicate the database to. func (c *ReplicaClient) Path() string { return c.path diff --git a/gs/replica_client.go b/gs/replica_client.go index 9ae6902..927e5d4 100644 --- a/gs/replica_client.go +++ b/gs/replica_client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "log/slog" + "net/url" "os" "path" "sync" @@ -20,6 +21,10 @@ import ( "github.com/benbjohnson/litestream/internal" ) +func init() { + litestream.RegisterReplicaClientFactory("gs", NewReplicaClientFromURL) +} + // ReplicaClientType is the client type for this package. const ReplicaClientType = "gs" @@ -47,6 +52,19 @@ func NewReplicaClient() *ReplicaClient { } } +// NewReplicaClientFromURL creates a new ReplicaClient from URL components. +// This is used by the replica client factory registration. +func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) { + if host == "" { + return nil, fmt.Errorf("bucket required for gs replica URL") + } + + client := NewReplicaClient() + client.Bucket = host + client.Path = urlPath + return client, nil +} + // Type returns "gs" as the client type. func (c *ReplicaClient) Type() string { return ReplicaClientType diff --git a/mock/replica_client.go b/mock/replica_client.go index 43d6300..6e610fb 100644 --- a/mock/replica_client.go +++ b/mock/replica_client.go @@ -12,6 +12,7 @@ import ( var _ litestream.ReplicaClient = (*ReplicaClient)(nil) type ReplicaClient struct { + InitFunc func(ctx context.Context) error DeleteAllFunc func(ctx context.Context) error LTXFilesFunc func(ctx context.Context, level int, seek ltx.TXID, useMetadata bool) (ltx.FileIterator, error) OpenLTXFileFunc func(ctx context.Context, level int, minTXID, maxTXID ltx.TXID, offset, size int64) (io.ReadCloser, error) @@ -21,6 +22,13 @@ type ReplicaClient struct { func (c *ReplicaClient) Type() string { return "mock" } +func (c *ReplicaClient) Init(ctx context.Context) error { + if c.InitFunc != nil { + return c.InitFunc(ctx) + } + return nil +} + func (c *ReplicaClient) DeleteAll(ctx context.Context) error { return c.DeleteAllFunc(ctx) } diff --git a/nats/replica_client.go b/nats/replica_client.go index a727fb7..3989c94 100644 --- a/nats/replica_client.go +++ b/nats/replica_client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "log/slog" + "net/url" "os" "sort" "strconv" @@ -22,6 +23,10 @@ import ( "github.com/benbjohnson/litestream/internal" ) +func init() { + litestream.RegisterReplicaClientFactory("nats", NewReplicaClientFromURL) +} + // ReplicaClientType is the client type for this package. const ReplicaClientType = "nats" @@ -84,6 +89,33 @@ func NewReplicaClient() *ReplicaClient { } } +// NewReplicaClientFromURL creates a new ReplicaClient from URL components. +// This is used by the replica client factory registration. +// URL format: nats://[user:pass@]host[:port]/bucket +func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) { + client := NewReplicaClient() + + // Reconstruct URL without bucket path + if host != "" { + client.URL = fmt.Sprintf("nats://%s", host) + } + + // Extract credentials from userinfo if present + if userinfo != nil { + client.Username = userinfo.Username() + client.Password, _ = userinfo.Password() + } + + // Extract bucket name from path + bucket := strings.Trim(urlPath, "/") + if bucket == "" { + return nil, fmt.Errorf("bucket required for nats replica URL") + } + client.BucketName = bucket + + return client, nil +} + // Type returns "nats" as the client type. func (c *ReplicaClient) Type() string { return ReplicaClientType diff --git a/oss/replica_client.go b/oss/replica_client.go index 4881a0f..e87d07d 100644 --- a/oss/replica_client.go +++ b/oss/replica_client.go @@ -23,6 +23,10 @@ import ( "github.com/benbjohnson/litestream/internal" ) +func init() { + litestream.RegisterReplicaClientFactory("oss", NewReplicaClientFromURL) +} + // ReplicaClientType is the client type for this package. const ReplicaClientType = "oss" @@ -67,6 +71,24 @@ func NewReplicaClient() *ReplicaClient { } } +// NewReplicaClientFromURL creates a new ReplicaClient from URL components. +// This is used by the replica client factory registration. +// URL format: oss://bucket[.oss-region.aliyuncs.com]/path +func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) { + client := NewReplicaClient() + + bucket, region, _ := ParseHost(host) + if bucket == "" { + return nil, fmt.Errorf("bucket required for oss replica URL") + } + + client.Bucket = bucket + client.Region = region + client.Path = urlPath + + return client, nil +} + // Type returns "oss" as the client type. func (c *ReplicaClient) Type() string { return ReplicaClientType diff --git a/replica_client.go b/replica_client.go index ef46620..770e394 100644 --- a/replica_client.go +++ b/replica_client.go @@ -19,6 +19,11 @@ type ReplicaClient interface { // Type returns the type of client. Type() string + // Init initializes the replica client connection. + // This may establish connections, validate configuration, etc. + // Implementations should be idempotent (no-op if already initialized). + Init(ctx context.Context) error + // LTXFiles returns an iterator of all LTX files on the replica for a given level. // If seek is specified, the iterator start from the given TXID or the next available if not found. // If useMetadata is true, the iterator fetches accurate timestamps from metadata for timestamp-based restore. diff --git a/replica_client_test.go b/replica_client_test.go index d54b895..e71a73b 100644 --- a/replica_client_test.go +++ b/replica_client_test.go @@ -435,7 +435,7 @@ AAAEDzV1D6COyvFGhSiZa6ll9aXZ2IMWED3KGrvCNjEEtYHwnK0+GdwOelXlAXdqLx/qvS c.Host = addr c.HostKey = expectedHostKey - _, err = c.Init(context.Background()) + err = c.Init(context.Background()) if err != nil { t.Fatalf("SFTP connection failed: %v", err) } @@ -449,7 +449,7 @@ AAAEDzV1D6COyvFGhSiZa6ll9aXZ2IMWED3KGrvCNjEEtYHwnK0+GdwOelXlAXdqLx/qvS c.Host = addr c.HostKey = invalidHostKey - _, err = c.Init(context.Background()) + err = c.Init(context.Background()) if err == nil { t.Fatalf("SFTP connection established despite invalid host key") } @@ -475,7 +475,7 @@ AAAEDzV1D6COyvFGhSiZa6ll9aXZ2IMWED3KGrvCNjEEtYHwnK0+GdwOelXlAXdqLx/qvS c.User = "foo" c.Host = addr - _, err = c.Init(context.Background()) + err = c.Init(context.Background()) if err != nil { t.Fatalf("SFTP connection failed: %v", err) } diff --git a/replica_url.go b/replica_url.go new file mode 100644 index 0000000..87051f0 --- /dev/null +++ b/replica_url.go @@ -0,0 +1,209 @@ +package litestream + +import ( + "fmt" + "net/url" + "path" + "regexp" + "strings" + "sync" +) + +// ReplicaClientFactory is a function that creates a ReplicaClient from URL components. +// The userinfo parameter contains credentials from the URL (e.g., user:pass@host). +type ReplicaClientFactory func(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (ReplicaClient, error) + +var ( + replicaClientFactories = make(map[string]ReplicaClientFactory) + replicaClientFactoriesMu sync.RWMutex +) + +// RegisterReplicaClientFactory registers a factory function for creating replica clients +// for a given URL scheme. This is typically called from init() functions in backend packages. +func RegisterReplicaClientFactory(scheme string, factory ReplicaClientFactory) { + replicaClientFactoriesMu.Lock() + defer replicaClientFactoriesMu.Unlock() + replicaClientFactories[scheme] = factory +} + +// NewReplicaClientFromURL creates a new ReplicaClient from a URL string. +// The URL scheme determines which backend is used (s3, gs, abs, file, etc.). +func NewReplicaClientFromURL(rawURL string) (ReplicaClient, error) { + scheme, host, urlPath, query, userinfo, err := ParseReplicaURLWithQuery(rawURL) + if err != nil { + return nil, err + } + + // Normalize webdavs to webdav + factoryScheme := scheme + if factoryScheme == "webdavs" { + factoryScheme = "webdav" + } + + replicaClientFactoriesMu.RLock() + factory, ok := replicaClientFactories[factoryScheme] + replicaClientFactoriesMu.RUnlock() + + if !ok { + return nil, fmt.Errorf("unsupported replica URL scheme: %q", scheme) + } + + return factory(scheme, host, urlPath, query, userinfo) +} + +// ReplicaTypeFromURL returns the replica type from a URL string. +// Returns empty string if the URL is invalid or has no scheme. +func ReplicaTypeFromURL(rawURL string) string { + scheme, _, _, _ := ParseReplicaURL(rawURL) + if scheme == "" { + return "" + } + if scheme == "webdavs" { + return "webdav" + } + return scheme +} + +// ParseReplicaURL parses a replica URL and returns the scheme, host, and path. +func ParseReplicaURL(s string) (scheme, host, urlPath string, err error) { + if strings.HasPrefix(strings.ToLower(s), "s3://arn:") { + return parseS3AccessPointURL(s) + } + + scheme, host, urlPath, _, _, err = ParseReplicaURLWithQuery(s) + return scheme, host, urlPath, err +} + +// ParseReplicaURLWithQuery parses a replica URL and returns query parameters and userinfo. +func ParseReplicaURLWithQuery(s string) (scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo, err error) { + // Handle S3 Access Point ARNs which can't be parsed by standard url.Parse + if strings.HasPrefix(strings.ToLower(s), "s3://arn:") { + scheme, host, urlPath, err := parseS3AccessPointURL(s) + return scheme, host, urlPath, nil, nil, err + } + + u, err := url.Parse(s) + if err != nil { + return "", "", "", nil, nil, err + } + + switch u.Scheme { + case "file": + scheme, u.Scheme = u.Scheme, "" + // Remove query params from path for file URLs + u.RawQuery = "" + return scheme, "", path.Clean(u.String()), nil, nil, nil + + case "": + return u.Scheme, u.Host, u.Path, nil, nil, fmt.Errorf("replica url scheme required: %s", s) + + default: + return u.Scheme, u.Host, strings.TrimPrefix(path.Clean(u.Path), "/"), u.Query(), u.User, nil + } +} + +// parseS3AccessPointURL parses an S3 Access Point URL (s3://arn:...). +func parseS3AccessPointURL(s string) (scheme, host, urlPath string, err error) { + const prefix = "s3://" + if !strings.HasPrefix(strings.ToLower(s), prefix) { + return "", "", "", fmt.Errorf("invalid s3 access point url: %s", s) + } + + arnWithPath := s[len(prefix):] + bucket, key, err := splitS3AccessPointARN(arnWithPath) + if err != nil { + return "", "", "", err + } + + return "s3", bucket, CleanReplicaURLPath(key), nil +} + +// splitS3AccessPointARN splits an S3 Access Point ARN into bucket and key components. +func splitS3AccessPointARN(s string) (bucket, key string, err error) { + lower := strings.ToLower(s) + const marker = ":accesspoint/" + idx := strings.Index(lower, marker) + if idx == -1 { + return "", "", fmt.Errorf("invalid s3 access point arn: %s", s) + } + + nameStart := idx + len(marker) + if nameStart >= len(s) { + return "", "", fmt.Errorf("invalid s3 access point arn: %s", s) + } + + remainder := s[nameStart:] + slashIdx := strings.IndexByte(remainder, '/') + if slashIdx == -1 { + return s, "", nil + } + + bucketEnd := nameStart + slashIdx + bucket = s[:bucketEnd] + key = remainder[slashIdx+1:] + return bucket, key, nil +} + +// CleanReplicaURLPath cleans a URL path for use in replica storage. +func CleanReplicaURLPath(p string) string { + if p == "" { + return "" + } + cleaned := path.Clean("/" + p) + cleaned = strings.TrimPrefix(cleaned, "/") + if cleaned == "." { + return "" + } + return cleaned +} + +// RegionFromS3ARN extracts the region from an S3 ARN. +func RegionFromS3ARN(arn string) string { + parts := strings.SplitN(arn, ":", 6) + if len(parts) >= 4 { + return parts[3] + } + return "" +} + +// BoolQueryValue returns a boolean value from URL query parameters. +// It checks multiple keys in order and returns the value and whether it was set. +func BoolQueryValue(query url.Values, keys ...string) (value bool, ok bool) { + if query == nil { + return false, false + } + for _, key := range keys { + if raw := query.Get(key); raw != "" { + switch strings.ToLower(raw) { + case "true", "1", "t", "yes": + return true, true + case "false", "0", "f", "no": + return false, true + default: + return false, true + } + } + } + return false, false +} + +// IsTigrisEndpoint returns true if the endpoint is the Tigris object storage service. +func IsTigrisEndpoint(endpoint string) bool { + endpoint = strings.TrimSpace(strings.ToLower(endpoint)) + if endpoint == "" { + return false + } + if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") { + if u, err := url.Parse(endpoint); err == nil && u.Host != "" { + endpoint = u.Host + } + } + return endpoint == "fly.storage.tigris.dev" +} + +// IsURL returns true if s appears to be a URL (has a scheme). +var isURLRegex = regexp.MustCompile(`^\w+:\/\/`) + +func IsURL(s string) bool { + return isURLRegex.MatchString(s) +} diff --git a/replica_url_test.go b/replica_url_test.go new file mode 100644 index 0000000..b77125a --- /dev/null +++ b/replica_url_test.go @@ -0,0 +1,632 @@ +package litestream_test + +import ( + "testing" + + "github.com/benbjohnson/litestream" + "github.com/benbjohnson/litestream/abs" + "github.com/benbjohnson/litestream/file" + "github.com/benbjohnson/litestream/gs" + "github.com/benbjohnson/litestream/nats" + "github.com/benbjohnson/litestream/oss" + "github.com/benbjohnson/litestream/s3" + "github.com/benbjohnson/litestream/sftp" + "github.com/benbjohnson/litestream/webdav" +) + +func TestNewReplicaClientFromURL(t *testing.T) { + t.Run("S3", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("s3://mybucket/path/to/db") + if err != nil { + t.Fatal(err) + } + if client.Type() != "s3" { + t.Errorf("expected type 's3', got %q", client.Type()) + } + s3Client, ok := client.(*s3.ReplicaClient) + if !ok { + t.Fatalf("expected *s3.ReplicaClient, got %T", client) + } + if s3Client.Bucket != "mybucket" { + t.Errorf("expected bucket 'mybucket', got %q", s3Client.Bucket) + } + if s3Client.Path != "path/to/db" { + t.Errorf("expected path 'path/to/db', got %q", s3Client.Path) + } + }) + + t.Run("S3WithQueryParams", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("s3://mybucket/db?endpoint=localhost:9000®ion=us-west-2") + if err != nil { + t.Fatal(err) + } + s3Client, ok := client.(*s3.ReplicaClient) + if !ok { + t.Fatalf("expected *s3.ReplicaClient, got %T", client) + } + if s3Client.Endpoint != "http://localhost:9000" { + t.Errorf("expected endpoint 'http://localhost:9000', got %q", s3Client.Endpoint) + } + if s3Client.Region != "us-west-2" { + t.Errorf("expected region 'us-west-2', got %q", s3Client.Region) + } + }) + + t.Run("File", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("file:///tmp/replica") + if err != nil { + t.Fatal(err) + } + if client.Type() != "file" { + t.Errorf("expected type 'file', got %q", client.Type()) + } + fileClient, ok := client.(*file.ReplicaClient) + if !ok { + t.Fatalf("expected *file.ReplicaClient, got %T", client) + } + if fileClient.Path() != "/tmp/replica" { + t.Errorf("expected path '/tmp/replica', got %q", fileClient.Path()) + } + }) + + t.Run("GS", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("gs://mybucket/path") + if err != nil { + t.Fatal(err) + } + if client.Type() != "gs" { + t.Errorf("expected type 'gs', got %q", client.Type()) + } + gsClient, ok := client.(*gs.ReplicaClient) + if !ok { + t.Fatalf("expected *gs.ReplicaClient, got %T", client) + } + if gsClient.Bucket != "mybucket" { + t.Errorf("expected bucket 'mybucket', got %q", gsClient.Bucket) + } + if gsClient.Path != "path" { + t.Errorf("expected path 'path', got %q", gsClient.Path) + } + }) + + t.Run("GS_MissingBucket", func(t *testing.T) { + _, err := litestream.NewReplicaClientFromURL("gs:///path") + if err == nil { + t.Fatal("expected error for missing bucket") + } + }) + + t.Run("ABS", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("abs://mycontainer/path") + if err != nil { + t.Fatal(err) + } + if client.Type() != "abs" { + t.Errorf("expected type 'abs', got %q", client.Type()) + } + absClient, ok := client.(*abs.ReplicaClient) + if !ok { + t.Fatalf("expected *abs.ReplicaClient, got %T", client) + } + if absClient.Bucket != "mycontainer" { + t.Errorf("expected bucket 'mycontainer', got %q", absClient.Bucket) + } + if absClient.Path != "path" { + t.Errorf("expected path 'path', got %q", absClient.Path) + } + }) + + t.Run("ABS_WithAccount", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("abs://myaccount@mycontainer/path") + if err != nil { + t.Fatal(err) + } + absClient, ok := client.(*abs.ReplicaClient) + if !ok { + t.Fatalf("expected *abs.ReplicaClient, got %T", client) + } + if absClient.AccountName != "myaccount" { + t.Errorf("expected account 'myaccount', got %q", absClient.AccountName) + } + if absClient.Bucket != "mycontainer" { + t.Errorf("expected bucket 'mycontainer', got %q", absClient.Bucket) + } + }) + + t.Run("ABS_MissingBucket", func(t *testing.T) { + _, err := litestream.NewReplicaClientFromURL("abs:///path") + if err == nil { + t.Fatal("expected error for missing bucket") + } + }) + + t.Run("SFTP", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("sftp://myuser@host.example.com/path") + if err != nil { + t.Fatal(err) + } + if client.Type() != "sftp" { + t.Errorf("expected type 'sftp', got %q", client.Type()) + } + sftpClient, ok := client.(*sftp.ReplicaClient) + if !ok { + t.Fatalf("expected *sftp.ReplicaClient, got %T", client) + } + if sftpClient.Host != "host.example.com" { + t.Errorf("expected host 'host.example.com', got %q", sftpClient.Host) + } + if sftpClient.User != "myuser" { + t.Errorf("expected user 'myuser', got %q", sftpClient.User) + } + if sftpClient.Path != "path" { + t.Errorf("expected path 'path', got %q", sftpClient.Path) + } + }) + + t.Run("SFTP_WithPassword", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("sftp://myuser:secret@host.example.com/path") + if err != nil { + t.Fatal(err) + } + sftpClient, ok := client.(*sftp.ReplicaClient) + if !ok { + t.Fatalf("expected *sftp.ReplicaClient, got %T", client) + } + if sftpClient.User != "myuser" { + t.Errorf("expected user 'myuser', got %q", sftpClient.User) + } + if sftpClient.Password != "secret" { + t.Errorf("expected password 'secret', got %q", sftpClient.Password) + } + }) + + t.Run("SFTP_RequiresUserError", func(t *testing.T) { + _, err := litestream.NewReplicaClientFromURL("sftp://host.example.com/path") + if err == nil { + t.Fatal("expected error for missing user") + } + }) + + t.Run("SFTP_MissingHost", func(t *testing.T) { + _, err := litestream.NewReplicaClientFromURL("sftp:///path") + if err == nil { + t.Fatal("expected error for missing host") + } + }) + + t.Run("WebDAV", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("webdav://host.example.com/path") + if err != nil { + t.Fatal(err) + } + if client.Type() != "webdav" { + t.Errorf("expected type 'webdav', got %q", client.Type()) + } + webdavClient, ok := client.(*webdav.ReplicaClient) + if !ok { + t.Fatalf("expected *webdav.ReplicaClient, got %T", client) + } + if webdavClient.URL != "http://host.example.com" { + t.Errorf("expected URL 'http://host.example.com', got %q", webdavClient.URL) + } + if webdavClient.Path != "path" { + t.Errorf("expected path 'path', got %q", webdavClient.Path) + } + }) + + t.Run("WebDAVS", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("webdavs://host.example.com/path") + if err != nil { + t.Fatal(err) + } + if client.Type() != "webdav" { + t.Errorf("expected type 'webdav', got %q", client.Type()) + } + webdavClient, ok := client.(*webdav.ReplicaClient) + if !ok { + t.Fatalf("expected *webdav.ReplicaClient, got %T", client) + } + if webdavClient.URL != "https://host.example.com" { + t.Errorf("expected URL 'https://host.example.com', got %q", webdavClient.URL) + } + }) + + t.Run("WebDAV_WithCredentials", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("webdav://myuser:secret@host.example.com/path") + if err != nil { + t.Fatal(err) + } + webdavClient, ok := client.(*webdav.ReplicaClient) + if !ok { + t.Fatalf("expected *webdav.ReplicaClient, got %T", client) + } + if webdavClient.Username != "myuser" { + t.Errorf("expected username 'myuser', got %q", webdavClient.Username) + } + if webdavClient.Password != "secret" { + t.Errorf("expected password 'secret', got %q", webdavClient.Password) + } + if webdavClient.URL != "http://host.example.com" { + t.Errorf("expected URL 'http://host.example.com', got %q", webdavClient.URL) + } + }) + + t.Run("WebDAVS_WithCredentials", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("webdavs://myuser:secret@host.example.com/path") + if err != nil { + t.Fatal(err) + } + webdavClient, ok := client.(*webdav.ReplicaClient) + if !ok { + t.Fatalf("expected *webdav.ReplicaClient, got %T", client) + } + if webdavClient.Username != "myuser" { + t.Errorf("expected username 'myuser', got %q", webdavClient.Username) + } + if webdavClient.Password != "secret" { + t.Errorf("expected password 'secret', got %q", webdavClient.Password) + } + if webdavClient.URL != "https://host.example.com" { + t.Errorf("expected URL 'https://host.example.com', got %q", webdavClient.URL) + } + }) + + t.Run("WebDAV_MissingHost", func(t *testing.T) { + _, err := litestream.NewReplicaClientFromURL("webdav:///path") + if err == nil { + t.Fatal("expected error for missing host") + } + }) + + t.Run("NATS", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("nats://localhost:4222/mybucket") + if err != nil { + t.Fatal(err) + } + if client.Type() != "nats" { + t.Errorf("expected type 'nats', got %q", client.Type()) + } + natsClient, ok := client.(*nats.ReplicaClient) + if !ok { + t.Fatalf("expected *nats.ReplicaClient, got %T", client) + } + if natsClient.URL != "nats://localhost:4222" { + t.Errorf("expected URL 'nats://localhost:4222', got %q", natsClient.URL) + } + if natsClient.BucketName != "mybucket" { + t.Errorf("expected bucket 'mybucket', got %q", natsClient.BucketName) + } + }) + + t.Run("NATS_WithCredentials", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("nats://myuser:secret@localhost:4222/mybucket") + if err != nil { + t.Fatal(err) + } + natsClient, ok := client.(*nats.ReplicaClient) + if !ok { + t.Fatalf("expected *nats.ReplicaClient, got %T", client) + } + if natsClient.Username != "myuser" { + t.Errorf("expected username 'myuser', got %q", natsClient.Username) + } + if natsClient.Password != "secret" { + t.Errorf("expected password 'secret', got %q", natsClient.Password) + } + if natsClient.URL != "nats://localhost:4222" { + t.Errorf("expected URL 'nats://localhost:4222', got %q", natsClient.URL) + } + }) + + t.Run("NATS_MissingBucket", func(t *testing.T) { + _, err := litestream.NewReplicaClientFromURL("nats://localhost:4222/") + if err == nil { + t.Fatal("expected error for missing bucket") + } + }) + + t.Run("OSS", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("oss://mybucket/path") + if err != nil { + t.Fatal(err) + } + if client.Type() != "oss" { + t.Errorf("expected type 'oss', got %q", client.Type()) + } + ossClient, ok := client.(*oss.ReplicaClient) + if !ok { + t.Fatalf("expected *oss.ReplicaClient, got %T", client) + } + if ossClient.Bucket != "mybucket" { + t.Errorf("expected bucket 'mybucket', got %q", ossClient.Bucket) + } + if ossClient.Path != "path" { + t.Errorf("expected path 'path', got %q", ossClient.Path) + } + }) + + t.Run("OSS_WithRegion", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("oss://mybucket.oss-cn-shanghai.aliyuncs.com/path") + if err != nil { + t.Fatal(err) + } + ossClient, ok := client.(*oss.ReplicaClient) + if !ok { + t.Fatalf("expected *oss.ReplicaClient, got %T", client) + } + if ossClient.Bucket != "mybucket" { + t.Errorf("expected bucket 'mybucket', got %q", ossClient.Bucket) + } + // Note: Region is extracted without the 'oss-' prefix + if ossClient.Region != "cn-shanghai" { + t.Errorf("expected region 'cn-shanghai', got %q", ossClient.Region) + } + }) + + t.Run("OSS_MissingBucket", func(t *testing.T) { + _, err := litestream.NewReplicaClientFromURL("oss:///path") + if err == nil { + t.Fatal("expected error for missing bucket") + } + }) + + // Note: file:// with empty path returns "." due to path.Clean behavior. + // This is technically valid but may not be the intended behavior. + t.Run("File_EmptyPathReturnsDot", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("file://") + if err != nil { + t.Fatal(err) + } + fileClient, ok := client.(*file.ReplicaClient) + if !ok { + t.Fatalf("expected *file.ReplicaClient, got %T", client) + } + // path.Clean("") returns "." which passes the empty check + if fileClient.Path() != "." { + t.Errorf("expected path '.', got %q", fileClient.Path()) + } + }) + + t.Run("S3_ARN", func(t *testing.T) { + client, err := litestream.NewReplicaClientFromURL("s3://arn:aws:s3:us-east-1:123456789012:accesspoint/db-access/backups") + if err != nil { + t.Fatal(err) + } + s3Client, ok := client.(*s3.ReplicaClient) + if !ok { + t.Fatalf("expected *s3.ReplicaClient, got %T", client) + } + if s3Client.Bucket != "arn:aws:s3:us-east-1:123456789012:accesspoint/db-access" { + t.Errorf("expected bucket ARN, got %q", s3Client.Bucket) + } + if s3Client.Path != "backups" { + t.Errorf("expected path 'backups', got %q", s3Client.Path) + } + }) + + t.Run("S3_MissingBucket", func(t *testing.T) { + _, err := litestream.NewReplicaClientFromURL("s3:///path") + if err == nil { + t.Fatal("expected error for missing bucket") + } + }) + + t.Run("EmptyURL", func(t *testing.T) { + _, err := litestream.NewReplicaClientFromURL("") + if err == nil { + t.Fatal("expected error for empty URL") + } + }) + + t.Run("UnsupportedScheme", func(t *testing.T) { + _, err := litestream.NewReplicaClientFromURL("unknown://bucket/path") + if err == nil { + t.Fatal("expected error for unsupported scheme") + } + }) + + t.Run("InvalidURL", func(t *testing.T) { + _, err := litestream.NewReplicaClientFromURL("not-a-valid-url") + if err == nil { + t.Fatal("expected error for invalid URL") + } + }) +} + +func TestReplicaTypeFromURL(t *testing.T) { + tests := []struct { + url string + expected string + }{ + {"s3://bucket/path", "s3"}, + {"gs://bucket/path", "gs"}, + {"abs://container/path", "abs"}, + {"file:///path/to/replica", "file"}, + {"sftp://host/path", "sftp"}, + {"webdav://host/path", "webdav"}, + {"webdavs://host/path", "webdav"}, + {"nats://host/bucket", "nats"}, + {"oss://bucket/path", "oss"}, + {"", ""}, + {"invalid", ""}, + } + + for _, tt := range tests { + t.Run(tt.url, func(t *testing.T) { + got := litestream.ReplicaTypeFromURL(tt.url) + if got != tt.expected { + t.Errorf("ReplicaTypeFromURL(%q) = %q, want %q", tt.url, got, tt.expected) + } + }) + } +} + +func TestIsURL(t *testing.T) { + tests := []struct { + s string + expected bool + }{ + {"s3://bucket/path", true}, + {"file:///path", true}, + {"https://example.com", true}, + {"/path/to/file", false}, + {"relative/path", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.s, func(t *testing.T) { + got := litestream.IsURL(tt.s) + if got != tt.expected { + t.Errorf("IsURL(%q) = %v, want %v", tt.s, got, tt.expected) + } + }) + } +} + +func TestBoolQueryValue(t *testing.T) { + t.Run("True values", func(t *testing.T) { + for _, v := range []string{"true", "True", "TRUE", "1", "t", "yes"} { + query := make(map[string][]string) + query["key"] = []string{v} + value, ok := litestream.BoolQueryValue(query, "key") + if !ok { + t.Errorf("BoolQueryValue with %q should be ok", v) + } + if !value { + t.Errorf("BoolQueryValue with %q should be true", v) + } + } + }) + + t.Run("False values", func(t *testing.T) { + for _, v := range []string{"false", "False", "FALSE", "0", "f", "no"} { + query := make(map[string][]string) + query["key"] = []string{v} + value, ok := litestream.BoolQueryValue(query, "key") + if !ok { + t.Errorf("BoolQueryValue with %q should be ok", v) + } + if value { + t.Errorf("BoolQueryValue with %q should be false", v) + } + } + }) + + t.Run("Missing key", func(t *testing.T) { + query := make(map[string][]string) + _, ok := litestream.BoolQueryValue(query, "key") + if ok { + t.Error("BoolQueryValue with missing key should not be ok") + } + }) + + t.Run("Multiple keys", func(t *testing.T) { + query := make(map[string][]string) + query["key2"] = []string{"true"} + value, ok := litestream.BoolQueryValue(query, "key1", "key2") + if !ok { + t.Error("BoolQueryValue should find second key") + } + if !value { + t.Error("BoolQueryValue should return true for second key") + } + }) + + t.Run("Nil query", func(t *testing.T) { + _, ok := litestream.BoolQueryValue(nil, "key") + if ok { + t.Error("BoolQueryValue with nil query should not be ok") + } + }) + + t.Run("Invalid value returns false with ok", func(t *testing.T) { + query := make(map[string][]string) + query["key"] = []string{"invalid"} + value, ok := litestream.BoolQueryValue(query, "key") + if !ok { + t.Error("BoolQueryValue with invalid value should be ok") + } + if value { + t.Error("BoolQueryValue with invalid value should be false") + } + }) +} + +func TestIsTigrisEndpoint(t *testing.T) { + tests := []struct { + endpoint string + expected bool + }{ + {"fly.storage.tigris.dev", true}, + {"FLY.STORAGE.TIGRIS.DEV", true}, + {"https://fly.storage.tigris.dev", true}, + {"http://fly.storage.tigris.dev", true}, + {"s3.amazonaws.com", false}, + {"localhost:9000", false}, + {"", false}, + {" ", false}, + {"https://s3.us-east-1.amazonaws.com", false}, + } + + for _, tt := range tests { + t.Run(tt.endpoint, func(t *testing.T) { + got := litestream.IsTigrisEndpoint(tt.endpoint) + if got != tt.expected { + t.Errorf("IsTigrisEndpoint(%q) = %v, want %v", tt.endpoint, got, tt.expected) + } + }) + } +} + +func TestRegionFromS3ARN(t *testing.T) { + tests := []struct { + arn string + expected string + }{ + {"arn:aws:s3:us-east-1:123456789012:accesspoint/db-access", "us-east-1"}, + {"arn:aws:s3:eu-west-1:123456789012:accesspoint/db-access", "eu-west-1"}, + {"arn:aws:s3:ap-southeast-2:123456789012:accesspoint/db-access", "ap-southeast-2"}, + {"arn:aws:s3::123456789012:accesspoint/db-access", ""}, + {"invalid-arn", ""}, + {"", ""}, + {"arn:aws:s3", ""}, + } + + for _, tt := range tests { + t.Run(tt.arn, func(t *testing.T) { + got := litestream.RegionFromS3ARN(tt.arn) + if got != tt.expected { + t.Errorf("RegionFromS3ARN(%q) = %q, want %q", tt.arn, got, tt.expected) + } + }) + } +} + +func TestCleanReplicaURLPath(t *testing.T) { + tests := []struct { + path string + expected string + }{ + {"", ""}, + {"path", "path"}, + {"/path", "path"}, + {"path/", "path"}, + {"/path/", "path"}, + {"path/to/db", "path/to/db"}, + {"/path/to/db", "path/to/db"}, + {"//path//to//db", "path/to/db"}, + {".", ""}, + {"/.", ""}, + {"./path", "path"}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + got := litestream.CleanReplicaURLPath(tt.path) + if got != tt.expected { + t.Errorf("CleanReplicaURLPath(%q) = %q, want %q", tt.path, got, tt.expected) + } + }) + } +} diff --git a/s3/replica_client.go b/s3/replica_client.go index 2eeb988..e3e4867 100644 --- a/s3/replica_client.go +++ b/s3/replica_client.go @@ -38,6 +38,10 @@ import ( "github.com/benbjohnson/litestream/internal" ) +func init() { + litestream.RegisterReplicaClientFactory("s3", NewReplicaClientFromURL) +} + // ReplicaClientType is the client type for this package. const ReplicaClientType = "s3" @@ -90,6 +94,109 @@ func NewReplicaClient() *ReplicaClient { } } +// NewReplicaClientFromURL creates a new ReplicaClient from URL components. +// This is used by the replica client factory registration. +func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) { + client := NewReplicaClient() + + var ( + bucket string + region string + endpoint string + forcePathStyle bool + skipVerify bool + signPayload bool + signPayloadSet bool + requireMD5 bool + requireMD5Set bool + ) + + // Parse host for bucket and region + if strings.HasPrefix(host, "arn:") { + bucket = host + region = litestream.RegionFromS3ARN(host) + } else { + bucket, region, endpoint, forcePathStyle = ParseHost(host) + } + + // Override with query parameters if provided + if qEndpoint := query.Get("endpoint"); qEndpoint != "" { + // Ensure endpoint has a scheme + if !strings.HasPrefix(qEndpoint, "http://") && !strings.HasPrefix(qEndpoint, "https://") { + qEndpoint = "http://" + qEndpoint + } + endpoint = qEndpoint + // Default to path style for custom endpoints unless explicitly set to false + if query.Get("forcePathStyle") != "false" { + forcePathStyle = true + } + } + if qRegion := query.Get("region"); qRegion != "" { + region = qRegion + } + if qForcePathStyle := query.Get("forcePathStyle"); qForcePathStyle != "" { + forcePathStyle = qForcePathStyle == "true" + } + if qSkipVerify := query.Get("skipVerify"); qSkipVerify != "" { + skipVerify = qSkipVerify == "true" + } + if v, ok := litestream.BoolQueryValue(query, "signPayload", "sign-payload"); ok { + signPayload = v + signPayloadSet = true + } + if v, ok := litestream.BoolQueryValue(query, "requireContentMD5", "require-content-md5"); ok { + requireMD5 = v + requireMD5Set = true + } + + // Ensure bucket is set + if bucket == "" { + return nil, fmt.Errorf("bucket required for s3 replica URL") + } + + // Check for Tigris endpoint + isTigris := litestream.IsTigrisEndpoint(endpoint) + + // Read authentication from environment variables + if v := os.Getenv("AWS_ACCESS_KEY_ID"); v != "" { + client.AccessKeyID = v + } else if v := os.Getenv("LITESTREAM_ACCESS_KEY_ID"); v != "" { + client.AccessKeyID = v + } + if v := os.Getenv("AWS_SECRET_ACCESS_KEY"); v != "" { + client.SecretAccessKey = v + } else if v := os.Getenv("LITESTREAM_SECRET_ACCESS_KEY"); v != "" { + client.SecretAccessKey = v + } + + // Configure client + client.Bucket = bucket + client.Path = urlPath + client.Region = region + client.Endpoint = endpoint + client.ForcePathStyle = forcePathStyle + client.SkipVerify = skipVerify + + // Apply Tigris defaults + if isTigris { + if !signPayloadSet { + signPayload, signPayloadSet = true, true + } + if !requireMD5Set { + requireMD5, requireMD5Set = false, true + } + } + + if signPayloadSet { + client.SignPayload = signPayload + } + if requireMD5Set { + client.RequireContentMD5 = requireMD5 + } + + return client, nil +} + // Type returns "s3" as the client type. func (c *ReplicaClient) Type() string { return ReplicaClientType diff --git a/sftp/replica_client.go b/sftp/replica_client.go index e73b300..12ccd9a 100644 --- a/sftp/replica_client.go +++ b/sftp/replica_client.go @@ -8,6 +8,7 @@ import ( "io" "log/slog" "net" + "net/url" "os" "path" "sync" @@ -21,6 +22,10 @@ import ( "github.com/benbjohnson/litestream/internal" ) +func init() { + litestream.RegisterReplicaClientFactory("sftp", NewReplicaClientFromURL) +} + // ReplicaClientType is the client type for this package. const ReplicaClientType = "sftp" @@ -61,13 +66,44 @@ func NewReplicaClient() *ReplicaClient { } } +// NewReplicaClientFromURL creates a new ReplicaClient from URL components. +// This is used by the replica client factory registration. +// URL format: sftp://[user[:password]@]host[:port]/path +func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) { + client := NewReplicaClient() + + // Extract credentials from userinfo + if userinfo != nil { + client.User = userinfo.Username() + client.Password, _ = userinfo.Password() + } + + client.Host = host + client.Path = urlPath + + if client.Host == "" { + return nil, fmt.Errorf("host required for sftp replica URL") + } + if client.User == "" { + return nil, fmt.Errorf("user required for sftp replica URL") + } + + return client, nil +} + // Type returns "sftp" as the client type. func (c *ReplicaClient) Type() string { return ReplicaClientType } // Init initializes the connection to SFTP. No-op if already initialized. -func (c *ReplicaClient) Init(ctx context.Context) (_ *sftp.Client, err error) { +func (c *ReplicaClient) Init(ctx context.Context) error { + _, err := c.init(ctx) + return err +} + +// init initializes the connection and returns the SFTP client. +func (c *ReplicaClient) init(ctx context.Context) (_ *sftp.Client, err error) { c.mu.Lock() defer c.mu.Unlock() @@ -144,7 +180,7 @@ func (c *ReplicaClient) Init(ctx context.Context) (_ *sftp.Client, err error) { func (c *ReplicaClient) DeleteAll(ctx context.Context) (err error) { defer func() { c.resetOnConnError(err) }() - sftpClient, err := c.Init(ctx) + sftpClient, err := c.init(ctx) if err != nil { return err } @@ -188,7 +224,7 @@ func (c *ReplicaClient) DeleteAll(ctx context.Context) (err error) { func (c *ReplicaClient) LTXFiles(ctx context.Context, level int, seek ltx.TXID, _ bool) (_ ltx.FileIterator, err error) { defer func() { c.resetOnConnError(err) }() - sftpClient, err := c.Init(ctx) + sftpClient, err := c.init(ctx) if err != nil { return nil, err } @@ -227,7 +263,7 @@ func (c *ReplicaClient) LTXFiles(ctx context.Context, level int, seek ltx.TXID, func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, maxTXID ltx.TXID, rd io.Reader) (info *ltx.FileInfo, err error) { defer func() { c.resetOnConnError(err) }() - sftpClient, err := c.Init(ctx) + sftpClient, err := c.init(ctx) if err != nil { return nil, err } @@ -287,7 +323,7 @@ func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, ma func (c *ReplicaClient) OpenLTXFile(ctx context.Context, level int, minTXID, maxTXID ltx.TXID, offset, size int64) (_ io.ReadCloser, err error) { defer func() { c.resetOnConnError(err) }() - sftpClient, err := c.Init(ctx) + sftpClient, err := c.init(ctx) if err != nil { return nil, err } @@ -316,7 +352,7 @@ func (c *ReplicaClient) OpenLTXFile(ctx context.Context, level int, minTXID, max func (c *ReplicaClient) DeleteLTXFiles(ctx context.Context, a []*ltx.FileInfo) (err error) { defer func() { c.resetOnConnError(err) }() - sftpClient, err := c.Init(ctx) + sftpClient, err := c.init(ctx) if err != nil { return err } @@ -339,7 +375,7 @@ func (c *ReplicaClient) DeleteLTXFiles(ctx context.Context, a []*ltx.FileInfo) ( func (c *ReplicaClient) Cleanup(ctx context.Context) (err error) { defer func() { c.resetOnConnError(err) }() - sftpClient, err := c.Init(ctx) + sftpClient, err := c.init(ctx) if err != nil { return err } diff --git a/store_compaction_remote_test.go b/store_compaction_remote_test.go index 6d26a85..dcdd13f 100644 --- a/store_compaction_remote_test.go +++ b/store_compaction_remote_test.go @@ -126,6 +126,8 @@ func newDelayedReplicaClient(delay time.Duration) *delayedReplicaClient { func (c *delayedReplicaClient) Type() string { return "delayed" } +func (c *delayedReplicaClient) Init(context.Context) error { return nil } + func (c *delayedReplicaClient) key(level int, min, max ltx.TXID) string { return fmt.Sprintf("%d:%s:%s", level, min.String(), max.String()) } diff --git a/webdav/replica_client.go b/webdav/replica_client.go index 7f89567..2754004 100644 --- a/webdav/replica_client.go +++ b/webdav/replica_client.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "log/slog" + "net/url" "os" "path" "sort" @@ -19,6 +20,11 @@ import ( "github.com/benbjohnson/litestream/internal" ) +func init() { + litestream.RegisterReplicaClientFactory("webdav", NewReplicaClientFromURL) + litestream.RegisterReplicaClientFactory("webdavs", NewReplicaClientFromURL) +} + const ReplicaClientType = "webdav" const ( @@ -46,11 +52,45 @@ func NewReplicaClient() *ReplicaClient { } } +// NewReplicaClientFromURL creates a new ReplicaClient from URL components. +// This is used by the replica client factory registration. +// URL format: webdav://[user[:password]@]host[:port]/path or webdavs://... (for HTTPS) +func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) { + client := NewReplicaClient() + + // Determine HTTP or HTTPS based on scheme + httpScheme := "http" + if scheme == "webdavs" { + httpScheme = "https" + } + + // Extract credentials from userinfo + if userinfo != nil { + client.Username = userinfo.Username() + client.Password, _ = userinfo.Password() + } + + if host == "" { + return nil, fmt.Errorf("host required for webdav replica URL") + } + + client.URL = fmt.Sprintf("%s://%s", httpScheme, host) + client.Path = urlPath + + return client, nil +} + func (c *ReplicaClient) Type() string { return ReplicaClientType } -func (c *ReplicaClient) Init(ctx context.Context) (_ *gowebdav.Client, err error) { +func (c *ReplicaClient) Init(ctx context.Context) error { + _, err := c.init(ctx) + return err +} + +// init initializes the connection and returns the WebDAV client. +func (c *ReplicaClient) init(ctx context.Context) (_ *gowebdav.Client, err error) { c.mu.Lock() defer c.mu.Unlock() @@ -75,7 +115,7 @@ func (c *ReplicaClient) Init(ctx context.Context) (_ *gowebdav.Client, err error } func (c *ReplicaClient) DeleteAll(ctx context.Context) error { - client, err := c.Init(ctx) + client, err := c.init(ctx) if err != nil { return err } @@ -90,7 +130,7 @@ func (c *ReplicaClient) DeleteAll(ctx context.Context) error { } func (c *ReplicaClient) LTXFiles(ctx context.Context, level int, seek ltx.TXID, _ bool) (_ ltx.FileIterator, err error) { - client, err := c.Init(ctx) + client, err := c.init(ctx) if err != nil { return nil, err } @@ -190,7 +230,7 @@ func (c *ReplicaClient) LTXFiles(ctx context.Context, level int, seek ltx.TXID, // - https://github.com/nextcloud/server/issues/7995 (0-byte file bug) // - https://evertpot.com/260/ (WebDAV chunked encoding compatibility) func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, maxTXID ltx.TXID, rd io.Reader) (info *ltx.FileInfo, err error) { - client, err := c.Init(ctx) + client, err := c.init(ctx) if err != nil { return nil, err } @@ -255,7 +295,7 @@ func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, ma } func (c *ReplicaClient) OpenLTXFile(ctx context.Context, level int, minTXID, maxTXID ltx.TXID, offset, size int64) (_ io.ReadCloser, err error) { - client, err := c.Init(ctx) + client, err := c.init(ctx) if err != nil { return nil, err } @@ -307,7 +347,7 @@ func (c *ReplicaClient) OpenLTXFile(ctx context.Context, level int, minTXID, max } func (c *ReplicaClient) DeleteLTXFiles(ctx context.Context, a []*ltx.FileInfo) error { - client, err := c.Init(ctx) + client, err := c.init(ctx) if err != nil { return err } diff --git a/webdav/replica_client_test.go b/webdav/replica_client_test.go index 1f7a0b3..86202f7 100644 --- a/webdav/replica_client_test.go +++ b/webdav/replica_client_test.go @@ -29,7 +29,7 @@ func TestReplicaClient_Init_RequiresURL(t *testing.T) { c := webdav.NewReplicaClient() c.URL = "" - if _, err := c.Init(context.TODO()); err == nil { + if err := c.Init(context.TODO()); err == nil { t.Fatal("expected error when URL is empty") } else if got, want := err.Error(), "webdav url required"; got != want { t.Fatalf("error=%v, want %v", got, want)