mirror of
https://github.com/benbjohnson/litestream.git
synced 2026-01-25 05:06:30 +00:00
Refactor replica URL parsing (#884)
Co-authored-by: Cory LaNou <cory@lanou.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
@@ -28,6 +29,10 @@ import (
|
||||
"github.com/benbjohnson/litestream/internal"
|
||||
)
|
||||
|
||||
func init() {
|
||||
litestream.RegisterReplicaClientFactory("abs", NewReplicaClientFromURL)
|
||||
}
|
||||
|
||||
// ReplicaClientType is the client type for this package.
|
||||
const ReplicaClientType = "abs"
|
||||
|
||||
@@ -60,6 +65,27 @@ func NewReplicaClient() *ReplicaClient {
|
||||
}
|
||||
}
|
||||
|
||||
// NewReplicaClientFromURL creates a new ReplicaClient from URL components.
|
||||
// This is used by the replica client factory registration.
|
||||
// URL format: abs://[account-name@]container/path
|
||||
func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) {
|
||||
client := NewReplicaClient()
|
||||
|
||||
// Extract account name from userinfo if present (abs://account@container/path)
|
||||
if userinfo != nil {
|
||||
client.AccountName = userinfo.Username()
|
||||
}
|
||||
|
||||
client.Bucket = host
|
||||
client.Path = urlPath
|
||||
|
||||
if client.Bucket == "" {
|
||||
return nil, fmt.Errorf("bucket required for abs replica URL")
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Type returns "abs" as the client type.
|
||||
func (c *ReplicaClient) Type() string {
|
||||
return ReplicaClientType
|
||||
|
||||
@@ -16,43 +16,40 @@ import (
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/psanford/sqlite3vfs"
|
||||
|
||||
"github.com/benbjohnson/litestream"
|
||||
"github.com/benbjohnson/litestream/s3"
|
||||
|
||||
// Import all replica backends to register their URL factories.
|
||||
_ "github.com/benbjohnson/litestream/abs"
|
||||
_ "github.com/benbjohnson/litestream/file"
|
||||
_ "github.com/benbjohnson/litestream/gs"
|
||||
_ "github.com/benbjohnson/litestream/nats"
|
||||
_ "github.com/benbjohnson/litestream/oss"
|
||||
_ "github.com/benbjohnson/litestream/s3"
|
||||
_ "github.com/benbjohnson/litestream/sftp"
|
||||
_ "github.com/benbjohnson/litestream/webdav"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
|
||||
//export LitestreamVFSRegister
|
||||
func LitestreamVFSRegister() {
|
||||
var client litestream.ReplicaClient
|
||||
var err error
|
||||
client := s3.NewReplicaClient()
|
||||
client.AccessKeyID = os.Getenv("AWS_ACCESS_KEY_ID")
|
||||
client.SecretAccessKey = os.Getenv("AWS_SECRET_ACCESS_KEY")
|
||||
client.Region = os.Getenv("LITESTREAM_S3_REGION")
|
||||
client.Bucket = os.Getenv("LITESTREAM_S3_BUCKET")
|
||||
client.Path = os.Getenv("LITESTREAM_S3_PATH")
|
||||
client.Endpoint = os.Getenv("LITESTREAM_S3_ENDPOINT")
|
||||
|
||||
if v := os.Getenv("LITESTREAM_S3_FORCE_PATH_STYLE"); v != "" {
|
||||
if client.ForcePathStyle, err = strconv.ParseBool(v); err != nil {
|
||||
log.Fatalf("failed to parse LITESTREAM_S3_FORCE_PATH_STYLE: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if v := os.Getenv("LITESTREAM_S3_SKIP_VERIFY"); v != "" {
|
||||
if client.SkipVerify, err = strconv.ParseBool(v); err != nil {
|
||||
log.Fatalf("failed to parse LITESTREAM_S3_SKIP_VERIFY: %s", err)
|
||||
}
|
||||
replicaURL := os.Getenv("LITESTREAM_REPLICA_URL")
|
||||
client, err = litestream.NewReplicaClientFromURL(replicaURL)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create replica client from URL: %s", err)
|
||||
}
|
||||
|
||||
// Initialize the client.
|
||||
if err := client.Init(context.Background()); err != nil {
|
||||
log.Fatalf("failed to initialize litestream s3 client: %s", err)
|
||||
log.Fatalf("failed to initialize litestream replica client: %s", err)
|
||||
}
|
||||
|
||||
var level slog.Level
|
||||
|
||||
@@ -28,7 +28,7 @@ func (c *LTXCommand) Run(ctx context.Context, args []string) (err error) {
|
||||
}
|
||||
|
||||
var r *litestream.Replica
|
||||
if isURL(fs.Arg(0)) {
|
||||
if litestream.IsURL(fs.Arg(0)) {
|
||||
if *configPath != "" {
|
||||
return fmt.Errorf("cannot specify a replica URL and the -config flag")
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"os/user"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -994,7 +993,7 @@ type ReplicaConfig struct {
|
||||
// NewReplicaFromConfig instantiates a replica for a DB based on a config.
|
||||
func NewReplicaFromConfig(c *ReplicaConfig, db *litestream.DB) (_ *litestream.Replica, err error) {
|
||||
// Ensure user did not specify URL in path.
|
||||
if isURL(c.Path) {
|
||||
if litestream.IsURL(c.Path) {
|
||||
return nil, fmt.Errorf("replica path cannot be a url, please use the 'url' field instead: %s", c.Path)
|
||||
}
|
||||
|
||||
@@ -1064,7 +1063,7 @@ func newFileReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_
|
||||
// Parse configPath from URL, if specified.
|
||||
configPath := c.Path
|
||||
if c.URL != "" {
|
||||
if _, _, configPath, err = ParseReplicaURL(c.URL); err != nil {
|
||||
if _, _, configPath, err = litestream.ParseReplicaURL(c.URL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -1126,7 +1125,7 @@ func NewS3ReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ *s
|
||||
}
|
||||
|
||||
if c.URL != "" {
|
||||
_, host, upath, query, err := ParseReplicaURLWithQuery(c.URL)
|
||||
_, host, upath, query, _, err := litestream.ParseReplicaURLWithQuery(c.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1140,7 +1139,7 @@ func NewS3ReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ *s
|
||||
|
||||
if strings.HasPrefix(host, "arn:") {
|
||||
ubucket = host
|
||||
uregion = regionFromS3ARN(host)
|
||||
uregion = litestream.RegionFromS3ARN(host)
|
||||
} else {
|
||||
ubucket, uregion, uendpoint, uforcePathStyle = s3.ParseHost(host)
|
||||
}
|
||||
@@ -1168,11 +1167,11 @@ func NewS3ReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ *s
|
||||
if qSkipVerify := query.Get("skipVerify"); qSkipVerify != "" {
|
||||
skipVerify = qSkipVerify == "true"
|
||||
}
|
||||
if v, ok := boolQueryValue(query, "signPayload", "sign-payload"); ok {
|
||||
if v, ok := litestream.BoolQueryValue(query, "signPayload", "sign-payload"); ok {
|
||||
usignPayload = v
|
||||
usignPayloadSet = true
|
||||
}
|
||||
if v, ok := boolQueryValue(query, "requireContentMD5", "require-content-md5"); ok {
|
||||
if v, ok := litestream.BoolQueryValue(query, "requireContentMD5", "require-content-md5"); ok {
|
||||
urequireContentMD5 = v
|
||||
urequireContentMD5Set = true
|
||||
}
|
||||
@@ -1206,8 +1205,8 @@ func NewS3ReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ *s
|
||||
return nil, fmt.Errorf("bucket required for s3 replica")
|
||||
}
|
||||
|
||||
isTigris := isTigrisEndpoint(endpoint)
|
||||
if !isTigris && !endpointWasSet && isTigrisEndpoint(c.Endpoint) {
|
||||
isTigris := litestream.IsTigrisEndpoint(endpoint)
|
||||
if !isTigris && !endpointWasSet && litestream.IsTigrisEndpoint(c.Endpoint) {
|
||||
isTigris = true
|
||||
}
|
||||
|
||||
@@ -1253,7 +1252,7 @@ func newGSReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ *g
|
||||
|
||||
// Apply settings from URL, if specified.
|
||||
if c.URL != "" {
|
||||
_, uhost, upath, err := ParseReplicaURL(c.URL)
|
||||
_, uhost, upath, err := litestream.ParseReplicaURL(c.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1438,7 +1437,7 @@ func newNATSReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_
|
||||
// Parse URL if provided to extract bucket name and server URL
|
||||
var url, bucket string
|
||||
if c.URL != "" {
|
||||
scheme, host, bucketPath, err := ParseReplicaURL(c.URL)
|
||||
scheme, host, bucketPath, err := litestream.ParseReplicaURL(c.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid NATS URL: %w", err)
|
||||
}
|
||||
@@ -1520,7 +1519,7 @@ func newOSSReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ *
|
||||
|
||||
// Apply settings from URL, if specified.
|
||||
if c.URL != "" {
|
||||
_, host, upath, err := ParseReplicaURL(c.URL)
|
||||
_, host, upath, err := litestream.ParseReplicaURL(c.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1585,128 +1584,6 @@ func applyLitestreamEnv() {
|
||||
}
|
||||
}
|
||||
|
||||
// ParseReplicaURL parses a replica URL.
|
||||
func ParseReplicaURL(s string) (scheme, host, urlpath string, err error) {
|
||||
if strings.HasPrefix(strings.ToLower(s), "s3://arn:") {
|
||||
return parseS3AccessPointURL(s)
|
||||
}
|
||||
|
||||
scheme, host, urlpath, _, err = ParseReplicaURLWithQuery(s)
|
||||
return scheme, host, urlpath, err
|
||||
}
|
||||
|
||||
// ParseReplicaURLWithQuery parses a replica URL and returns query parameters.
|
||||
func ParseReplicaURLWithQuery(s string) (scheme, host, urlpath string, query url.Values, err error) {
|
||||
// Handle S3 Access Point ARNs which can't be parsed by standard url.Parse
|
||||
if strings.HasPrefix(strings.ToLower(s), "s3://arn:") {
|
||||
scheme, host, urlpath, err := parseS3AccessPointURL(s)
|
||||
return scheme, host, urlpath, nil, err
|
||||
}
|
||||
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
return "", "", "", nil, err
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "file":
|
||||
scheme, u.Scheme = u.Scheme, ""
|
||||
// Remove query params from path for file URLs
|
||||
u.RawQuery = ""
|
||||
return scheme, "", path.Clean(u.String()), nil, nil
|
||||
|
||||
case "":
|
||||
return u.Scheme, u.Host, u.Path, nil, fmt.Errorf("replica url scheme required: %s", s)
|
||||
|
||||
default:
|
||||
return u.Scheme, u.Host, strings.TrimPrefix(path.Clean(u.Path), "/"), u.Query(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseS3AccessPointURL(s string) (scheme, host, urlpath string, err error) {
|
||||
const prefix = "s3://"
|
||||
if !strings.HasPrefix(strings.ToLower(s), prefix) {
|
||||
return "", "", "", fmt.Errorf("invalid s3 access point url: %s", s)
|
||||
}
|
||||
|
||||
arnWithPath := s[len(prefix):]
|
||||
bucket, key, err := splitS3AccessPointARN(arnWithPath)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
return "s3", bucket, cleanReplicaURLPath(key), nil
|
||||
}
|
||||
|
||||
func splitS3AccessPointARN(s string) (bucket, key string, err error) {
|
||||
lower := strings.ToLower(s)
|
||||
const marker = ":accesspoint/"
|
||||
idx := strings.Index(lower, marker)
|
||||
if idx == -1 {
|
||||
return "", "", fmt.Errorf("invalid s3 access point arn: %s", s)
|
||||
}
|
||||
|
||||
nameStart := idx + len(marker)
|
||||
if nameStart >= len(s) {
|
||||
return "", "", fmt.Errorf("invalid s3 access point arn: %s", s)
|
||||
}
|
||||
|
||||
remainder := s[nameStart:]
|
||||
slashIdx := strings.IndexByte(remainder, '/')
|
||||
if slashIdx == -1 {
|
||||
return s, "", nil
|
||||
}
|
||||
|
||||
bucketEnd := nameStart + slashIdx
|
||||
bucket = s[:bucketEnd]
|
||||
key = remainder[slashIdx+1:]
|
||||
return bucket, key, nil
|
||||
}
|
||||
|
||||
func cleanReplicaURLPath(p string) string {
|
||||
if p == "" {
|
||||
return ""
|
||||
}
|
||||
cleaned := path.Clean("/" + p)
|
||||
cleaned = strings.TrimPrefix(cleaned, "/")
|
||||
if cleaned == "." {
|
||||
return ""
|
||||
}
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func boolQueryValue(query url.Values, keys ...string) (bool, bool) {
|
||||
if query == nil {
|
||||
return false, false
|
||||
}
|
||||
for _, key := range keys {
|
||||
if raw := query.Get(key); raw != "" {
|
||||
switch strings.ToLower(raw) {
|
||||
case "true", "1", "t", "yes":
|
||||
return true, true
|
||||
case "false", "0", "f", "no":
|
||||
return false, true
|
||||
default:
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func isTigrisEndpoint(endpoint string) bool {
|
||||
endpoint = strings.TrimSpace(strings.ToLower(endpoint))
|
||||
if endpoint == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
|
||||
if u, err := url.Parse(endpoint); err == nil && u.Host != "" {
|
||||
endpoint = u.Host
|
||||
}
|
||||
}
|
||||
return endpoint == "fly.storage.tigris.dev"
|
||||
}
|
||||
|
||||
type boolSetting struct {
|
||||
value bool
|
||||
set bool
|
||||
@@ -1727,27 +1604,10 @@ func (s *boolSetting) ApplyDefault(value bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func regionFromS3ARN(arn string) string {
|
||||
parts := strings.SplitN(arn, ":", 6)
|
||||
if len(parts) >= 4 {
|
||||
return parts[3]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// isURL returns true if s can be parsed and has a scheme.
|
||||
func isURL(s string) bool {
|
||||
return regexp.MustCompile(`^\w+:\/\/`).MatchString(s)
|
||||
}
|
||||
|
||||
// ReplicaType returns the type based on the type field or extracted from the URL.
|
||||
func (c *ReplicaConfig) ReplicaType() string {
|
||||
scheme, _, _, _ := ParseReplicaURL(c.URL)
|
||||
if scheme != "" {
|
||||
if scheme == "webdavs" {
|
||||
return "webdav"
|
||||
}
|
||||
return scheme
|
||||
if replicaType := litestream.ReplicaTypeFromURL(c.URL); replicaType != "" {
|
||||
return replicaType
|
||||
} else if c.Type != "" {
|
||||
return c.Type
|
||||
}
|
||||
|
||||
@@ -496,7 +496,7 @@ snapshot:
|
||||
|
||||
func TestParseReplicaURL_AccessPoint(t *testing.T) {
|
||||
t.Run("WithPrefix", func(t *testing.T) {
|
||||
scheme, host, urlPath, err := main.ParseReplicaURL("s3://arn:aws:s3:us-east-1:123456789012:accesspoint/db-access/backups/prod")
|
||||
scheme, host, urlPath, err := litestream.ParseReplicaURL("s3://arn:aws:s3:us-east-1:123456789012:accesspoint/db-access/backups/prod")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -512,7 +512,7 @@ func TestParseReplicaURL_AccessPoint(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Invalid", func(t *testing.T) {
|
||||
if _, _, _, err := main.ParseReplicaURL("s3://arn:aws:s3:us-east-1:123456789012:accesspoint/"); err == nil {
|
||||
if _, _, _, err := litestream.ParseReplicaURL("s3://arn:aws:s3:us-east-1:123456789012:accesspoint/"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
})
|
||||
@@ -1498,7 +1498,7 @@ func TestFindSQLiteDatabases(t *testing.T) {
|
||||
func TestParseReplicaURLWithQuery(t *testing.T) {
|
||||
t.Run("S3WithEndpoint", func(t *testing.T) {
|
||||
url := "s3://mybucket/path/to/db?endpoint=localhost:9000®ion=us-east-1&forcePathStyle=true"
|
||||
scheme, host, path, query, err := main.ParseReplicaURLWithQuery(url)
|
||||
scheme, host, path, query, _, err := litestream.ParseReplicaURLWithQuery(url)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1524,7 +1524,7 @@ func TestParseReplicaURLWithQuery(t *testing.T) {
|
||||
|
||||
t.Run("S3WithoutQuery", func(t *testing.T) {
|
||||
url := "s3://mybucket/path/to/db"
|
||||
scheme, host, path, query, err := main.ParseReplicaURLWithQuery(url)
|
||||
scheme, host, path, query, _, err := litestream.ParseReplicaURLWithQuery(url)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1544,7 +1544,7 @@ func TestParseReplicaURLWithQuery(t *testing.T) {
|
||||
|
||||
t.Run("FileURL", func(t *testing.T) {
|
||||
url := "file:///path/to/db"
|
||||
scheme, host, path, query, err := main.ParseReplicaURLWithQuery(url)
|
||||
scheme, host, path, query, _, err := litestream.ParseReplicaURLWithQuery(url)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1565,7 +1565,7 @@ func TestParseReplicaURLWithQuery(t *testing.T) {
|
||||
t.Run("BackwardCompatibility", func(t *testing.T) {
|
||||
// Test that ParseReplicaURL still works as before
|
||||
url := "s3://mybucket/path/to/db?endpoint=localhost:9000"
|
||||
scheme, host, path, err := main.ParseReplicaURL(url)
|
||||
scheme, host, path, err := litestream.ParseReplicaURL(url)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1582,7 +1582,7 @@ func TestParseReplicaURLWithQuery(t *testing.T) {
|
||||
|
||||
t.Run("S3TigrisExample", func(t *testing.T) {
|
||||
url := "s3://mybucket/db?endpoint=fly.storage.tigris.dev®ion=auto"
|
||||
scheme, host, path, query, err := main.ParseReplicaURLWithQuery(url)
|
||||
scheme, host, path, query, _, err := litestream.ParseReplicaURLWithQuery(url)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1605,7 +1605,7 @@ func TestParseReplicaURLWithQuery(t *testing.T) {
|
||||
|
||||
t.Run("S3WithSkipVerify", func(t *testing.T) {
|
||||
url := "s3://mybucket/db?endpoint=self-signed.local&skipVerify=true"
|
||||
_, _, _, query, err := main.ParseReplicaURLWithQuery(url)
|
||||
_, _, _, query, _, err := litestream.ParseReplicaURLWithQuery(url)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) {
|
||||
|
||||
// Determine replica to restore from.
|
||||
var r *litestream.Replica
|
||||
if isURL(fs.Arg(0)) {
|
||||
if litestream.IsURL(fs.Arg(0)) {
|
||||
if *configPath != "" {
|
||||
return fmt.Errorf("cannot specify a replica URL and the -config flag")
|
||||
}
|
||||
|
||||
@@ -20,6 +20,8 @@ type testReplicaClient struct {
|
||||
dir string
|
||||
}
|
||||
|
||||
func (c *testReplicaClient) Init(_ context.Context) error { return nil }
|
||||
|
||||
func (c *testReplicaClient) Type() string { return "test" }
|
||||
|
||||
func (c *testReplicaClient) LTXFiles(_ context.Context, _ int, _ ltx.TXID, _ bool) (ltx.FileIterator, error) {
|
||||
@@ -59,6 +61,8 @@ type errorReplicaClient struct {
|
||||
writeErr error
|
||||
}
|
||||
|
||||
func (c *errorReplicaClient) Init(_ context.Context) error { return nil }
|
||||
|
||||
func (c *errorReplicaClient) Type() string { return "error" }
|
||||
|
||||
func (c *errorReplicaClient) LTXFiles(_ context.Context, _ int, _ ltx.TXID, _ bool) (ltx.FileIterator, error) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
@@ -16,6 +17,10 @@ import (
|
||||
"github.com/benbjohnson/litestream/internal"
|
||||
)
|
||||
|
||||
func init() {
|
||||
litestream.RegisterReplicaClientFactory("file", NewReplicaClientFromURL)
|
||||
}
|
||||
|
||||
// ReplicaClientType is the client type for this package.
|
||||
const ReplicaClientType = "file"
|
||||
|
||||
@@ -37,6 +42,16 @@ func NewReplicaClient(path string) *ReplicaClient {
|
||||
}
|
||||
}
|
||||
|
||||
// NewReplicaClientFromURL creates a new ReplicaClient from URL components.
|
||||
// This is used by the replica client factory registration.
|
||||
func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) {
|
||||
// For file URLs, the path is the full path
|
||||
if urlPath == "" {
|
||||
return nil, fmt.Errorf("file replica path required")
|
||||
}
|
||||
return NewReplicaClient(urlPath), nil
|
||||
}
|
||||
|
||||
// db returns the database, if available.
|
||||
func (c *ReplicaClient) db() *litestream.DB {
|
||||
if c.Replica == nil {
|
||||
@@ -50,6 +65,11 @@ func (c *ReplicaClient) Type() string {
|
||||
return ReplicaClientType
|
||||
}
|
||||
|
||||
// Init is a no-op for file replica client as no initialization is required.
|
||||
func (c *ReplicaClient) Init(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Path returns the destination path to replicate the database to.
|
||||
func (c *ReplicaClient) Path() string {
|
||||
return c.path
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
@@ -20,6 +21,10 @@ import (
|
||||
"github.com/benbjohnson/litestream/internal"
|
||||
)
|
||||
|
||||
func init() {
|
||||
litestream.RegisterReplicaClientFactory("gs", NewReplicaClientFromURL)
|
||||
}
|
||||
|
||||
// ReplicaClientType is the client type for this package.
|
||||
const ReplicaClientType = "gs"
|
||||
|
||||
@@ -47,6 +52,19 @@ func NewReplicaClient() *ReplicaClient {
|
||||
}
|
||||
}
|
||||
|
||||
// NewReplicaClientFromURL creates a new ReplicaClient from URL components.
|
||||
// This is used by the replica client factory registration.
|
||||
func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) {
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("bucket required for gs replica URL")
|
||||
}
|
||||
|
||||
client := NewReplicaClient()
|
||||
client.Bucket = host
|
||||
client.Path = urlPath
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Type returns "gs" as the client type.
|
||||
func (c *ReplicaClient) Type() string {
|
||||
return ReplicaClientType
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
var _ litestream.ReplicaClient = (*ReplicaClient)(nil)
|
||||
|
||||
type ReplicaClient struct {
|
||||
InitFunc func(ctx context.Context) error
|
||||
DeleteAllFunc func(ctx context.Context) error
|
||||
LTXFilesFunc func(ctx context.Context, level int, seek ltx.TXID, useMetadata bool) (ltx.FileIterator, error)
|
||||
OpenLTXFileFunc func(ctx context.Context, level int, minTXID, maxTXID ltx.TXID, offset, size int64) (io.ReadCloser, error)
|
||||
@@ -21,6 +22,13 @@ type ReplicaClient struct {
|
||||
|
||||
func (c *ReplicaClient) Type() string { return "mock" }
|
||||
|
||||
func (c *ReplicaClient) Init(ctx context.Context) error {
|
||||
if c.InitFunc != nil {
|
||||
return c.InitFunc(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ReplicaClient) DeleteAll(ctx context.Context) error {
|
||||
return c.DeleteAllFunc(ctx)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -22,6 +23,10 @@ import (
|
||||
"github.com/benbjohnson/litestream/internal"
|
||||
)
|
||||
|
||||
func init() {
|
||||
litestream.RegisterReplicaClientFactory("nats", NewReplicaClientFromURL)
|
||||
}
|
||||
|
||||
// ReplicaClientType is the client type for this package.
|
||||
const ReplicaClientType = "nats"
|
||||
|
||||
@@ -84,6 +89,33 @@ func NewReplicaClient() *ReplicaClient {
|
||||
}
|
||||
}
|
||||
|
||||
// NewReplicaClientFromURL creates a new ReplicaClient from URL components.
|
||||
// This is used by the replica client factory registration.
|
||||
// URL format: nats://[user:pass@]host[:port]/bucket
|
||||
func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) {
|
||||
client := NewReplicaClient()
|
||||
|
||||
// Reconstruct URL without bucket path
|
||||
if host != "" {
|
||||
client.URL = fmt.Sprintf("nats://%s", host)
|
||||
}
|
||||
|
||||
// Extract credentials from userinfo if present
|
||||
if userinfo != nil {
|
||||
client.Username = userinfo.Username()
|
||||
client.Password, _ = userinfo.Password()
|
||||
}
|
||||
|
||||
// Extract bucket name from path
|
||||
bucket := strings.Trim(urlPath, "/")
|
||||
if bucket == "" {
|
||||
return nil, fmt.Errorf("bucket required for nats replica URL")
|
||||
}
|
||||
client.BucketName = bucket
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Type returns "nats" as the client type.
|
||||
func (c *ReplicaClient) Type() string {
|
||||
return ReplicaClientType
|
||||
|
||||
@@ -23,6 +23,10 @@ import (
|
||||
"github.com/benbjohnson/litestream/internal"
|
||||
)
|
||||
|
||||
func init() {
|
||||
litestream.RegisterReplicaClientFactory("oss", NewReplicaClientFromURL)
|
||||
}
|
||||
|
||||
// ReplicaClientType is the client type for this package.
|
||||
const ReplicaClientType = "oss"
|
||||
|
||||
@@ -67,6 +71,24 @@ func NewReplicaClient() *ReplicaClient {
|
||||
}
|
||||
}
|
||||
|
||||
// NewReplicaClientFromURL creates a new ReplicaClient from URL components.
|
||||
// This is used by the replica client factory registration.
|
||||
// URL format: oss://bucket[.oss-region.aliyuncs.com]/path
|
||||
func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) {
|
||||
client := NewReplicaClient()
|
||||
|
||||
bucket, region, _ := ParseHost(host)
|
||||
if bucket == "" {
|
||||
return nil, fmt.Errorf("bucket required for oss replica URL")
|
||||
}
|
||||
|
||||
client.Bucket = bucket
|
||||
client.Region = region
|
||||
client.Path = urlPath
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Type returns "oss" as the client type.
|
||||
func (c *ReplicaClient) Type() string {
|
||||
return ReplicaClientType
|
||||
|
||||
@@ -19,6 +19,11 @@ type ReplicaClient interface {
|
||||
// Type returns the type of client.
|
||||
Type() string
|
||||
|
||||
// Init initializes the replica client connection.
|
||||
// This may establish connections, validate configuration, etc.
|
||||
// Implementations should be idempotent (no-op if already initialized).
|
||||
Init(ctx context.Context) error
|
||||
|
||||
// LTXFiles returns an iterator of all LTX files on the replica for a given level.
|
||||
// If seek is specified, the iterator start from the given TXID or the next available if not found.
|
||||
// If useMetadata is true, the iterator fetches accurate timestamps from metadata for timestamp-based restore.
|
||||
|
||||
@@ -435,7 +435,7 @@ AAAEDzV1D6COyvFGhSiZa6ll9aXZ2IMWED3KGrvCNjEEtYHwnK0+GdwOelXlAXdqLx/qvS
|
||||
c.Host = addr
|
||||
c.HostKey = expectedHostKey
|
||||
|
||||
_, err = c.Init(context.Background())
|
||||
err = c.Init(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("SFTP connection failed: %v", err)
|
||||
}
|
||||
@@ -449,7 +449,7 @@ AAAEDzV1D6COyvFGhSiZa6ll9aXZ2IMWED3KGrvCNjEEtYHwnK0+GdwOelXlAXdqLx/qvS
|
||||
c.Host = addr
|
||||
c.HostKey = invalidHostKey
|
||||
|
||||
_, err = c.Init(context.Background())
|
||||
err = c.Init(context.Background())
|
||||
if err == nil {
|
||||
t.Fatalf("SFTP connection established despite invalid host key")
|
||||
}
|
||||
@@ -475,7 +475,7 @@ AAAEDzV1D6COyvFGhSiZa6ll9aXZ2IMWED3KGrvCNjEEtYHwnK0+GdwOelXlAXdqLx/qvS
|
||||
c.User = "foo"
|
||||
c.Host = addr
|
||||
|
||||
_, err = c.Init(context.Background())
|
||||
err = c.Init(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("SFTP connection failed: %v", err)
|
||||
}
|
||||
|
||||
209
replica_url.go
Normal file
209
replica_url.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package litestream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ReplicaClientFactory is a function that creates a ReplicaClient from URL components.
|
||||
// The userinfo parameter contains credentials from the URL (e.g., user:pass@host).
|
||||
type ReplicaClientFactory func(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (ReplicaClient, error)
|
||||
|
||||
var (
|
||||
replicaClientFactories = make(map[string]ReplicaClientFactory)
|
||||
replicaClientFactoriesMu sync.RWMutex
|
||||
)
|
||||
|
||||
// RegisterReplicaClientFactory registers a factory function for creating replica clients
|
||||
// for a given URL scheme. This is typically called from init() functions in backend packages.
|
||||
func RegisterReplicaClientFactory(scheme string, factory ReplicaClientFactory) {
|
||||
replicaClientFactoriesMu.Lock()
|
||||
defer replicaClientFactoriesMu.Unlock()
|
||||
replicaClientFactories[scheme] = factory
|
||||
}
|
||||
|
||||
// NewReplicaClientFromURL creates a new ReplicaClient from a URL string.
|
||||
// The URL scheme determines which backend is used (s3, gs, abs, file, etc.).
|
||||
func NewReplicaClientFromURL(rawURL string) (ReplicaClient, error) {
|
||||
scheme, host, urlPath, query, userinfo, err := ParseReplicaURLWithQuery(rawURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Normalize webdavs to webdav
|
||||
factoryScheme := scheme
|
||||
if factoryScheme == "webdavs" {
|
||||
factoryScheme = "webdav"
|
||||
}
|
||||
|
||||
replicaClientFactoriesMu.RLock()
|
||||
factory, ok := replicaClientFactories[factoryScheme]
|
||||
replicaClientFactoriesMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported replica URL scheme: %q", scheme)
|
||||
}
|
||||
|
||||
return factory(scheme, host, urlPath, query, userinfo)
|
||||
}
|
||||
|
||||
// ReplicaTypeFromURL returns the replica type from a URL string.
|
||||
// Returns empty string if the URL is invalid or has no scheme.
|
||||
func ReplicaTypeFromURL(rawURL string) string {
|
||||
scheme, _, _, _ := ParseReplicaURL(rawURL)
|
||||
if scheme == "" {
|
||||
return ""
|
||||
}
|
||||
if scheme == "webdavs" {
|
||||
return "webdav"
|
||||
}
|
||||
return scheme
|
||||
}
|
||||
|
||||
// ParseReplicaURL parses a replica URL and returns the scheme, host, and path.
|
||||
func ParseReplicaURL(s string) (scheme, host, urlPath string, err error) {
|
||||
if strings.HasPrefix(strings.ToLower(s), "s3://arn:") {
|
||||
return parseS3AccessPointURL(s)
|
||||
}
|
||||
|
||||
scheme, host, urlPath, _, _, err = ParseReplicaURLWithQuery(s)
|
||||
return scheme, host, urlPath, err
|
||||
}
|
||||
|
||||
// ParseReplicaURLWithQuery parses a replica URL and returns query parameters and userinfo.
|
||||
func ParseReplicaURLWithQuery(s string) (scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo, err error) {
|
||||
// Handle S3 Access Point ARNs which can't be parsed by standard url.Parse
|
||||
if strings.HasPrefix(strings.ToLower(s), "s3://arn:") {
|
||||
scheme, host, urlPath, err := parseS3AccessPointURL(s)
|
||||
return scheme, host, urlPath, nil, nil, err
|
||||
}
|
||||
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
return "", "", "", nil, nil, err
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "file":
|
||||
scheme, u.Scheme = u.Scheme, ""
|
||||
// Remove query params from path for file URLs
|
||||
u.RawQuery = ""
|
||||
return scheme, "", path.Clean(u.String()), nil, nil, nil
|
||||
|
||||
case "":
|
||||
return u.Scheme, u.Host, u.Path, nil, nil, fmt.Errorf("replica url scheme required: %s", s)
|
||||
|
||||
default:
|
||||
return u.Scheme, u.Host, strings.TrimPrefix(path.Clean(u.Path), "/"), u.Query(), u.User, nil
|
||||
}
|
||||
}
|
||||
|
||||
// parseS3AccessPointURL parses an S3 Access Point URL (s3://arn:...).
|
||||
func parseS3AccessPointURL(s string) (scheme, host, urlPath string, err error) {
|
||||
const prefix = "s3://"
|
||||
if !strings.HasPrefix(strings.ToLower(s), prefix) {
|
||||
return "", "", "", fmt.Errorf("invalid s3 access point url: %s", s)
|
||||
}
|
||||
|
||||
arnWithPath := s[len(prefix):]
|
||||
bucket, key, err := splitS3AccessPointARN(arnWithPath)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
return "s3", bucket, CleanReplicaURLPath(key), nil
|
||||
}
|
||||
|
||||
// splitS3AccessPointARN splits an S3 Access Point ARN into bucket and key components.
|
||||
func splitS3AccessPointARN(s string) (bucket, key string, err error) {
|
||||
lower := strings.ToLower(s)
|
||||
const marker = ":accesspoint/"
|
||||
idx := strings.Index(lower, marker)
|
||||
if idx == -1 {
|
||||
return "", "", fmt.Errorf("invalid s3 access point arn: %s", s)
|
||||
}
|
||||
|
||||
nameStart := idx + len(marker)
|
||||
if nameStart >= len(s) {
|
||||
return "", "", fmt.Errorf("invalid s3 access point arn: %s", s)
|
||||
}
|
||||
|
||||
remainder := s[nameStart:]
|
||||
slashIdx := strings.IndexByte(remainder, '/')
|
||||
if slashIdx == -1 {
|
||||
return s, "", nil
|
||||
}
|
||||
|
||||
bucketEnd := nameStart + slashIdx
|
||||
bucket = s[:bucketEnd]
|
||||
key = remainder[slashIdx+1:]
|
||||
return bucket, key, nil
|
||||
}
|
||||
|
||||
// CleanReplicaURLPath cleans a URL path for use in replica storage.
|
||||
func CleanReplicaURLPath(p string) string {
|
||||
if p == "" {
|
||||
return ""
|
||||
}
|
||||
cleaned := path.Clean("/" + p)
|
||||
cleaned = strings.TrimPrefix(cleaned, "/")
|
||||
if cleaned == "." {
|
||||
return ""
|
||||
}
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// RegionFromS3ARN extracts the region from an S3 ARN.
|
||||
func RegionFromS3ARN(arn string) string {
|
||||
parts := strings.SplitN(arn, ":", 6)
|
||||
if len(parts) >= 4 {
|
||||
return parts[3]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BoolQueryValue returns a boolean value from URL query parameters.
|
||||
// It checks multiple keys in order and returns the value and whether it was set.
|
||||
func BoolQueryValue(query url.Values, keys ...string) (value bool, ok bool) {
|
||||
if query == nil {
|
||||
return false, false
|
||||
}
|
||||
for _, key := range keys {
|
||||
if raw := query.Get(key); raw != "" {
|
||||
switch strings.ToLower(raw) {
|
||||
case "true", "1", "t", "yes":
|
||||
return true, true
|
||||
case "false", "0", "f", "no":
|
||||
return false, true
|
||||
default:
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
// IsTigrisEndpoint returns true if the endpoint is the Tigris object storage service.
|
||||
func IsTigrisEndpoint(endpoint string) bool {
|
||||
endpoint = strings.TrimSpace(strings.ToLower(endpoint))
|
||||
if endpoint == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
|
||||
if u, err := url.Parse(endpoint); err == nil && u.Host != "" {
|
||||
endpoint = u.Host
|
||||
}
|
||||
}
|
||||
return endpoint == "fly.storage.tigris.dev"
|
||||
}
|
||||
|
||||
// IsURL returns true if s appears to be a URL (has a scheme).
|
||||
var isURLRegex = regexp.MustCompile(`^\w+:\/\/`)
|
||||
|
||||
func IsURL(s string) bool {
|
||||
return isURLRegex.MatchString(s)
|
||||
}
|
||||
632
replica_url_test.go
Normal file
632
replica_url_test.go
Normal file
@@ -0,0 +1,632 @@
|
||||
package litestream_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/benbjohnson/litestream"
|
||||
"github.com/benbjohnson/litestream/abs"
|
||||
"github.com/benbjohnson/litestream/file"
|
||||
"github.com/benbjohnson/litestream/gs"
|
||||
"github.com/benbjohnson/litestream/nats"
|
||||
"github.com/benbjohnson/litestream/oss"
|
||||
"github.com/benbjohnson/litestream/s3"
|
||||
"github.com/benbjohnson/litestream/sftp"
|
||||
"github.com/benbjohnson/litestream/webdav"
|
||||
)
|
||||
|
||||
func TestNewReplicaClientFromURL(t *testing.T) {
|
||||
t.Run("S3", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("s3://mybucket/path/to/db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.Type() != "s3" {
|
||||
t.Errorf("expected type 's3', got %q", client.Type())
|
||||
}
|
||||
s3Client, ok := client.(*s3.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *s3.ReplicaClient, got %T", client)
|
||||
}
|
||||
if s3Client.Bucket != "mybucket" {
|
||||
t.Errorf("expected bucket 'mybucket', got %q", s3Client.Bucket)
|
||||
}
|
||||
if s3Client.Path != "path/to/db" {
|
||||
t.Errorf("expected path 'path/to/db', got %q", s3Client.Path)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("S3WithQueryParams", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("s3://mybucket/db?endpoint=localhost:9000®ion=us-west-2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s3Client, ok := client.(*s3.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *s3.ReplicaClient, got %T", client)
|
||||
}
|
||||
if s3Client.Endpoint != "http://localhost:9000" {
|
||||
t.Errorf("expected endpoint 'http://localhost:9000', got %q", s3Client.Endpoint)
|
||||
}
|
||||
if s3Client.Region != "us-west-2" {
|
||||
t.Errorf("expected region 'us-west-2', got %q", s3Client.Region)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("File", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("file:///tmp/replica")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.Type() != "file" {
|
||||
t.Errorf("expected type 'file', got %q", client.Type())
|
||||
}
|
||||
fileClient, ok := client.(*file.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *file.ReplicaClient, got %T", client)
|
||||
}
|
||||
if fileClient.Path() != "/tmp/replica" {
|
||||
t.Errorf("expected path '/tmp/replica', got %q", fileClient.Path())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GS", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("gs://mybucket/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.Type() != "gs" {
|
||||
t.Errorf("expected type 'gs', got %q", client.Type())
|
||||
}
|
||||
gsClient, ok := client.(*gs.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *gs.ReplicaClient, got %T", client)
|
||||
}
|
||||
if gsClient.Bucket != "mybucket" {
|
||||
t.Errorf("expected bucket 'mybucket', got %q", gsClient.Bucket)
|
||||
}
|
||||
if gsClient.Path != "path" {
|
||||
t.Errorf("expected path 'path', got %q", gsClient.Path)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GS_MissingBucket", func(t *testing.T) {
|
||||
_, err := litestream.NewReplicaClientFromURL("gs:///path")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing bucket")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ABS", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("abs://mycontainer/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.Type() != "abs" {
|
||||
t.Errorf("expected type 'abs', got %q", client.Type())
|
||||
}
|
||||
absClient, ok := client.(*abs.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *abs.ReplicaClient, got %T", client)
|
||||
}
|
||||
if absClient.Bucket != "mycontainer" {
|
||||
t.Errorf("expected bucket 'mycontainer', got %q", absClient.Bucket)
|
||||
}
|
||||
if absClient.Path != "path" {
|
||||
t.Errorf("expected path 'path', got %q", absClient.Path)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ABS_WithAccount", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("abs://myaccount@mycontainer/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
absClient, ok := client.(*abs.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *abs.ReplicaClient, got %T", client)
|
||||
}
|
||||
if absClient.AccountName != "myaccount" {
|
||||
t.Errorf("expected account 'myaccount', got %q", absClient.AccountName)
|
||||
}
|
||||
if absClient.Bucket != "mycontainer" {
|
||||
t.Errorf("expected bucket 'mycontainer', got %q", absClient.Bucket)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ABS_MissingBucket", func(t *testing.T) {
|
||||
_, err := litestream.NewReplicaClientFromURL("abs:///path")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing bucket")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SFTP", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("sftp://myuser@host.example.com/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.Type() != "sftp" {
|
||||
t.Errorf("expected type 'sftp', got %q", client.Type())
|
||||
}
|
||||
sftpClient, ok := client.(*sftp.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *sftp.ReplicaClient, got %T", client)
|
||||
}
|
||||
if sftpClient.Host != "host.example.com" {
|
||||
t.Errorf("expected host 'host.example.com', got %q", sftpClient.Host)
|
||||
}
|
||||
if sftpClient.User != "myuser" {
|
||||
t.Errorf("expected user 'myuser', got %q", sftpClient.User)
|
||||
}
|
||||
if sftpClient.Path != "path" {
|
||||
t.Errorf("expected path 'path', got %q", sftpClient.Path)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SFTP_WithPassword", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("sftp://myuser:secret@host.example.com/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sftpClient, ok := client.(*sftp.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *sftp.ReplicaClient, got %T", client)
|
||||
}
|
||||
if sftpClient.User != "myuser" {
|
||||
t.Errorf("expected user 'myuser', got %q", sftpClient.User)
|
||||
}
|
||||
if sftpClient.Password != "secret" {
|
||||
t.Errorf("expected password 'secret', got %q", sftpClient.Password)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SFTP_RequiresUserError", func(t *testing.T) {
|
||||
_, err := litestream.NewReplicaClientFromURL("sftp://host.example.com/path")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing user")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SFTP_MissingHost", func(t *testing.T) {
|
||||
_, err := litestream.NewReplicaClientFromURL("sftp:///path")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing host")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WebDAV", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("webdav://host.example.com/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.Type() != "webdav" {
|
||||
t.Errorf("expected type 'webdav', got %q", client.Type())
|
||||
}
|
||||
webdavClient, ok := client.(*webdav.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *webdav.ReplicaClient, got %T", client)
|
||||
}
|
||||
if webdavClient.URL != "http://host.example.com" {
|
||||
t.Errorf("expected URL 'http://host.example.com', got %q", webdavClient.URL)
|
||||
}
|
||||
if webdavClient.Path != "path" {
|
||||
t.Errorf("expected path 'path', got %q", webdavClient.Path)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WebDAVS", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("webdavs://host.example.com/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.Type() != "webdav" {
|
||||
t.Errorf("expected type 'webdav', got %q", client.Type())
|
||||
}
|
||||
webdavClient, ok := client.(*webdav.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *webdav.ReplicaClient, got %T", client)
|
||||
}
|
||||
if webdavClient.URL != "https://host.example.com" {
|
||||
t.Errorf("expected URL 'https://host.example.com', got %q", webdavClient.URL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WebDAV_WithCredentials", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("webdav://myuser:secret@host.example.com/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
webdavClient, ok := client.(*webdav.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *webdav.ReplicaClient, got %T", client)
|
||||
}
|
||||
if webdavClient.Username != "myuser" {
|
||||
t.Errorf("expected username 'myuser', got %q", webdavClient.Username)
|
||||
}
|
||||
if webdavClient.Password != "secret" {
|
||||
t.Errorf("expected password 'secret', got %q", webdavClient.Password)
|
||||
}
|
||||
if webdavClient.URL != "http://host.example.com" {
|
||||
t.Errorf("expected URL 'http://host.example.com', got %q", webdavClient.URL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WebDAVS_WithCredentials", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("webdavs://myuser:secret@host.example.com/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
webdavClient, ok := client.(*webdav.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *webdav.ReplicaClient, got %T", client)
|
||||
}
|
||||
if webdavClient.Username != "myuser" {
|
||||
t.Errorf("expected username 'myuser', got %q", webdavClient.Username)
|
||||
}
|
||||
if webdavClient.Password != "secret" {
|
||||
t.Errorf("expected password 'secret', got %q", webdavClient.Password)
|
||||
}
|
||||
if webdavClient.URL != "https://host.example.com" {
|
||||
t.Errorf("expected URL 'https://host.example.com', got %q", webdavClient.URL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WebDAV_MissingHost", func(t *testing.T) {
|
||||
_, err := litestream.NewReplicaClientFromURL("webdav:///path")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing host")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NATS", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("nats://localhost:4222/mybucket")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.Type() != "nats" {
|
||||
t.Errorf("expected type 'nats', got %q", client.Type())
|
||||
}
|
||||
natsClient, ok := client.(*nats.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *nats.ReplicaClient, got %T", client)
|
||||
}
|
||||
if natsClient.URL != "nats://localhost:4222" {
|
||||
t.Errorf("expected URL 'nats://localhost:4222', got %q", natsClient.URL)
|
||||
}
|
||||
if natsClient.BucketName != "mybucket" {
|
||||
t.Errorf("expected bucket 'mybucket', got %q", natsClient.BucketName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NATS_WithCredentials", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("nats://myuser:secret@localhost:4222/mybucket")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
natsClient, ok := client.(*nats.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *nats.ReplicaClient, got %T", client)
|
||||
}
|
||||
if natsClient.Username != "myuser" {
|
||||
t.Errorf("expected username 'myuser', got %q", natsClient.Username)
|
||||
}
|
||||
if natsClient.Password != "secret" {
|
||||
t.Errorf("expected password 'secret', got %q", natsClient.Password)
|
||||
}
|
||||
if natsClient.URL != "nats://localhost:4222" {
|
||||
t.Errorf("expected URL 'nats://localhost:4222', got %q", natsClient.URL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NATS_MissingBucket", func(t *testing.T) {
|
||||
_, err := litestream.NewReplicaClientFromURL("nats://localhost:4222/")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing bucket")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OSS", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("oss://mybucket/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.Type() != "oss" {
|
||||
t.Errorf("expected type 'oss', got %q", client.Type())
|
||||
}
|
||||
ossClient, ok := client.(*oss.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *oss.ReplicaClient, got %T", client)
|
||||
}
|
||||
if ossClient.Bucket != "mybucket" {
|
||||
t.Errorf("expected bucket 'mybucket', got %q", ossClient.Bucket)
|
||||
}
|
||||
if ossClient.Path != "path" {
|
||||
t.Errorf("expected path 'path', got %q", ossClient.Path)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OSS_WithRegion", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("oss://mybucket.oss-cn-shanghai.aliyuncs.com/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ossClient, ok := client.(*oss.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *oss.ReplicaClient, got %T", client)
|
||||
}
|
||||
if ossClient.Bucket != "mybucket" {
|
||||
t.Errorf("expected bucket 'mybucket', got %q", ossClient.Bucket)
|
||||
}
|
||||
// Note: Region is extracted without the 'oss-' prefix
|
||||
if ossClient.Region != "cn-shanghai" {
|
||||
t.Errorf("expected region 'cn-shanghai', got %q", ossClient.Region)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OSS_MissingBucket", func(t *testing.T) {
|
||||
_, err := litestream.NewReplicaClientFromURL("oss:///path")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing bucket")
|
||||
}
|
||||
})
|
||||
|
||||
// Note: file:// with empty path returns "." due to path.Clean behavior.
|
||||
// This is technically valid but may not be the intended behavior.
|
||||
t.Run("File_EmptyPathReturnsDot", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("file://")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fileClient, ok := client.(*file.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *file.ReplicaClient, got %T", client)
|
||||
}
|
||||
// path.Clean("") returns "." which passes the empty check
|
||||
if fileClient.Path() != "." {
|
||||
t.Errorf("expected path '.', got %q", fileClient.Path())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("S3_ARN", func(t *testing.T) {
|
||||
client, err := litestream.NewReplicaClientFromURL("s3://arn:aws:s3:us-east-1:123456789012:accesspoint/db-access/backups")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s3Client, ok := client.(*s3.ReplicaClient)
|
||||
if !ok {
|
||||
t.Fatalf("expected *s3.ReplicaClient, got %T", client)
|
||||
}
|
||||
if s3Client.Bucket != "arn:aws:s3:us-east-1:123456789012:accesspoint/db-access" {
|
||||
t.Errorf("expected bucket ARN, got %q", s3Client.Bucket)
|
||||
}
|
||||
if s3Client.Path != "backups" {
|
||||
t.Errorf("expected path 'backups', got %q", s3Client.Path)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("S3_MissingBucket", func(t *testing.T) {
|
||||
_, err := litestream.NewReplicaClientFromURL("s3:///path")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing bucket")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyURL", func(t *testing.T) {
|
||||
_, err := litestream.NewReplicaClientFromURL("")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty URL")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnsupportedScheme", func(t *testing.T) {
|
||||
_, err := litestream.NewReplicaClientFromURL("unknown://bucket/path")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unsupported scheme")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InvalidURL", func(t *testing.T) {
|
||||
_, err := litestream.NewReplicaClientFromURL("not-a-valid-url")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid URL")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReplicaTypeFromURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
url string
|
||||
expected string
|
||||
}{
|
||||
{"s3://bucket/path", "s3"},
|
||||
{"gs://bucket/path", "gs"},
|
||||
{"abs://container/path", "abs"},
|
||||
{"file:///path/to/replica", "file"},
|
||||
{"sftp://host/path", "sftp"},
|
||||
{"webdav://host/path", "webdav"},
|
||||
{"webdavs://host/path", "webdav"},
|
||||
{"nats://host/bucket", "nats"},
|
||||
{"oss://bucket/path", "oss"},
|
||||
{"", ""},
|
||||
{"invalid", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.url, func(t *testing.T) {
|
||||
got := litestream.ReplicaTypeFromURL(tt.url)
|
||||
if got != tt.expected {
|
||||
t.Errorf("ReplicaTypeFromURL(%q) = %q, want %q", tt.url, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected bool
|
||||
}{
|
||||
{"s3://bucket/path", true},
|
||||
{"file:///path", true},
|
||||
{"https://example.com", true},
|
||||
{"/path/to/file", false},
|
||||
{"relative/path", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s, func(t *testing.T) {
|
||||
got := litestream.IsURL(tt.s)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsURL(%q) = %v, want %v", tt.s, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolQueryValue(t *testing.T) {
|
||||
t.Run("True values", func(t *testing.T) {
|
||||
for _, v := range []string{"true", "True", "TRUE", "1", "t", "yes"} {
|
||||
query := make(map[string][]string)
|
||||
query["key"] = []string{v}
|
||||
value, ok := litestream.BoolQueryValue(query, "key")
|
||||
if !ok {
|
||||
t.Errorf("BoolQueryValue with %q should be ok", v)
|
||||
}
|
||||
if !value {
|
||||
t.Errorf("BoolQueryValue with %q should be true", v)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("False values", func(t *testing.T) {
|
||||
for _, v := range []string{"false", "False", "FALSE", "0", "f", "no"} {
|
||||
query := make(map[string][]string)
|
||||
query["key"] = []string{v}
|
||||
value, ok := litestream.BoolQueryValue(query, "key")
|
||||
if !ok {
|
||||
t.Errorf("BoolQueryValue with %q should be ok", v)
|
||||
}
|
||||
if value {
|
||||
t.Errorf("BoolQueryValue with %q should be false", v)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Missing key", func(t *testing.T) {
|
||||
query := make(map[string][]string)
|
||||
_, ok := litestream.BoolQueryValue(query, "key")
|
||||
if ok {
|
||||
t.Error("BoolQueryValue with missing key should not be ok")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Multiple keys", func(t *testing.T) {
|
||||
query := make(map[string][]string)
|
||||
query["key2"] = []string{"true"}
|
||||
value, ok := litestream.BoolQueryValue(query, "key1", "key2")
|
||||
if !ok {
|
||||
t.Error("BoolQueryValue should find second key")
|
||||
}
|
||||
if !value {
|
||||
t.Error("BoolQueryValue should return true for second key")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Nil query", func(t *testing.T) {
|
||||
_, ok := litestream.BoolQueryValue(nil, "key")
|
||||
if ok {
|
||||
t.Error("BoolQueryValue with nil query should not be ok")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid value returns false with ok", func(t *testing.T) {
|
||||
query := make(map[string][]string)
|
||||
query["key"] = []string{"invalid"}
|
||||
value, ok := litestream.BoolQueryValue(query, "key")
|
||||
if !ok {
|
||||
t.Error("BoolQueryValue with invalid value should be ok")
|
||||
}
|
||||
if value {
|
||||
t.Error("BoolQueryValue with invalid value should be false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsTigrisEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
endpoint string
|
||||
expected bool
|
||||
}{
|
||||
{"fly.storage.tigris.dev", true},
|
||||
{"FLY.STORAGE.TIGRIS.DEV", true},
|
||||
{"https://fly.storage.tigris.dev", true},
|
||||
{"http://fly.storage.tigris.dev", true},
|
||||
{"s3.amazonaws.com", false},
|
||||
{"localhost:9000", false},
|
||||
{"", false},
|
||||
{" ", false},
|
||||
{"https://s3.us-east-1.amazonaws.com", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.endpoint, func(t *testing.T) {
|
||||
got := litestream.IsTigrisEndpoint(tt.endpoint)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsTigrisEndpoint(%q) = %v, want %v", tt.endpoint, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegionFromS3ARN(t *testing.T) {
|
||||
tests := []struct {
|
||||
arn string
|
||||
expected string
|
||||
}{
|
||||
{"arn:aws:s3:us-east-1:123456789012:accesspoint/db-access", "us-east-1"},
|
||||
{"arn:aws:s3:eu-west-1:123456789012:accesspoint/db-access", "eu-west-1"},
|
||||
{"arn:aws:s3:ap-southeast-2:123456789012:accesspoint/db-access", "ap-southeast-2"},
|
||||
{"arn:aws:s3::123456789012:accesspoint/db-access", ""},
|
||||
{"invalid-arn", ""},
|
||||
{"", ""},
|
||||
{"arn:aws:s3", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.arn, func(t *testing.T) {
|
||||
got := litestream.RegionFromS3ARN(tt.arn)
|
||||
if got != tt.expected {
|
||||
t.Errorf("RegionFromS3ARN(%q) = %q, want %q", tt.arn, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanReplicaURLPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"path", "path"},
|
||||
{"/path", "path"},
|
||||
{"path/", "path"},
|
||||
{"/path/", "path"},
|
||||
{"path/to/db", "path/to/db"},
|
||||
{"/path/to/db", "path/to/db"},
|
||||
{"//path//to//db", "path/to/db"},
|
||||
{".", ""},
|
||||
{"/.", ""},
|
||||
{"./path", "path"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
got := litestream.CleanReplicaURLPath(tt.path)
|
||||
if got != tt.expected {
|
||||
t.Errorf("CleanReplicaURLPath(%q) = %q, want %q", tt.path, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,10 @@ import (
|
||||
"github.com/benbjohnson/litestream/internal"
|
||||
)
|
||||
|
||||
func init() {
|
||||
litestream.RegisterReplicaClientFactory("s3", NewReplicaClientFromURL)
|
||||
}
|
||||
|
||||
// ReplicaClientType is the client type for this package.
|
||||
const ReplicaClientType = "s3"
|
||||
|
||||
@@ -90,6 +94,109 @@ func NewReplicaClient() *ReplicaClient {
|
||||
}
|
||||
}
|
||||
|
||||
// NewReplicaClientFromURL creates a new ReplicaClient from URL components.
|
||||
// This is used by the replica client factory registration.
|
||||
func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) {
|
||||
client := NewReplicaClient()
|
||||
|
||||
var (
|
||||
bucket string
|
||||
region string
|
||||
endpoint string
|
||||
forcePathStyle bool
|
||||
skipVerify bool
|
||||
signPayload bool
|
||||
signPayloadSet bool
|
||||
requireMD5 bool
|
||||
requireMD5Set bool
|
||||
)
|
||||
|
||||
// Parse host for bucket and region
|
||||
if strings.HasPrefix(host, "arn:") {
|
||||
bucket = host
|
||||
region = litestream.RegionFromS3ARN(host)
|
||||
} else {
|
||||
bucket, region, endpoint, forcePathStyle = ParseHost(host)
|
||||
}
|
||||
|
||||
// Override with query parameters if provided
|
||||
if qEndpoint := query.Get("endpoint"); qEndpoint != "" {
|
||||
// Ensure endpoint has a scheme
|
||||
if !strings.HasPrefix(qEndpoint, "http://") && !strings.HasPrefix(qEndpoint, "https://") {
|
||||
qEndpoint = "http://" + qEndpoint
|
||||
}
|
||||
endpoint = qEndpoint
|
||||
// Default to path style for custom endpoints unless explicitly set to false
|
||||
if query.Get("forcePathStyle") != "false" {
|
||||
forcePathStyle = true
|
||||
}
|
||||
}
|
||||
if qRegion := query.Get("region"); qRegion != "" {
|
||||
region = qRegion
|
||||
}
|
||||
if qForcePathStyle := query.Get("forcePathStyle"); qForcePathStyle != "" {
|
||||
forcePathStyle = qForcePathStyle == "true"
|
||||
}
|
||||
if qSkipVerify := query.Get("skipVerify"); qSkipVerify != "" {
|
||||
skipVerify = qSkipVerify == "true"
|
||||
}
|
||||
if v, ok := litestream.BoolQueryValue(query, "signPayload", "sign-payload"); ok {
|
||||
signPayload = v
|
||||
signPayloadSet = true
|
||||
}
|
||||
if v, ok := litestream.BoolQueryValue(query, "requireContentMD5", "require-content-md5"); ok {
|
||||
requireMD5 = v
|
||||
requireMD5Set = true
|
||||
}
|
||||
|
||||
// Ensure bucket is set
|
||||
if bucket == "" {
|
||||
return nil, fmt.Errorf("bucket required for s3 replica URL")
|
||||
}
|
||||
|
||||
// Check for Tigris endpoint
|
||||
isTigris := litestream.IsTigrisEndpoint(endpoint)
|
||||
|
||||
// Read authentication from environment variables
|
||||
if v := os.Getenv("AWS_ACCESS_KEY_ID"); v != "" {
|
||||
client.AccessKeyID = v
|
||||
} else if v := os.Getenv("LITESTREAM_ACCESS_KEY_ID"); v != "" {
|
||||
client.AccessKeyID = v
|
||||
}
|
||||
if v := os.Getenv("AWS_SECRET_ACCESS_KEY"); v != "" {
|
||||
client.SecretAccessKey = v
|
||||
} else if v := os.Getenv("LITESTREAM_SECRET_ACCESS_KEY"); v != "" {
|
||||
client.SecretAccessKey = v
|
||||
}
|
||||
|
||||
// Configure client
|
||||
client.Bucket = bucket
|
||||
client.Path = urlPath
|
||||
client.Region = region
|
||||
client.Endpoint = endpoint
|
||||
client.ForcePathStyle = forcePathStyle
|
||||
client.SkipVerify = skipVerify
|
||||
|
||||
// Apply Tigris defaults
|
||||
if isTigris {
|
||||
if !signPayloadSet {
|
||||
signPayload, signPayloadSet = true, true
|
||||
}
|
||||
if !requireMD5Set {
|
||||
requireMD5, requireMD5Set = false, true
|
||||
}
|
||||
}
|
||||
|
||||
if signPayloadSet {
|
||||
client.SignPayload = signPayload
|
||||
}
|
||||
if requireMD5Set {
|
||||
client.RequireContentMD5 = requireMD5
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Type returns "s3" as the client type.
|
||||
func (c *ReplicaClient) Type() string {
|
||||
return ReplicaClientType
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
@@ -21,6 +22,10 @@ import (
|
||||
"github.com/benbjohnson/litestream/internal"
|
||||
)
|
||||
|
||||
func init() {
|
||||
litestream.RegisterReplicaClientFactory("sftp", NewReplicaClientFromURL)
|
||||
}
|
||||
|
||||
// ReplicaClientType is the client type for this package.
|
||||
const ReplicaClientType = "sftp"
|
||||
|
||||
@@ -61,13 +66,44 @@ func NewReplicaClient() *ReplicaClient {
|
||||
}
|
||||
}
|
||||
|
||||
// NewReplicaClientFromURL creates a new ReplicaClient from URL components.
|
||||
// This is used by the replica client factory registration.
|
||||
// URL format: sftp://[user[:password]@]host[:port]/path
|
||||
func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) {
|
||||
client := NewReplicaClient()
|
||||
|
||||
// Extract credentials from userinfo
|
||||
if userinfo != nil {
|
||||
client.User = userinfo.Username()
|
||||
client.Password, _ = userinfo.Password()
|
||||
}
|
||||
|
||||
client.Host = host
|
||||
client.Path = urlPath
|
||||
|
||||
if client.Host == "" {
|
||||
return nil, fmt.Errorf("host required for sftp replica URL")
|
||||
}
|
||||
if client.User == "" {
|
||||
return nil, fmt.Errorf("user required for sftp replica URL")
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Type returns "sftp" as the client type.
|
||||
func (c *ReplicaClient) Type() string {
|
||||
return ReplicaClientType
|
||||
}
|
||||
|
||||
// Init initializes the connection to SFTP. No-op if already initialized.
|
||||
func (c *ReplicaClient) Init(ctx context.Context) (_ *sftp.Client, err error) {
|
||||
func (c *ReplicaClient) Init(ctx context.Context) error {
|
||||
_, err := c.init(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// init initializes the connection and returns the SFTP client.
|
||||
func (c *ReplicaClient) init(ctx context.Context) (_ *sftp.Client, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@@ -144,7 +180,7 @@ func (c *ReplicaClient) Init(ctx context.Context) (_ *sftp.Client, err error) {
|
||||
func (c *ReplicaClient) DeleteAll(ctx context.Context) (err error) {
|
||||
defer func() { c.resetOnConnError(err) }()
|
||||
|
||||
sftpClient, err := c.Init(ctx)
|
||||
sftpClient, err := c.init(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -188,7 +224,7 @@ func (c *ReplicaClient) DeleteAll(ctx context.Context) (err error) {
|
||||
func (c *ReplicaClient) LTXFiles(ctx context.Context, level int, seek ltx.TXID, _ bool) (_ ltx.FileIterator, err error) {
|
||||
defer func() { c.resetOnConnError(err) }()
|
||||
|
||||
sftpClient, err := c.Init(ctx)
|
||||
sftpClient, err := c.init(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -227,7 +263,7 @@ func (c *ReplicaClient) LTXFiles(ctx context.Context, level int, seek ltx.TXID,
|
||||
func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, maxTXID ltx.TXID, rd io.Reader) (info *ltx.FileInfo, err error) {
|
||||
defer func() { c.resetOnConnError(err) }()
|
||||
|
||||
sftpClient, err := c.Init(ctx)
|
||||
sftpClient, err := c.init(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -287,7 +323,7 @@ func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, ma
|
||||
func (c *ReplicaClient) OpenLTXFile(ctx context.Context, level int, minTXID, maxTXID ltx.TXID, offset, size int64) (_ io.ReadCloser, err error) {
|
||||
defer func() { c.resetOnConnError(err) }()
|
||||
|
||||
sftpClient, err := c.Init(ctx)
|
||||
sftpClient, err := c.init(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -316,7 +352,7 @@ func (c *ReplicaClient) OpenLTXFile(ctx context.Context, level int, minTXID, max
|
||||
func (c *ReplicaClient) DeleteLTXFiles(ctx context.Context, a []*ltx.FileInfo) (err error) {
|
||||
defer func() { c.resetOnConnError(err) }()
|
||||
|
||||
sftpClient, err := c.Init(ctx)
|
||||
sftpClient, err := c.init(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -339,7 +375,7 @@ func (c *ReplicaClient) DeleteLTXFiles(ctx context.Context, a []*ltx.FileInfo) (
|
||||
func (c *ReplicaClient) Cleanup(ctx context.Context) (err error) {
|
||||
defer func() { c.resetOnConnError(err) }()
|
||||
|
||||
sftpClient, err := c.Init(ctx)
|
||||
sftpClient, err := c.init(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -126,6 +126,8 @@ func newDelayedReplicaClient(delay time.Duration) *delayedReplicaClient {
|
||||
|
||||
func (c *delayedReplicaClient) Type() string { return "delayed" }
|
||||
|
||||
func (c *delayedReplicaClient) Init(context.Context) error { return nil }
|
||||
|
||||
func (c *delayedReplicaClient) key(level int, min, max ltx.TXID) string {
|
||||
return fmt.Sprintf("%d:%s:%s", level, min.String(), max.String())
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
@@ -19,6 +20,11 @@ import (
|
||||
"github.com/benbjohnson/litestream/internal"
|
||||
)
|
||||
|
||||
func init() {
|
||||
litestream.RegisterReplicaClientFactory("webdav", NewReplicaClientFromURL)
|
||||
litestream.RegisterReplicaClientFactory("webdavs", NewReplicaClientFromURL)
|
||||
}
|
||||
|
||||
const ReplicaClientType = "webdav"
|
||||
|
||||
const (
|
||||
@@ -46,11 +52,45 @@ func NewReplicaClient() *ReplicaClient {
|
||||
}
|
||||
}
|
||||
|
||||
// NewReplicaClientFromURL creates a new ReplicaClient from URL components.
|
||||
// This is used by the replica client factory registration.
|
||||
// URL format: webdav://[user[:password]@]host[:port]/path or webdavs://... (for HTTPS)
|
||||
func NewReplicaClientFromURL(scheme, host, urlPath string, query url.Values, userinfo *url.Userinfo) (litestream.ReplicaClient, error) {
|
||||
client := NewReplicaClient()
|
||||
|
||||
// Determine HTTP or HTTPS based on scheme
|
||||
httpScheme := "http"
|
||||
if scheme == "webdavs" {
|
||||
httpScheme = "https"
|
||||
}
|
||||
|
||||
// Extract credentials from userinfo
|
||||
if userinfo != nil {
|
||||
client.Username = userinfo.Username()
|
||||
client.Password, _ = userinfo.Password()
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("host required for webdav replica URL")
|
||||
}
|
||||
|
||||
client.URL = fmt.Sprintf("%s://%s", httpScheme, host)
|
||||
client.Path = urlPath
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *ReplicaClient) Type() string {
|
||||
return ReplicaClientType
|
||||
}
|
||||
|
||||
func (c *ReplicaClient) Init(ctx context.Context) (_ *gowebdav.Client, err error) {
|
||||
func (c *ReplicaClient) Init(ctx context.Context) error {
|
||||
_, err := c.init(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// init initializes the connection and returns the WebDAV client.
|
||||
func (c *ReplicaClient) init(ctx context.Context) (_ *gowebdav.Client, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@@ -75,7 +115,7 @@ func (c *ReplicaClient) Init(ctx context.Context) (_ *gowebdav.Client, err error
|
||||
}
|
||||
|
||||
func (c *ReplicaClient) DeleteAll(ctx context.Context) error {
|
||||
client, err := c.Init(ctx)
|
||||
client, err := c.init(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -90,7 +130,7 @@ func (c *ReplicaClient) DeleteAll(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (c *ReplicaClient) LTXFiles(ctx context.Context, level int, seek ltx.TXID, _ bool) (_ ltx.FileIterator, err error) {
|
||||
client, err := c.Init(ctx)
|
||||
client, err := c.init(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -190,7 +230,7 @@ func (c *ReplicaClient) LTXFiles(ctx context.Context, level int, seek ltx.TXID,
|
||||
// - https://github.com/nextcloud/server/issues/7995 (0-byte file bug)
|
||||
// - https://evertpot.com/260/ (WebDAV chunked encoding compatibility)
|
||||
func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, maxTXID ltx.TXID, rd io.Reader) (info *ltx.FileInfo, err error) {
|
||||
client, err := c.Init(ctx)
|
||||
client, err := c.init(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -255,7 +295,7 @@ func (c *ReplicaClient) WriteLTXFile(ctx context.Context, level int, minTXID, ma
|
||||
}
|
||||
|
||||
func (c *ReplicaClient) OpenLTXFile(ctx context.Context, level int, minTXID, maxTXID ltx.TXID, offset, size int64) (_ io.ReadCloser, err error) {
|
||||
client, err := c.Init(ctx)
|
||||
client, err := c.init(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -307,7 +347,7 @@ func (c *ReplicaClient) OpenLTXFile(ctx context.Context, level int, minTXID, max
|
||||
}
|
||||
|
||||
func (c *ReplicaClient) DeleteLTXFiles(ctx context.Context, a []*ltx.FileInfo) error {
|
||||
client, err := c.Init(ctx)
|
||||
client, err := c.init(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func TestReplicaClient_Init_RequiresURL(t *testing.T) {
|
||||
c := webdav.NewReplicaClient()
|
||||
c.URL = ""
|
||||
|
||||
if _, err := c.Init(context.TODO()); err == nil {
|
||||
if err := c.Init(context.TODO()); err == nil {
|
||||
t.Fatal("expected error when URL is empty")
|
||||
} else if got, want := err.Error(), "webdav url required"; got != want {
|
||||
t.Fatalf("error=%v, want %v", got, want)
|
||||
|
||||
Reference in New Issue
Block a user