diff --git a/pkg/cassdcutil/manage.go b/pkg/cassdcutil/manage.go index 642415d..ea0f627 100644 --- a/pkg/cassdcutil/manage.go +++ b/pkg/cassdcutil/manage.go @@ -11,6 +11,11 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) +const ( + defaultPollInterval = 1 * time.Second + defaultTimeout = 10 * time.Minute +) + type CassManager struct { client client.Client } @@ -39,28 +44,22 @@ func (c *CassManager) ModifyStoppedState(ctx context.Context, name, namespace st if wait { if stop { - if err := waitutil.PollUntilContextTimeout(ctx, 10*time.Second, 10*time.Minute, true, func(context.Context) (bool, error) { - return c.RefreshStatus(ctx, cassdc, cassdcapi.DatacenterStopped, corev1.ConditionTrue) - }); err != nil { + if err := c.WaitForStatus(ctx, cassdc, cassdcapi.DatacenterStopped, corev1.ConditionTrue, defaultPollInterval, defaultTimeout); err != nil { return err } - // And wait for it to finish.. - return waitutil.PollUntilContextTimeout(ctx, 10*time.Second, 10*time.Minute, true, func(context.Context) (bool, error) { - return c.RefreshStatus(ctx, cassdc, cassdcapi.DatacenterReady, corev1.ConditionFalse) - }) - } + if err := c.WaitForStatus(ctx, cassdc, cassdcapi.DatacenterReady, corev1.ConditionFalse, defaultPollInterval, defaultTimeout); err != nil { + return err + } + } else { + if err := c.WaitForStatus(ctx, cassdc, cassdcapi.DatacenterStopped, corev1.ConditionFalse, defaultPollInterval, defaultTimeout); err != nil { + return err + } - if err := waitutil.PollUntilContextTimeout(ctx, 10*time.Second, 10*time.Minute, true, func(context.Context) (bool, error) { - return c.RefreshStatus(ctx, cassdc, cassdcapi.DatacenterStopped, corev1.ConditionFalse) - }); err != nil { - return err + if err := c.WaitForStatus(ctx, cassdc, cassdcapi.DatacenterReady, corev1.ConditionTrue, defaultPollInterval, defaultTimeout); err != nil { + return err + } } - - // And wait for it to finish.. - return waitutil.PollUntilContextTimeout(ctx, 10*time.Second, 10*time.Minute, true, func(context.Context) (bool, error) { - return c.RefreshStatus(ctx, cassdc, cassdcapi.DatacenterReady, corev1.ConditionTrue) - }) } return nil @@ -95,3 +94,20 @@ func (c *CassManager) RestartDc(ctx context.Context, name, namespace, rack strin } return nil } + +func (c *CassManager) WaitForStatus(ctx context.Context, cassdc *cassdcapi.CassandraDatacenter, status cassdcapi.DatacenterConditionType, wanted corev1.ConditionStatus, interval, timeout time.Duration) error { + if interval == 0 { + interval = defaultPollInterval + } + + if timeout == 0 { + timeout = defaultTimeout + } + + if err := waitutil.PollUntilContextTimeout(ctx, interval, timeout, true, func(context.Context) (bool, error) { + return c.RefreshStatus(ctx, cassdc, status, wanted) + }); err != nil { + return err + } + return nil +}