diff --git a/login1/dbus.go b/login1/dbus.go index 4cc6bb95..0ff7cccc 100644 --- a/login1/dbus.go +++ b/login1/dbus.go @@ -16,6 +16,7 @@ package login1 import ( + "context" "fmt" "os" "strconv" @@ -59,6 +60,7 @@ type connectionManager interface { type Caller interface { // TODO: This method should eventually be removed, as it provides no context support. Call(method string, flags dbus.Flags, args ...interface{}) *dbus.Call + CallWithContext(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call } // New establishes a connection to the system bus and authenticates. @@ -347,6 +349,15 @@ func (c *Conn) Reboot(askForAuth bool) { c.object.Call(dbusInterface+".Reboot", 0, askForAuth) } +// Reboot asks logind for a reboot using given context, optionally asking for auth. +func (c *Conn) RebootWithContext(ctx context.Context, askForAuth bool) error { + if call := c.object.CallWithContext(ctx, dbusInterface+".Reboot", 0, askForAuth); call.Err != nil { + return fmt.Errorf("calling reboot: %w", call.Err) + } + + return nil +} + // Inhibit takes inhibition lock in logind. func (c *Conn) Inhibit(what, who, why, mode string) (*os.File, error) { var fd dbus.UnixFD diff --git a/login1/dbus_test.go b/login1/dbus_test.go index 8816fd8b..1fd64e09 100644 --- a/login1/dbus_test.go +++ b/login1/dbus_test.go @@ -15,6 +15,8 @@ package login1_test import ( + "context" + "errors" "fmt" "os/user" "regexp" @@ -142,6 +144,168 @@ func Test_Creating_new_connection_with_custom_connection(t *testing.T) { }) } +//nolint:funlen // Many subtests. +func Test_Rebooting_with_context(t *testing.T) { + t.Parallel() + + t.Run("calls_login1_reboot_method_on_manager_interface", func(t *testing.T) { + t.Parallel() + + rebootCalled := false + + askForReboot := false + + connectionWithContextCheck := &mockConnection{ + ObjectF: func(string, dbus.ObjectPath) dbus.BusObject { + return &mockObject{ + CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call { + rebootCalled = true + + expectedMethodName := "org.freedesktop.login1.Manager.Reboot" + + if method != expectedMethodName { + t.Fatalf("Expected method %q being called, got %q", expectedMethodName, method) + } + + if len(args) != 1 { + t.Fatalf("Expected one argument to call, got %q", args) + } + + askedForReboot, ok := args[0].(bool) + if !ok { + t.Fatalf("Expected first argument to be of type %T, got %T", askForReboot, args[0]) + } + + if askForReboot != askedForReboot { + t.Fatalf("Expected argument to be %t, got %t", askForReboot, askedForReboot) + } + + return &dbus.Call{} + }, + } + }, + } + + testConn, err := login1.NewWithConnection(connectionWithContextCheck) + if err != nil { + t.Fatalf("Unexpected error creating connection: %v", err) + } + + if err := testConn.RebootWithContext(context.Background(), askForReboot); err != nil { + t.Fatalf("Unexpected error rebooting: %v", err) + } + + if !rebootCalled { + t.Fatalf("Expected reboot method call on given D-Bus connection") + } + }) + + t.Run("asks_for_auth_when_requested", func(t *testing.T) { + t.Parallel() + + rebootCalled := false + + askForReboot := true + + connectionWithContextCheck := &mockConnection{ + ObjectF: func(string, dbus.ObjectPath) dbus.BusObject { + return &mockObject{ + CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call { + rebootCalled = true + + if len(args) != 1 { + t.Fatalf("Expected one argument to call, got %q", args) + } + + askedForReboot, ok := args[0].(bool) + if !ok { + t.Fatalf("Expected first argument to be of type %T, got %T", askForReboot, args[0]) + } + + if askForReboot != askedForReboot { + t.Fatalf("Expected argument to be %t, got %t", askForReboot, askedForReboot) + } + + return &dbus.Call{} + }, + } + }, + } + + testConn, err := login1.NewWithConnection(connectionWithContextCheck) + if err != nil { + t.Fatalf("Unexpected error creating connection: %v", err) + } + + if err := testConn.RebootWithContext(context.Background(), askForReboot); err != nil { + t.Fatalf("Unexpected error rebooting: %v", err) + } + + if !rebootCalled { + t.Fatalf("Expected reboot method call on given D-Bus connection") + } + }) + + t.Run("use_given_context_for_D-Bus_call", func(t *testing.T) { + t.Parallel() + + testKey := struct{}{} + expectedValue := "bar" + + ctx := context.WithValue(context.Background(), testKey, expectedValue) + + connectionWithContextCheck := &mockConnection{ + ObjectF: func(string, dbus.ObjectPath) dbus.BusObject { + return &mockObject{ + CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call { + if val := ctx.Value(testKey); val != expectedValue { + t.Fatalf("Got unexpected context on call") + } + + return &dbus.Call{} + }, + } + }, + } + + testConn, err := login1.NewWithConnection(connectionWithContextCheck) + if err != nil { + t.Fatalf("Unexpected error creating connection: %v", err) + } + + if err := testConn.RebootWithContext(ctx, false); err != nil { + t.Fatalf("Unexpected error rebooting: %v", err) + } + }) + + t.Run("returns_error_when_D-Bus_call_fails", func(t *testing.T) { + t.Parallel() + + expectedError := fmt.Errorf("reboot error") + + connectionWithFailingObjectCall := &mockConnection{ + ObjectF: func(string, dbus.ObjectPath) dbus.BusObject { + return &mockObject{ + CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call { + return &dbus.Call{ + Err: expectedError, + } + }, + } + }, + } + + testConn, err := login1.NewWithConnection(connectionWithFailingObjectCall) + if err != nil { + t.Fatalf("Unexpected error creating connection: %v", err) + } + + if err := testConn.RebootWithContext(context.Background(), false); !errors.Is(err, expectedError) { + t.Fatalf("Unexpected error rebooting: %v", err) + } + }) +} + // mockConnection is a test helper for mocking dbus.Conn. type mockConnection struct { ObjectF func(string, dbus.ObjectPath) dbus.BusObject @@ -178,3 +342,70 @@ func (m *mockConnection) Close() error { func (m *mockConnection) BusObject() dbus.BusObject { return nil } + +// mockObject is a mock of dbus.BusObject. +type mockObject struct { + CallWithContextF func(context.Context, string, dbus.Flags, ...interface{}) *dbus.Call + CallF func(string, dbus.Flags, ...interface{}) *dbus.Call +} + +// mockObject must implement dbus.BusObject to be usable for other packages in tests, though not +// all methods must actually be mockable. See https://github.com/dbus/dbus/issues/252 for details. +var _ dbus.BusObject = &mockObject{} + +// CallWithContext ... +// +//nolint:lll // Upstream signature, can't do much with that. +func (m *mockObject) CallWithContext(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call { + if m.CallWithContextF == nil { + return &dbus.Call{} + } + + return m.CallWithContextF(ctx, method, flags, args...) +} + +// Call ... +func (m *mockObject) Call(method string, flags dbus.Flags, args ...interface{}) *dbus.Call { + if m.CallF == nil { + return &dbus.Call{} + } + + return m.CallF(method, flags, args...) +} + +// Go ... +func (m *mockObject) Go(method string, flags dbus.Flags, ch chan *dbus.Call, args ...interface{}) *dbus.Call { + return &dbus.Call{} +} + +// GoWithContext ... +// +//nolint:lll // Upstream signature, can't do much with that. +func (m *mockObject) GoWithContext(ctx context.Context, method string, flags dbus.Flags, ch chan *dbus.Call, args ...interface{}) *dbus.Call { + return &dbus.Call{} +} + +// AddMatchSignal ... +func (m *mockObject) AddMatchSignal(iface, member string, options ...dbus.MatchOption) *dbus.Call { + return &dbus.Call{} +} + +// RemoveMatchSignal ... +func (m *mockObject) RemoveMatchSignal(iface, member string, options ...dbus.MatchOption) *dbus.Call { + return &dbus.Call{} +} + +// GetProperty ... +func (m *mockObject) GetProperty(p string) (dbus.Variant, error) { return dbus.Variant{}, nil } + +// StoreProperty ... +func (m *mockObject) StoreProperty(p string, value interface{}) error { return nil } + +// SetProperty ... +func (m *mockObject) SetProperty(p string, v interface{}) error { return nil } + +// Destination ... +func (m *mockObject) Destination() string { return "" } + +// Path ... +func (m *mockObject) Path() dbus.ObjectPath { return "" }