diff --git a/cmd/litestream/main.go b/cmd/litestream/main.go index 9698bf50..1da07449 100644 --- a/cmd/litestream/main.go +++ b/cmd/litestream/main.go @@ -32,13 +32,20 @@ var ( Version = "(development build)" ) -// errStop is a terminal error for indicating program should quit. -var errStop = errors.New("stop") +var ( + // errStop is a terminal error for indicating program should quit. + errStop = errors.New("stop") + + // errLeaseExpired is a terminal error indicatingthe program has exited due to lease expiration. + errLeaseExpired = errors.New("lease expired") +) func main() { m := NewMain() if err := m.Run(context.Background(), os.Args[1:]); err == flag.ErrHelp || err == errStop { os.Exit(1) + } else if err == errLeaseExpired { + os.Exit(2) } else if err != nil { slog.Error("failed to run", "error", err) os.Exit(1) @@ -89,8 +96,11 @@ func (m *Main) Run(ctx context.Context, args []string) (err error) { return err } - // Wait for signal to stop program. + // Wait for lease expiration or for a signal to stop program. select { + case <-c.leaseExpireCh: + return errLeaseExpired + case err = <-c.execCh: slog.Info("subprocess exited, litestream shutting down") case sig := <-signalCh: @@ -162,6 +172,9 @@ type Config struct { // List of databases to manage. DBs []*DBConfig `yaml:"dbs"` + // Optional. Distributed lease configuration. + Lease *LeaseConfig `yaml:"lease"` + // Subcommand to execute during replication. // Litestream will shutdown when subcommand exits. Exec string `yaml:"exec"` @@ -281,6 +294,119 @@ func ReadConfigFile(filename string, expandEnv bool) (_ Config, err error) { return config, nil } +// LeaseConfig represents the configuration for a distributed lease. +type LeaseConfig struct { + Type string `yaml:"type"` // "s3" + Path string `yaml:"path"` + URL string `yaml:"url"` + Timeout *time.Duration `yaml:"timeout"` + Owner string `yaml:"owner"` + + // S3 settings + AccessKeyID string `yaml:"access-key-id"` + SecretAccessKey string `yaml:"secret-access-key"` + Region string `yaml:"region"` + Bucket string `yaml:"bucket"` + Endpoint string `yaml:"endpoint"` + ForcePathStyle *bool `yaml:"force-path-style"` + SkipVerify bool `yaml:"skip-verify"` +} + +// NewLeaserFromConfig instantiates a lease client. +func NewLeaserFromConfig(c *LeaseConfig) (_ litestream.Leaser, err error) { + // Ensure user did not specify URL in path. + if isURL(c.Path) { + return nil, fmt.Errorf("leaser path cannot be a url, please use the 'url' field instead: %s", c.Path) + } + + switch c.Type { + case "s3": + return newS3LeaserFromConfig(c) + default: + return nil, fmt.Errorf("unknown leaser type in config: %q", c.Type) + } +} + +// newS3LeaserFromConfig returns a new instance of s3.Leaser built from config. +func newS3LeaserFromConfig(c *LeaseConfig) (_ *s3.Leaser, err error) { + // Ensure URL & constituent parts are not both specified. + if c.URL != "" && c.Path != "" { + return nil, fmt.Errorf("cannot specify url & path for s3 leaser") + } else if c.URL != "" && c.Bucket != "" { + return nil, fmt.Errorf("cannot specify url & bucket for s3 leaser") + } + + bucket, path := c.Bucket, c.Path + region, endpoint, skipVerify := c.Region, c.Endpoint, c.SkipVerify + + // Use path style if an endpoint is explicitly set. This works because the + // only service to not use path style is AWS which does not use an endpoint. + forcePathStyle := (endpoint != "") + if v := c.ForcePathStyle; v != nil { + forcePathStyle = *v + } + + // Apply settings from URL, if specified. + if c.URL != "" { + _, host, upath, err := ParseReplicaURL(c.URL) + if err != nil { + return nil, err + } + ubucket, uregion, uendpoint, uforcePathStyle := s3.ParseHost(host) + + // Only apply URL parts to field that have not been overridden. + if path == "" { + path = upath + } + if bucket == "" { + bucket = ubucket + } + if region == "" { + region = uregion + } + if endpoint == "" { + endpoint = uendpoint + } + if !forcePathStyle { + forcePathStyle = uforcePathStyle + } + } + + // Ensure required settings are set. + if bucket == "" { + return nil, fmt.Errorf("bucket required for s3 leaser") + } + + // Build leaser. + leaser := s3.NewLeaser() + leaser.AccessKeyID = c.AccessKeyID + leaser.AccessKeyID = c.AccessKeyID + leaser.SecretAccessKey = c.SecretAccessKey + leaser.Bucket = bucket + leaser.Path = path + leaser.Region = region + leaser.Endpoint = endpoint + leaser.ForcePathStyle = forcePathStyle + leaser.SkipVerify = skipVerify + + owner := c.Owner + if owner == "" { + owner, _ = os.Hostname() + } + leaser.Owner = owner + + if v := c.Timeout; v != nil { + leaser.LeaseTimeout = *v + } + + // Initialize leaser to build client. + if err := leaser.Open(); err != nil { + return nil, err + } + + return leaser, nil +} + // DBConfig represents the configuration for a single database. type DBConfig struct { Path string `yaml:"path"` diff --git a/cmd/litestream/replicate.go b/cmd/litestream/replicate.go index 7c9b4f30..577c9f61 100644 --- a/cmd/litestream/replicate.go +++ b/cmd/litestream/replicate.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "flag" "fmt" "log/slog" @@ -10,6 +11,9 @@ import ( _ "net/http/pprof" "os" "os/exec" + "sync" + "sync/atomic" + "time" "github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream/abs" @@ -23,19 +27,39 @@ import ( // ReplicateCommand represents a command that continuously replicates SQLite databases. type ReplicateCommand struct { - cmd *exec.Cmd // subcommand - execCh chan error // subcommand error channel + cmd *exec.Cmd // subcommand + wg sync.WaitGroup + execCh chan error // subcommand error channel + leaseExpireCh chan struct{} // lease expiration error channel + leaserCtx context.Context + leaserCancel context.CancelCauseFunc + + // Holds the current lease, if any. + lease atomic.Value // *litestream.Lease Config Config + // Lease client for managing distributed lease. + // May be nil if no lease config specified. + Leaser litestream.Leaser + // List of managed databases specified in the config. DBs []*litestream.DB } func NewReplicateCommand() *ReplicateCommand { - return &ReplicateCommand{ - execCh: make(chan error), + c := &ReplicateCommand{ + execCh: make(chan error), + leaseExpireCh: make(chan struct{}), } + c.leaserCtx, c.leaserCancel = context.WithCancelCause(context.Background()) + + c.lease.Store((*litestream.Lease)(nil)) + return c +} + +func (c *ReplicateCommand) Lease() *litestream.Lease { + return c.lease.Load().(*litestream.Lease) } // ParseFlags parses the CLI flags and loads the configuration file. @@ -87,6 +111,18 @@ func (c *ReplicateCommand) Run() (err error) { // Display version information. slog.Info("litestream", "version", Version) + // Acquire lease if config specified. + if c.Config.Lease != nil { + c.Leaser, err = NewLeaserFromConfig(c.Config.Lease) + if err != nil { + return fmt.Errorf("initialize leaser: %w", err) + } + + if err := c.acquireLease(context.Background()); err != nil { + return fmt.Errorf("acquire initial lease: %w", err) + } + } + // Setup databases. if len(c.Config.DBs) == 0 { slog.Error("no databases specified in configuration") @@ -175,9 +211,116 @@ func (c *ReplicateCommand) Close() (err error) { } } } + + // Stop lease monitoring. + c.leaserCancel(errors.New("litestream shutting down")) + c.wg.Wait() + + // Release the most recent lease. + if lease := c.Lease(); lease != nil { + slog.Info("releasing lease", slog.Int64("epoch", lease.Epoch)) + + if e := c.Leaser.ReleaseLease(context.Background(), lease.Epoch); e != nil { + slog.Error("failed to release lease", + slog.Int64("epoch", lease.Epoch), + slog.Any("error", e)) + } + } + return err } +// acquireLease initializes a lease client based on the config, acquires the initial +// lease, and then continuously monitors & renews the lease in the background. +func (c *ReplicateCommand) acquireLease(ctx context.Context) (err error) { + timer := time.NewTimer(1) + defer timer.Stop() + + // Continually try to acquire lease if there is an existing lease. +OUTER: + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-timer.C: + var leaseExistsError *litestream.LeaseExistsError + lease, err := c.Leaser.AcquireLease(ctx) + if errors.As(err, &leaseExistsError) { + timer.Reset(litestream.LeaseRetryInterval) + slog.Info("lease already exists, waiting to retry", + slog.Int64("epoch", leaseExistsError.Lease.Epoch), + slog.String("owner", leaseExistsError.Lease.Owner), + slog.Time("expires", leaseExistsError.Lease.Deadline())) + continue + } else if err != nil { + return fmt.Errorf("acquire lease: %w", err) + } + c.lease.Store(lease) + break OUTER + } + } + + lease := c.Lease() + slog.Info("lease acquired", + slog.Int64("epoch", lease.Epoch), + slog.Duration("timeout", lease.Timeout), + slog.String("owner", lease.Owner)) + + // Continuously monitor and renew lease in a separate goroutine. + c.wg.Add(1) + go func() { defer c.wg.Done(); c.monitorLease(c.leaserCtx) }() + + return nil +} + +func (c *ReplicateCommand) monitorLease(ctx context.Context) { + timer := time.NewTimer(c.Lease().Timeout / 2) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + slog.Error("stopping lease monitor") + return + + case <-timer.C: + var leaseExistsError *litestream.LeaseExistsError + + lease := c.Lease() + slog.Debug("attempting to renew lease", slog.Int64("epoch", lease.Epoch)) + + // Attempt to renew our currently held lease. + newLease, err := c.Leaser.RenewLease(ctx, lease) + if errors.As(err, &leaseExistsError) { + slog.Error("cannot renew lease, another lease exists, exiting", + slog.Int64("epoch", leaseExistsError.Lease.Epoch), + slog.String("owner", leaseExistsError.Lease.Owner)) + c.leaseExpireCh <- struct{}{} + return + } + + // If our lease has expired then give up and exit. + if lease.Expired() { + slog.Error("lease expired, exiting") + c.leaseExpireCh <- struct{}{} + return + } + + // If we hit a temporary error then aggressively retry. + if err != nil { + slog.Warn("temporarily unable to renew lease, retrying", slog.Any("error", err)) + timer.Reset(1 * time.Second) + continue + } + + // Replace lease and try to renew after halfway through the timeout. + slog.Debug("lease renewed", slog.Int64("epoch", newLease.Epoch)) + c.lease.Store(newLease) + timer.Reset(lease.Timeout / 2) + } + } +} + // Usage prints the help screen to STDOUT. func (c *ReplicateCommand) Usage() { fmt.Printf(` diff --git a/integration_test.go b/integration_test.go new file mode 100644 index 00000000..971de4af --- /dev/null +++ b/integration_test.go @@ -0,0 +1,45 @@ +package litestream_test + +import ( + "flag" + "os" +) + +// Enables integration tests. +var integration = flag.String("integration", "file", "") + +// S3 settings +var ( + // Replica client settings + s3AccessKeyID = flag.String("s3-access-key-id", os.Getenv("LITESTREAM_S3_ACCESS_KEY_ID"), "") + s3SecretAccessKey = flag.String("s3-secret-access-key", os.Getenv("LITESTREAM_S3_SECRET_ACCESS_KEY"), "") + s3Region = flag.String("s3-region", os.Getenv("LITESTREAM_S3_REGION"), "") + s3Bucket = flag.String("s3-bucket", os.Getenv("LITESTREAM_S3_BUCKET"), "") + s3Path = flag.String("s3-path", os.Getenv("LITESTREAM_S3_PATH"), "") + s3Endpoint = flag.String("s3-endpoint", os.Getenv("LITESTREAM_S3_ENDPOINT"), "") + s3ForcePathStyle = flag.Bool("s3-force-path-style", os.Getenv("LITESTREAM_S3_FORCE_PATH_STYLE") == "true", "") + s3SkipVerify = flag.Bool("s3-skip-verify", os.Getenv("LITESTREAM_S3_SKIP_VERIFY") == "true", "") +) + +// Google cloud storage settings +var ( + gcsBucket = flag.String("gcs-bucket", os.Getenv("LITESTREAM_GCS_BUCKET"), "") + gcsPath = flag.String("gcs-path", os.Getenv("LITESTREAM_GCS_PATH"), "") +) + +// Azure blob storage settings +var ( + absAccountName = flag.String("abs-account-name", os.Getenv("LITESTREAM_ABS_ACCOUNT_NAME"), "") + absAccountKey = flag.String("abs-account-key", os.Getenv("LITESTREAM_ABS_ACCOUNT_KEY"), "") + absBucket = flag.String("abs-bucket", os.Getenv("LITESTREAM_ABS_BUCKET"), "") + absPath = flag.String("abs-path", os.Getenv("LITESTREAM_ABS_PATH"), "") +) + +// SFTP settings +var ( + sftpHost = flag.String("sftp-host", os.Getenv("LITESTREAM_SFTP_HOST"), "") + sftpUser = flag.String("sftp-user", os.Getenv("LITESTREAM_SFTP_USER"), "") + sftpPassword = flag.String("sftp-password", os.Getenv("LITESTREAM_SFTP_PASSWORD"), "") + sftpKeyPath = flag.String("sftp-key-path", os.Getenv("LITESTREAM_SFTP_KEY_PATH"), "") + sftpPath = flag.String("sftp-path", os.Getenv("LITESTREAM_SFTP_PATH"), "") +) diff --git a/leaser.go b/leaser.go new file mode 100644 index 00000000..5bac2649 --- /dev/null +++ b/leaser.go @@ -0,0 +1,82 @@ +package litestream + +import ( + "context" + "fmt" + "time" +) + +// DefaultLeaseTimeout is the default amount of time to hold a lease object. +const DefaultLeaseTimeout = 30 * time.Second + +// LeaseRetryInterval is the interval to retry lease acquisition when there +// is an already existing lease. This ensures that the process will pick up +// the lease quickly after it has expired from another process. +const LeaseRetryInterval = 1 * time.Second + +// Leaser represents a client for a distributed leasing service. +type Leaser interface { + // The name of the implementation (e.g. "s3"). + Type() string + + // Returns a sorted list of existing lease epochs. + Epochs(ctx context.Context) ([]int64, error) + + // Attempts to acquire a new lease. + AcquireLease(ctx context.Context) (*Lease, error) + + // Renews an existing lease. + RenewLease(ctx context.Context, lease *Lease) (*Lease, error) + + // Releases an previously acquired lease via expiration. + ReleaseLease(ctx context.Context, epoch int64) error + + // Removes the lease by deleting its underlying file. + // This is typically used for reaping expired leases or for test cleanup. + DeleteLease(ctx context.Context, epoch int64) error +} + +// Lease represents a distributed lease to ensure that only a single Litestream +// instance replicates to a replica at a time. This prevents duplicate streams +// from overwriting each other and causing data loss. +type Lease struct { + // Required. Incremented on each leader change. + Epoch int64 `json:"epoch"` + + // Timestamp of when the lease was last modified. + ModTime time.Time `json:"-"` + + // Required. Duration after last modified time that lease is valid. + // If set to zero, lease is immediately expired. + Timeout time.Duration `json:"timeout"` + + // Optional. Specifies a description of the process that acquired the lease. + // For example, a hostname or machine ID. + Owner string `json:"owner,omitempty"` +} + +// Expired returns true if the lease has expired. +func (l *Lease) Expired() bool { + return l.Timeout <= 0 || l.Deadline().Before(time.Now()) +} + +// Deadline returns the time when the lease will expire if not renewed. +func (l *Lease) Deadline() time.Time { + return l.ModTime.Add(l.Timeout) +} + +// LeaseExistsError represents an error returned when trying to acquire a lease +// when another lease already exists and has not expired yet. +type LeaseExistsError struct { + Lease *Lease +} + +// NewLeaseExistsError returns a new instance of LeaseExistsError. +func NewLeaseExistsError(lease *Lease) *LeaseExistsError { + return &LeaseExistsError{Lease: lease} +} + +// Error implements the error interface. +func (e *LeaseExistsError) Error() string { + return fmt.Sprintf("lease exists (epoch %d)", e.Lease.Epoch) +} diff --git a/leaser_test.go b/leaser_test.go new file mode 100644 index 00000000..12edfaf9 --- /dev/null +++ b/leaser_test.go @@ -0,0 +1,253 @@ +package litestream_test + +import ( + "context" + "errors" + "fmt" + "math/rand" + "path" + "reflect" + "strings" + "testing" + "time" + + "github.com/benbjohnson/litestream" + "github.com/benbjohnson/litestream/s3" +) + +const ( + testOwner = "TESTOWNER" +) + +func TestLeaser_AcquireLease(t *testing.T) { + runWithLeaser(t, "OK", func(t *testing.T, leaser litestream.Leaser) { + // Create the initial lease. + lease, err := leaser.AcquireLease(context.Background()) + if err != nil { + t.Fatal(err) + } else if got, want := lease.Epoch, int64(1); got != want { + t.Fatalf("Epoch=%v, want %v", got, want) + } else if got, want := lease.Owner, testOwner; got != want { + t.Fatalf("Owner=%v, want %v", got, want) + } else if got, want := lease.Timeout, litestream.DefaultLeaseTimeout; got != want { + t.Fatalf("Timeout=%v, want %v", got, want) + } else if lease.ModTime.IsZero() { + t.Fatalf("expected ModTime") + } + + // Fetch associated epoch. + epochs, err := leaser.Epochs(context.Background()) + if err != nil { + t.Fatal(err) + } else if got, want := epochs, []int64{1}; !reflect.DeepEqual(got, want) { + t.Fatalf("Epochs()=%v, want %v", got, want) + } + }) + + runWithLeaser(t, "Reacquire", func(t *testing.T, leaser litestream.Leaser) { + // Create the initial lease. + lease1, err := leaser.AcquireLease(context.Background()) + if err != nil { + t.Fatal(err) + } + + // Release the lease. + if err := leaser.ReleaseLease(context.Background(), lease1.Epoch); err != nil { + t.Fatal(err) + } + + // Acquire a new lease. + lease2, err := leaser.AcquireLease(context.Background()) + if err != nil { + t.Fatal(err) + } else if got, want := lease2.Epoch, int64(2); got != want { + t.Fatalf("Epoch=%v, want %v", got, want) + } + }) + + // Ensure that acquiring a lease before the previous one is released returns a LeaseExistsError. + runWithLeaser(t, "ErrLeaseExists", func(t *testing.T, leaser litestream.Leaser) { + // Acquire an initial lease. + if _, err := leaser.AcquireLease(context.Background()); err != nil { + t.Fatal(err) + } + + // Attempt to acquire a new lease before it has been released. + var leaseExistsErr *litestream.LeaseExistsError + if _, err := leaser.AcquireLease(context.Background()); !errors.As(err, &leaseExistsErr) { + t.Fatalf("unexpected error: %#v", err) + } else if leaseExistsErr.Lease == nil { + t.Fatalf("expected lease") + } else if got, want := leaseExistsErr.Lease.Epoch, int64(1); got != want { + t.Fatalf("err.Lease.Epoch=%v, want %v", got, want) + } + }) +} + +func TestLeaser_RenewLease(t *testing.T) { + runWithLeaser(t, "OK", func(t *testing.T, leaser litestream.Leaser) { + // Create the initial lease. + lease1, err := leaser.AcquireLease(context.Background()) + if err != nil { + t.Fatal(err) + } + + // Renew lease. + lease2, err := leaser.RenewLease(context.Background(), lease1) + if err != nil { + t.Fatal(err) + } else if got, want := lease2.Epoch, int64(2); got != want { + t.Fatalf("Epoch=%v, want %v", got, want) + } + + // Pause momentarily so original lease is reaped. + time.Sleep(1 * time.Second) + + // Fetch associated epochs. + epochs, err := leaser.Epochs(context.Background()) + if err != nil { + t.Fatal(err) + } else if got, want := epochs, []int64{2}; !reflect.DeepEqual(got, want) { + t.Fatalf("Epochs()=%v, want %v", got, want) + } + }) + + runWithLeaser(t, "Reaped", func(t *testing.T, leaser litestream.Leaser) { + // Try to renew a lease when all the lock files have been reaped. + lease, err := leaser.RenewLease(context.Background(), &litestream.Lease{Epoch: 100}) + if err != nil { + t.Fatal(err) + } else if got, want := lease.Epoch, int64(1); got != want { + t.Fatalf("Epoch=%v, want %v", got, want) + } + + // Fetch associated epochs. + epochs, err := leaser.Epochs(context.Background()) + if err != nil { + t.Fatal(err) + } else if got, want := epochs, []int64{1}; !reflect.DeepEqual(got, want) { + t.Fatalf("Epochs()=%v, want %v", got, want) + } + }) + + runWithLeaser(t, "Released", func(t *testing.T, leaser litestream.Leaser) { + // Create & release the initial lease. + lease1, err := leaser.AcquireLease(context.Background()) + if err != nil { + t.Fatal(err) + } else if err := leaser.ReleaseLease(context.Background(), lease1.Epoch); err != nil { + t.Fatal(err) + } + + // Create & release a lease from a different client. + lease2, err := leaser.AcquireLease(context.Background()) + if err != nil { + t.Fatal(err) + } else if err := leaser.ReleaseLease(context.Background(), lease2.Epoch); err != nil { + t.Fatal(err) + } + + // Try to renew the original lease + lease, err := leaser.RenewLease(context.Background(), lease1) + if err != nil { + t.Fatal(err) + } else if got, want := lease.Epoch, int64(3); got != want { + t.Fatalf("Epoch=%v, want %v", got, want) + } + + time.Sleep(1 * time.Second) + + // Fetch associated epochs. + epochs, err := leaser.Epochs(context.Background()) + if err != nil { + t.Fatal(err) + } else if got, want := epochs, []int64{3}; !reflect.DeepEqual(got, want) { + t.Fatalf("Epochs()=%v, want %v", got, want) + } + }) + + runWithLeaser(t, "ErrLeaseExists", func(t *testing.T, leaser litestream.Leaser) { + // Create & release the initial lease. + lease1, err := leaser.AcquireLease(context.Background()) + if err != nil { + t.Fatal(err) + } else if err := leaser.ReleaseLease(context.Background(), lease1.Epoch); err != nil { + t.Fatal(err) + } + + // Create lease from different client. + if _, err := leaser.AcquireLease(context.Background()); err != nil { + t.Fatal(err) + } + + // Try to renew the original lease. + var leaseExistsError *litestream.LeaseExistsError + if _, err := leaser.RenewLease(context.Background(), lease1); !errors.As(err, &leaseExistsError) { + t.Fatalf("unexpected error: %#v", err) + } else if got, want := leaseExistsError.Lease.Epoch, int64(2); got != want { + t.Fatalf("err.Lease.Epoch=%v, want %v", got, want) + } + }) +} + +// runWithLeaser executes fn with each leaser specified by the -integration flag +func runWithLeaser(t *testing.T, name string, fn func(*testing.T, litestream.Leaser)) { + t.Run(name, func(t *testing.T) { + for _, typ := range strings.Split("s3", ",") { + t.Run(typ, func(t *testing.T) { + c := newOpenLeaser(t, typ) + defer mustDeleteAllLeases(t, c) + fn(t, c) + }) + } + }) +} + +// newOpenLeaser returns a new, open leaser for integration testing by type name. +func newOpenLeaser(tb testing.TB, typ string) litestream.Leaser { + tb.Helper() + + switch typ { + case "s3": + leaser := newS3Leaser(tb) + leaser.Owner = testOwner + if err := leaser.Open(); err != nil { + tb.Fatal(err) + } + return leaser + default: + tb.Fatalf("invalid leaser type: %q", typ) + return nil + } +} + +func newS3Leaser(tb testing.TB) *s3.Leaser { + tb.Helper() + + l := s3.NewLeaser() + l.AccessKeyID = *s3AccessKeyID + l.SecretAccessKey = *s3SecretAccessKey + l.Region = *s3Region + l.Bucket = *s3Bucket + l.Path = path.Join(*s3Path, fmt.Sprintf("%016x", rand.Uint64())) + l.Endpoint = *s3Endpoint + l.ForcePathStyle = *s3ForcePathStyle + l.SkipVerify = *s3SkipVerify + return l +} + +// mustDeleteAllLeases deletes all lease objects. +func mustDeleteAllLeases(tb testing.TB, l litestream.Leaser) { + tb.Helper() + + epochs, err := l.Epochs(context.Background()) + if err != nil { + tb.Fatalf("cannot list lease epochs for deletion: %s", err) + } + + for _, epoch := range epochs { + if err := l.DeleteLease(context.Background(), epoch); err != nil { + tb.Fatalf("cannot delete lease (epoch=%d): %s", epoch, err) + } + } +} diff --git a/replica_client_test.go b/replica_client_test.go index cf6079b2..7494b86d 100644 --- a/replica_client_test.go +++ b/replica_client_test.go @@ -2,7 +2,6 @@ package litestream_test import ( "context" - "flag" "fmt" "io" "math/rand" @@ -21,47 +20,6 @@ import ( "github.com/benbjohnson/litestream/sftp" ) -var ( - // Enables integration tests. - integration = flag.String("integration", "file", "") -) - -// S3 settings -var ( - // Replica client settings - s3AccessKeyID = flag.String("s3-access-key-id", os.Getenv("LITESTREAM_S3_ACCESS_KEY_ID"), "") - s3SecretAccessKey = flag.String("s3-secret-access-key", os.Getenv("LITESTREAM_S3_SECRET_ACCESS_KEY"), "") - s3Region = flag.String("s3-region", os.Getenv("LITESTREAM_S3_REGION"), "") - s3Bucket = flag.String("s3-bucket", os.Getenv("LITESTREAM_S3_BUCKET"), "") - s3Path = flag.String("s3-path", os.Getenv("LITESTREAM_S3_PATH"), "") - s3Endpoint = flag.String("s3-endpoint", os.Getenv("LITESTREAM_S3_ENDPOINT"), "") - s3ForcePathStyle = flag.Bool("s3-force-path-style", os.Getenv("LITESTREAM_S3_FORCE_PATH_STYLE") == "true", "") - s3SkipVerify = flag.Bool("s3-skip-verify", os.Getenv("LITESTREAM_S3_SKIP_VERIFY") == "true", "") -) - -// Google cloud storage settings -var ( - gcsBucket = flag.String("gcs-bucket", os.Getenv("LITESTREAM_GCS_BUCKET"), "") - gcsPath = flag.String("gcs-path", os.Getenv("LITESTREAM_GCS_PATH"), "") -) - -// Azure blob storage settings -var ( - absAccountName = flag.String("abs-account-name", os.Getenv("LITESTREAM_ABS_ACCOUNT_NAME"), "") - absAccountKey = flag.String("abs-account-key", os.Getenv("LITESTREAM_ABS_ACCOUNT_KEY"), "") - absBucket = flag.String("abs-bucket", os.Getenv("LITESTREAM_ABS_BUCKET"), "") - absPath = flag.String("abs-path", os.Getenv("LITESTREAM_ABS_PATH"), "") -) - -// SFTP settings -var ( - sftpHost = flag.String("sftp-host", os.Getenv("LITESTREAM_SFTP_HOST"), "") - sftpUser = flag.String("sftp-user", os.Getenv("LITESTREAM_SFTP_USER"), "") - sftpPassword = flag.String("sftp-password", os.Getenv("LITESTREAM_SFTP_PASSWORD"), "") - sftpKeyPath = flag.String("sftp-key-path", os.Getenv("LITESTREAM_SFTP_KEY_PATH"), "") - sftpPath = flag.String("sftp-path", os.Getenv("LITESTREAM_SFTP_PATH"), "") -) - func TestReplicaClient_Generations(t *testing.T) { RunWithReplicaClient(t, "OK", func(t *testing.T, c litestream.ReplicaClient) { t.Parallel() @@ -395,7 +353,6 @@ func TestReplicaClient_WriteWALSegment(t *testing.T) { } func TestReplicaClient_WALReader(t *testing.T) { - RunWithReplicaClient(t, "OK", func(t *testing.T, c litestream.ReplicaClient) { t.Parallel() if _, err := c.WriteWALSegment(context.Background(), litestream.Pos{Generation: "5efbd8d042012dca", Index: 10, Offset: 5}, strings.NewReader(`foobar`)); err != nil { diff --git a/s3/leaser.go b/s3/leaser.go new file mode 100644 index 00000000..171c4092 --- /dev/null +++ b/s3/leaser.go @@ -0,0 +1,271 @@ +package s3 + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "path" + "regexp" + "slices" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/benbjohnson/litestream" +) + +// LockFileExt is the file extension used for epoch lock files. +const LockFileExt = ".lock" + +var _ litestream.Leaser = (*Leaser)(nil) + +// Leaser is an implementation of a distributed lease using S3 conditional writes. +type Leaser struct { + s3 *s3.S3 // s3 service + + // Required. The amount of time that a lease is held for before expiring. + // This can be extended by renewing the lease before expiration. + LeaseTimeout time.Duration + + // Optional. The name of the owner when a lease is acquired (e.g. hostname). + // This is used to provide human-readable information to see current leadership. + Owner string + + // AWS authentication keys. + AccessKeyID string + SecretAccessKey string + + // S3 bucket information + Region string + Bucket string + Path string + Endpoint string + ForcePathStyle bool + SkipVerify bool +} + +// NewLeaser returns a new instance of ReplicaClient. +func NewLeaser() *Leaser { + return &Leaser{ + LeaseTimeout: litestream.DefaultLeaseTimeout, + } +} + +// Type returns "s3" as the leaser type. +func (l *Leaser) Type() string { return "s3" } + +func (l *Leaser) Open() error { + sess, err := newSession(context.Background(), l.AccessKeyID, l.SecretAccessKey, l.Region, l.Bucket, l.Path, l.Endpoint, l.ForcePathStyle, l.SkipVerify) + if err != nil { + return err + } + l.s3 = s3.New(sess) + return nil +} + +// Epochs returns a list of epoch numbers from lease files in S3. +// +// Typically there should only be one or two as they will be cleaned up after +// expiration. The returned epoches are not guaranteed to be live. +func (l *Leaser) Epochs(ctx context.Context) ([]int64, error) { + var a []int64 + err := l.s3.ListObjectsPagesWithContext(ctx, &s3.ListObjectsInput{ + Bucket: aws.String(l.Bucket), + Prefix: aws.String(l.Path + "/"), + }, func(page *s3.ListObjectsOutput, lastPage bool) bool { + for _, obj := range page.Contents { + // Skip any files that don't match the lock file format. + name := path.Base(aws.StringValue(obj.Key)) + if !lockFileRegex.MatchString(name) { + continue + } + + // Parse epoch from base filename. + epochStr := strings.TrimSuffix(name, LockFileExt) + epoch, _ := strconv.ParseInt(epochStr, 16, 64) + + a = append(a, epoch) + } + return true + }) + + // Epochs should be sorted anyway but perform an explicit sort to be safe. + slices.Sort(a) + + return a, err +} + +// AcquireLease attempts to acquire a new lease. +func (l *Leaser) AcquireLease(ctx context.Context) (*litestream.Lease, error) { + return l.acquireLease(ctx, 0) +} + +// RenewLease renews an existing lease with the same timeout. +func (l *Leaser) RenewLease(ctx context.Context, lease *litestream.Lease) (*litestream.Lease, error) { + return l.acquireLease(ctx, lease.Epoch) +} + +func (l *Leaser) acquireLease(ctx context.Context, prevEpoch int64) (*litestream.Lease, error) { + // List all epochs for all available lock files. + epochs, err := l.Epochs(ctx) + if err != nil { + return nil, fmt.Errorf("epochs: %w", err) + } + + // Check if current epoch is valid and has not expired yet. + var epoch int64 + if len(epochs) > 0 { + epoch = epochs[len(epochs)-1] + + // Only check current lease if we don't have a previous lease that we + // are renewing or if the lease epoch doesn't match that previous lease. + if prevEpoch == 0 || epoch != prevEpoch { + if lease, err := l.lease(ctx, epoch); os.IsNotExist(err) { + // No lease, skip error checking + } else if err != nil { + return nil, fmt.Errorf("fetch lease (%d): %w", epoch, err) + } else if !lease.Expired() { + return nil, litestream.NewLeaseExistsError(lease) + } + } + } + + // At this point we assume there is no lease owner so try to acquire the + // lock using the next epoch. + epoch++ + lease, err := l.createLease(ctx, epoch) + if err != nil { + return nil, fmt.Errorf("create lease: %w", err) + } + + // Reap old lease files in the background. This does not affect correctness + // since we have a new lease file in existence so we will just log errors. + for _, epoch := range epochs { + epoch := epoch + go func() { + if err := l.DeleteLease(ctx, epoch); err != nil { + slog.Warn("cannot reap lease", + slog.Int64("epoch", epoch), + slog.Any("error", err)) + } + }() + } + + return lease, nil +} + +// ReleaseLease releases an previously acquired lease via expiration. +func (l *Leaser) ReleaseLease(ctx context.Context, epoch int64) error { + // Check if lease exists. Ignore if it has already been reaped. + lease, err := l.lease(ctx, epoch) + if os.IsNotExist(err) { + return nil // no lease file, exit + } else if err != nil { + return fmt.Errorf("fetch lease: %w", err) + } else if lease.Timeout <= 0 { + return nil // lease file already released + } + + // Invalidate timeout so it expires immediately. + lease.Timeout = 0 + + // Overwrite previous object with expired lease. + body, err := json.MarshalIndent(lease, "", " ") + if err != nil { + return err + } + + // Construct PUT with a conditional write and send to S3. + _, err = l.s3.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(l.Bucket), + Key: aws.String(l.leaseKey(lease.Epoch)), + Body: bytes.NewReader(body), + }) + return err +} + +// DeleteLease removes the lease by deleting its underlying file. +// This is typically used for reaping expired leases or for test cleanup. +func (l *Leaser) DeleteLease(ctx context.Context, epoch int64) error { + _, err := l.s3.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(l.Bucket), + Key: aws.String(l.leaseKey(epoch)), + }) + if err != nil && !isNotExists(err) { + return err + } + return nil +} + +func (l *Leaser) createLease(ctx context.Context, epoch int64) (*litestream.Lease, error) { + // Marshal lease object into formatted JSON. We set the last modified date + // to the current time to be conservative. We could also refetch the lease + // object so we'd get an accurate LastModified date but this should be ok. + lease := &litestream.Lease{ + Epoch: epoch, + ModTime: time.Now(), + Timeout: l.LeaseTimeout, + Owner: l.Owner, + } + body, err := json.MarshalIndent(lease, "", " ") + if err != nil { + return nil, err + } + + // Construct PUT with a conditional write and send to S3. + req, _ := l.s3.PutObjectRequest(&s3.PutObjectInput{ + Bucket: aws.String(l.Bucket), + Key: aws.String(l.leaseKey(epoch)), + Body: bytes.NewReader(body), + }) + req.HTTPRequest.Header.Add("If-None-Match", "*") + + var awsErr awserr.Error + if err := req.Send(); errors.As(err, &awsErr) && awsErr.Code() == "PreconditionFailed" { + // If precondition failed then another instance raced and got the epoch + // first so we should return an error with that lease. + currentLease, err := l.lease(ctx, epoch) + if err != nil { + return nil, fmt.Errorf("fetch conflicting lease (%d): %w", epoch, err) + } + return nil, litestream.NewLeaseExistsError(currentLease) + } else if err != nil { + return nil, err + } else if req.Error != nil { + return nil, req.Error + } + return lease, nil +} + +// lease fetches & decodes a lease file by epoch. +func (l *Leaser) lease(ctx context.Context, epoch int64) (*litestream.Lease, error) { + output, err := l.s3.GetObjectWithContext(ctx, &s3.GetObjectInput{ + Bucket: aws.String(l.Bucket), + Key: aws.String(l.leaseKey(epoch)), + }) + if isNotExists(err) { + return nil, os.ErrNotExist + } else if err != nil { + return nil, err + } + + var lease litestream.Lease + if err := json.NewDecoder(output.Body).Decode(&lease); err != nil { + return nil, err + } + lease.ModTime = aws.TimeValue(output.LastModified) + return &lease, nil +} + +func (l *Leaser) leaseKey(epoch int64) string { + return fmt.Sprintf("%s/%016x%s", l.Path, epoch, LockFileExt) +} + +var lockFileRegex = regexp.MustCompile(`^[0-9a-f]{16}\.lock$`) diff --git a/s3/replica_client.go b/s3/replica_client.go index 7b513750..b0eac31a 100644 --- a/s3/replica_client.go +++ b/s3/replica_client.go @@ -2,11 +2,9 @@ package s3 import ( "context" - "crypto/tls" "fmt" "io" "net" - "net/http" "os" "path" "regexp" @@ -15,8 +13,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/benbjohnson/litestream" @@ -30,9 +26,6 @@ const ReplicaClientType = "s3" // MaxKeys is the number of keys S3 can operate on per batch. const MaxKeys = 1000 -// DefaultRegion is the region used if one is not specified. -const DefaultRegion = "us-east-1" - var _ litestream.ReplicaClient = (*ReplicaClient)(nil) // ReplicaClient is a client for writing snapshots & WAL segments to disk. @@ -73,79 +66,15 @@ func (c *ReplicaClient) Init(ctx context.Context) (err error) { return nil } - // Look up region if not specified and no endpoint is used. - // Endpoints are typically used for non-S3 object stores and do not - // necessarily require a region. - region := c.Region - if region == "" { - if c.Endpoint == "" { - if region, err = c.findBucketRegion(ctx, c.Bucket); err != nil { - return fmt.Errorf("cannot lookup bucket region: %w", err) - } - } else { - region = DefaultRegion // default for non-S3 object stores - } - } - - // Create new AWS session. - config := c.config() - if region != "" { - config.Region = aws.String(region) - } - - sess, err := session.NewSession(config) + sess, err := newSession(ctx, c.AccessKeyID, c.SecretAccessKey, c.Region, c.Bucket, c.Path, c.Endpoint, c.ForcePathStyle, c.SkipVerify) if err != nil { - return fmt.Errorf("cannot create aws session: %w", err) + return err } c.s3 = s3.New(sess) c.uploader = s3manager.NewUploader(sess) return nil } -// config returns the AWS configuration. Uses the default credential chain -// unless a key/secret are explicitly set. -func (c *ReplicaClient) config() *aws.Config { - config := &aws.Config{} - - if c.AccessKeyID != "" || c.SecretAccessKey != "" { - config.Credentials = credentials.NewStaticCredentials(c.AccessKeyID, c.SecretAccessKey, "") - } - if c.Endpoint != "" { - config.Endpoint = aws.String(c.Endpoint) - } - if c.ForcePathStyle { - config.S3ForcePathStyle = aws.Bool(c.ForcePathStyle) - } - if c.SkipVerify { - config.HTTPClient = &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }} - } - - return config -} - -func (c *ReplicaClient) findBucketRegion(ctx context.Context, bucket string) (string, error) { - // Connect to US standard region to fetch info. - config := c.config() - config.Region = aws.String(DefaultRegion) - sess, err := session.NewSession(config) - if err != nil { - return "", err - } - - // Fetch bucket location, if possible. Must be bucket owner. - // This call can return a nil location which means it's in us-east-1. - if out, err := s3.New(sess).HeadBucketWithContext(ctx, &s3.HeadBucketInput{ - Bucket: aws.String(bucket), - }); err != nil { - return "", err - } else if out.BucketRegion != nil { - return *out.BucketRegion, nil - } - return DefaultRegion, nil -} - // Generations returns a list of available generation names. func (c *ReplicaClient) Generations(ctx context.Context) ([]string, error) { if err := c.Init(ctx); err != nil { diff --git a/s3/s3.go b/s3/s3.go new file mode 100644 index 00000000..879f4205 --- /dev/null +++ b/s3/s3.go @@ -0,0 +1,75 @@ +package s3 + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" +) + +// DefaultRegion is the region used if one is not specified. +const DefaultRegion = "us-east-1" + +func newSession(ctx context.Context, accessKeyID, secretAccessKey, region, bucket, path, endpoint string, forcePathStyle, skipVerify bool) (sess *session.Session, err error) { // Build S3 config. + // Build session config. + var config aws.Config + if accessKeyID != "" || secretAccessKey != "" { + config.Credentials = credentials.NewStaticCredentials(accessKeyID, secretAccessKey, "") + } + if endpoint != "" { + config.Endpoint = aws.String(endpoint) + } + if forcePathStyle { + config.S3ForcePathStyle = aws.Bool(forcePathStyle) + } + if skipVerify { + config.HTTPClient = &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }} + } + + // Look up region if not specified and no endpoint is used. + // Endpoints are typically used for non-S3 object stores and do not + // necessarily require a region. + if region == "" { + if endpoint == "" { + if region, err = findBucketRegion(ctx, config, bucket); err != nil { + return nil, fmt.Errorf("cannot lookup bucket region: %w", err) + } + } else { + region = DefaultRegion // default for non-S3 object stores + } + } + + // Set region if provided by user or by bucket lookup. + if region != "" { + config.Region = aws.String(region) + } + + return session.NewSession(&config) +} + +func findBucketRegion(ctx context.Context, config aws.Config, bucket string) (string, error) { + // Connect to US standard region to fetch info. + config.Region = aws.String(DefaultRegion) + sess, err := session.NewSession(&config) + if err != nil { + return "", err + } + + // Fetch bucket location, if possible. Must be bucket owner. + // This call can return a nil location which means it's in us-east-1. + if out, err := s3.New(sess).HeadBucketWithContext(ctx, &s3.HeadBucketInput{ + Bucket: aws.String(bucket), + }); err != nil { + return "", err + } else if out.BucketRegion != nil { + return *out.BucketRegion, nil + } + return DefaultRegion, nil +}