diff --git a/config/config.go b/config/config.go index e165c30db..a3467a7a1 100644 --- a/config/config.go +++ b/config/config.go @@ -107,10 +107,10 @@ var baseConfig = Config{ MaxMemory: 0, EvictionPolicy: EvictAllKeysLFU, EvictionRatio: 0.40, - KeysLimit: 10000, + KeysLimit: 20000000, AOFFile: "./dice-master.aof", PersistenceEnabled: true, - WriteAOFOnCleanup: true, + WriteAOFOnCleanup: false, LFULogFactor: 10, LogLevel: "info", PrettyPrintLogs: false, diff --git a/go.mod b/go.mod index d7a0756a2..05694d4d0 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require gotest.tools/v3 v3.5.1 require ( github.com/bytedance/sonic/loader v0.2.0 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect @@ -41,6 +40,7 @@ require ( require ( github.com/axiomhq/hyperloglog v0.2.0 github.com/bytedance/sonic v1.12.1 + github.com/cespare/xxhash/v2 v2.2.0 github.com/cockroachdb/swiss v0.0.0-20240612210725-f4de07ae6964 github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 diff --git a/integration_tests/commands/async/setup.go b/integration_tests/commands/async/setup.go index 02b23708d..da94f3df2 100644 --- a/integration_tests/commands/async/setup.go +++ b/integration_tests/commands/async/setup.go @@ -11,13 +11,11 @@ import ( "sync" "time" - "github.com/dicedb/dice/internal/shard" - + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/clientio" - + derrors "github.com/dicedb/dice/internal/errors" "github.com/dicedb/dice/internal/server" - - "github.com/dicedb/dice/config" + "github.com/dicedb/dice/internal/shard" dstore "github.com/dicedb/dice/internal/store" "github.com/dicedb/dice/testutils" redis "github.com/dicedb/go-dice" @@ -121,7 +119,8 @@ func RunTestServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption const totalRetries = 100 var err error watchChan := make(chan dstore.WatchEvent, config.DiceConfig.Server.KeysLimit) - shardManager := shard.NewShardManager(1, watchChan, opt.Logger) + gec := make(chan error) + shardManager := shard.NewShardManager(1, watchChan, gec, opt.Logger) // Initialize the AsyncServer testServer := server.NewAsyncServer(shardManager, watchChan, opt.Logger) @@ -167,7 +166,7 @@ func RunTestServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption go func() { defer wg.Done() if err := testServer.Run(ctx); err != nil { - if errors.Is(err, server.ErrAborted) { + if errors.Is(err, derrors.ErrAborted) { cancelShardManager() return } diff --git a/integration_tests/commands/http/setup.go b/integration_tests/commands/http/setup.go index e582004c3..b72bcbb90 100644 --- a/integration_tests/commands/http/setup.go +++ b/integration_tests/commands/http/setup.go @@ -13,9 +13,9 @@ import ( "sync" "time" - "github.com/dicedb/dice/internal/querywatcher" - "github.com/dicedb/dice/config" + derrors "github.com/dicedb/dice/internal/errors" + "github.com/dicedb/dice/internal/querywatcher" "github.com/dicedb/dice/internal/server" "github.com/dicedb/dice/internal/shard" dstore "github.com/dicedb/dice/internal/store" @@ -90,8 +90,9 @@ func RunHTTPServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption config.DiceConfig.Network.IOBufferLength = 16 config.DiceConfig.Server.WriteAOFOnCleanup = false + globalErrChannel := make(chan error) watchChan := make(chan dstore.WatchEvent, config.DiceConfig.Server.KeysLimit) - shardManager := shard.NewShardManager(1, watchChan, opt.Logger) + shardManager := shard.NewShardManager(1, watchChan, globalErrChannel, opt.Logger) queryWatcherLocal := querywatcher.NewQueryManager(opt.Logger) config.HTTPPort = opt.Port // Initialize the HTTPServer @@ -118,7 +119,7 @@ func RunHTTPServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption err := testServer.Run(ctx) if err != nil { cancelShardManager() - if errors.Is(err, server.ErrAborted) { + if errors.Is(err, derrors.ErrAborted) { return } if err.Error() != "http: Server closed" { diff --git a/internal/auth/session.go b/internal/auth/session.go index dc1f3dc90..013ff419f 100644 --- a/internal/auth/session.go +++ b/internal/auth/session.go @@ -143,7 +143,6 @@ func (session *Session) Validate(username, password string) error { return fmt.Errorf("WRONGPASS invalid username-password pair or user is disabled") } -func (session *Session) Expire() (err error) { +func (session *Session) Expire() { session.Status = SessionStatusExpired - return } diff --git a/internal/auth/session_test.go b/internal/auth/session_test.go index 5db456edb..05f93ae56 100644 --- a/internal/auth/session_test.go +++ b/internal/auth/session_test.go @@ -144,10 +144,7 @@ func TestSessionValidate(t *testing.T) { func TestSessionExpire(t *testing.T) { session := NewSession() - err := session.Expire() - if err != nil { - t.Errorf("Session.Expire() returned an error: %v", err) - } + session.Expire() if session.Status != SessionStatusExpired { t.Errorf("Session.Expire() did not set status to Expired. Got %v, want %v", session.Status, SessionStatusExpired) } diff --git a/internal/clientio/iohandler/iohandler.go b/internal/clientio/iohandler/iohandler.go new file mode 100644 index 000000000..73b6e3adc --- /dev/null +++ b/internal/clientio/iohandler/iohandler.go @@ -0,0 +1,11 @@ +package iohandler + +import ( + "context" +) + +type IOHandler interface { + Read(ctx context.Context) ([]byte, error) + Write(ctx context.Context, response []byte) error + Close() error +} diff --git a/internal/clientio/iohandler/netconn/netconn.go b/internal/clientio/iohandler/netconn/netconn.go new file mode 100644 index 000000000..555af4982 --- /dev/null +++ b/internal/clientio/iohandler/netconn/netconn.go @@ -0,0 +1,185 @@ +package netconn + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "log/slog" + "net" + "os" + "syscall" + "time" + + "github.com/dicedb/dice/internal/clientio/iohandler" +) + +const ( + maxRequestSize = 512 * 1024 // 512 KB + readBufferSize = 4 * 1024 // 4 KB + idleTimeout = 10 * time.Minute +) + +var ( + ErrRequestTooLarge = errors.New("request too large") + ErrIdleTimeout = errors.New("connection idle timeout") + ErrorClosed = errors.New("connection closed") +) + +// IOHandler handles I/O operations for a network connection +type IOHandler struct { + fd int + file *os.File + conn net.Conn + reader *bufio.Reader + writer *bufio.Writer + logger *slog.Logger +} + +var _ iohandler.IOHandler = (*IOHandler)(nil) + +// NewIOHandler creates a new IOHandler from a file descriptor +func NewIOHandler(clientFD int, logger *slog.Logger) (*IOHandler, error) { + file := os.NewFile(uintptr(clientFD), "client-connection") + if file == nil { + return nil, fmt.Errorf("failed to create file from file descriptor") + } + + // Ensure the file is closed if we exit this function with an error + var conn net.Conn + defer func() { + if conn == nil { + // Only close the file if we haven't successfully created a net.Conn + err := file.Close() + if err != nil { + logger.Warn("Error closing file in NewIOHandler:", slog.Any("error", err)) + } + } + }() + + var err error + conn, err = net.FileConn(file) + if err != nil { + return nil, fmt.Errorf("failed to create net.Conn from file descriptor: %w", err) + } + + return &IOHandler{ + fd: clientFD, + file: file, + conn: conn, + reader: bufio.NewReader(conn), + writer: bufio.NewWriter(conn), + logger: logger, + }, nil +} + +func NewIOHandlerWithConn(conn net.Conn) *IOHandler { + return &IOHandler{ + conn: conn, + reader: bufio.NewReader(conn), + writer: bufio.NewWriter(conn), + } +} + +func (h *IOHandler) FileDescriptor() int { + return h.fd +} + +// ReadRequest reads data from the network connection +func (h *IOHandler) Read(ctx context.Context) ([]byte, error) { + var data []byte + buf := make([]byte, readBufferSize) + + for { + select { + case <-ctx.Done(): + return data, ctx.Err() + default: + err := h.conn.SetReadDeadline(time.Now().Add(idleTimeout)) + if err != nil { + return nil, fmt.Errorf("error setting read deadline: %w", err) + } + + n, err := h.reader.Read(buf) + if n > 0 { + data = append(data, buf[:n]...) + } + if err != nil { + switch { + case errors.Is(err, syscall.EAGAIN), errors.Is(err, syscall.EWOULDBLOCK), errors.Is(err, io.EOF): + // No more data to read at this time + return data, nil + case errors.Is(err, net.ErrClosed), errors.Is(err, syscall.EPIPE), errors.Is(err, syscall.ECONNRESET): + h.logger.Error("Connection closed", slog.Any("error", err)) + cerr := h.Close() + if cerr != nil { + h.logger.Warn("Error closing connection", slog.Any("error", errors.Join(err, cerr))) + } + return nil, ErrorClosed + case errors.Is(err, syscall.ETIMEDOUT): + h.logger.Info("Connection idle timeout", slog.Any("error", err)) + cerr := h.Close() + if cerr != nil { + h.logger.Warn("Error closing connection", slog.Any("error", errors.Join(err, cerr))) + } + return nil, ErrIdleTimeout + default: + h.logger.Error("Error reading from connection", slog.Any("error", err)) + return nil, fmt.Errorf("error reading request: %w", err) + } + } + + if len(data) > maxRequestSize { + h.logger.Warn("Request too large", slog.Any("size", len(data))) + return nil, ErrRequestTooLarge + } + + // If we've read less than the buffer size, we've likely got all the data + if n < len(buf) { + return data, nil + } + } + } +} + +// WriteResponse writes the response back to the network connection +func (h *IOHandler) Write(ctx context.Context, response []byte) error { + errChan := make(chan error, 1) + + go func(errChan chan error) { + _, err := h.writer.Write(response) + if err == nil { + err = h.writer.Flush() + } + + errChan <- err + }(errChan) + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errChan: + if err != nil { + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { + cerr := h.Close() + if cerr != nil { + err = errors.Join(err, cerr) + } + + h.logger.Error("Connection closed", slog.Any("error", err)) + return err + } + + return fmt.Errorf("error writing response: %w", err) + } + } + + return nil +} + +// Close underlying network connection +func (h *IOHandler) Close() error { + h.logger.Info("Closing connection") + return errors.Join(h.conn.Close(), h.file.Close()) +} diff --git a/internal/clientio/iohandler/netconn/netconn_resp_test.go b/internal/clientio/iohandler/netconn/netconn_resp_test.go new file mode 100644 index 000000000..ba605810a --- /dev/null +++ b/internal/clientio/iohandler/netconn/netconn_resp_test.go @@ -0,0 +1,202 @@ +package netconn + +import ( + "bufio" + "context" + "errors" + "github.com/dicedb/dice/mocks" + "log/slog" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNetConnIOHandler_RESP(t *testing.T) { + tests := []struct { + name string + input string + expectedRead string + writeResponse string + expectedWrite string + readErr error + writeErr error + ctxTimeout time.Duration + expectedReadErr error + expectedWriteErr error + }{ + { + name: "Simple String", + input: "+OK\r\n", + expectedRead: "+OK\r\n", + writeResponse: "+OK\r\n", + expectedWrite: "+OK\r\n", + }, + { + name: "Error", + input: "-Error message\r\n", + expectedRead: "-Error message\r\n", + writeResponse: "-ERR unknown command 'FOOBAR'\r\n", + expectedWrite: "-ERR unknown command 'FOOBAR'\r\n", + }, + { + name: "Integer", + input: ":1000\r\n", + expectedRead: ":1000\r\n", + writeResponse: ":1000\r\n", + expectedWrite: ":1000\r\n", + }, + { + name: "Bulk String", + input: "$5\r\nhello\r\n", + expectedRead: "$5\r\nhello\r\n", + writeResponse: "$5\r\nworld\r\n", + expectedWrite: "$5\r\nworld\r\n", + }, + { + name: "Null Bulk String", + input: "$-1\r\n", + expectedRead: "$-1\r\n", + writeResponse: "$-1\r\n", + expectedWrite: "$-1\r\n", + }, + { + name: "Empty Bulk String", + input: "$0\r\n\r\n", + expectedRead: "$0\r\n\r\n", + writeResponse: "$0\r\n\r\n", + expectedWrite: "$0\r\n\r\n", + }, + { + name: "Array", + input: "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + expectedRead: "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + writeResponse: "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + expectedWrite: "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + }, + { + name: "Empty Array", + input: "*0\r\n", + expectedRead: "*0\r\n", + writeResponse: "*0\r\n", + expectedWrite: "*0\r\n", + }, + { + name: "Null Array", + input: "*-1\r\n", + expectedRead: "*-1\r\n", + writeResponse: "*-1\r\n", + expectedWrite: "*-1\r\n", + }, + { + name: "Nested Array", + input: "*2\r\n*2\r\n+foo\r\n+bar\r\n*2\r\n+hello\r\n+world\r\n", + expectedRead: "*2\r\n*2\r\n+foo\r\n+bar\r\n*2\r\n+hello\r\n+world\r\n", + writeResponse: "*2\r\n*2\r\n+foo\r\n+bar\r\n*2\r\n+hello\r\n+world\r\n", + expectedWrite: "*2\r\n*2\r\n+foo\r\n+bar\r\n*2\r\n+hello\r\n+world\r\n", + }, + { + name: "SET command", + input: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n", + expectedRead: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n", + writeResponse: "+OK\r\n", + expectedWrite: "+OK\r\n", + }, + { + name: "GET command", + input: "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n", + expectedRead: "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n", + writeResponse: "$5\r\nvalue\r\n", + expectedWrite: "$5\r\nvalue\r\n", + }, + { + name: "LPUSH command", + input: "*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$5\r\nvalue\r\n$6\r\nvalue2\r\n", + expectedRead: "*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$5\r\nvalue\r\n$6\r\nvalue2\r\n", + writeResponse: ":2\r\n", + expectedWrite: ":2\r\n", + }, + { + name: "HMSET command", + input: "*6\r\n$5\r\nHMSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n", + expectedRead: "*6\r\n$5\r\nHMSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n", + writeResponse: "+OK\r\n", + expectedWrite: "+OK\r\n", + }, + { + name: "Partial read", + input: "*2\r\n$5\r\nhello\r\n$5\r\nwor", + expectedRead: "*2\r\n$5\r\nhello\r\n$5\r\nwor", + writeResponse: "+OK\r\n", + expectedWrite: "+OK\r\n", + }, + { + name: "Read error", + input: "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + readErr: errors.New("read error"), + expectedReadErr: errors.New("error reading request: read error"), + }, + { + name: "Write error", + input: "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + expectedRead: "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + writeResponse: strings.Repeat("Hello, World!\r\n", 100), + writeErr: errors.New("write error"), + expectedWriteErr: errors.New("error writing response: write error"), + }, + { + name: "Write error", + input: "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + expectedRead: "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + writeResponse: "Hello, World!\r\n", + writeErr: errors.New("write error"), + expectedWriteErr: errors.New("error writing response: write error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockConn{ + readData: []byte(tt.input), + readErr: tt.readErr, + writeErr: tt.writeErr, + } + + handler := &IOHandler{ + conn: mock, + reader: bufio.NewReaderSize(mock, 512), + writer: bufio.NewWriterSize(mock, 1024), + logger: slog.New(mocks.SlogNoopHandler{}), + } + + ctx := context.Background() + if tt.ctxTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, tt.ctxTimeout) + defer cancel() + } + + // Test ReadRequest + data, err := handler.Read(ctx) + if tt.expectedReadErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedReadErr.Error(), err.Error()) + return + } else { + assert.NoError(t, err) + assert.Equal(t, []byte(tt.expectedRead), data) + } + + // Test WriteResponse + err = handler.Write(ctx, []byte(tt.writeResponse)) + if tt.expectedWriteErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedWriteErr.Error(), err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, []byte(tt.expectedWrite), mock.writeData.Bytes()) + } + }) + } +} diff --git a/internal/clientio/iohandler/netconn/netconn_test.go b/internal/clientio/iohandler/netconn/netconn_test.go new file mode 100644 index 000000000..7f1064e54 --- /dev/null +++ b/internal/clientio/iohandler/netconn/netconn_test.go @@ -0,0 +1,279 @@ +package netconn + +import ( + "bufio" + "bytes" + "context" + "errors" + "github.com/dicedb/dice/mocks" + "io" + "log/slog" + "net" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockConn implements net.Conn interface for testing +type mockConn struct { + readData []byte + writeData bytes.Buffer + readErr error + writeErr error + closed bool + mu sync.Mutex +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.readErr != nil { + return 0, m.readErr + } + if len(m.readData) == 0 { + return 0, io.EOF + } + n = copy(b, m.readData) + m.readData = m.readData[n:] + return n, nil +} + +func (m *mockConn) Write(b []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.writeErr != nil { + return 0, m.writeErr + } + return m.writeData.Write(b) +} + +func (m *mockConn) Close() error { + m.closed = true + return nil +} +func (m *mockConn) LocalAddr() net.Addr { return nil } +func (m *mockConn) RemoteAddr() net.Addr { return nil } +func (m *mockConn) SetDeadline(t time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +type mockFile struct { + f *os.File +} + +func (f *mockFile) Close() error { + return nil +} + +func (m *mockConn) File() (*os.File, error) { + return &os.File{}, nil +} + +func TestNetConnIOHandler(t *testing.T) { + tests := []struct { + name string + readData []byte + readErr error + writeErr error + ctxTimeout time.Duration + response []byte + expectedRead []byte + expectedWrite []byte + expectedReadErr error + expectedWriteErr error + }{ + { + name: "Simple read and write", + readData: []byte("Hello, World!\r\n"), + expectedRead: []byte("Hello, World!\r\n"), + expectedWrite: []byte("Response\r\n"), + }, + { + name: "Read error", + readErr: errors.New("read error"), + expectedReadErr: errors.New("error reading request: read error"), + expectedWrite: []byte("Response\r\n"), + }, + { + name: "Write error", + readData: []byte("Hello, World!\r\n"), + expectedRead: []byte("Hello, World!\r\n"), + writeErr: errors.New("write error"), + response: []byte("Hello, World!\r\n"), + expectedWriteErr: errors.New("error writing response: write error"), + }, + { + name: "Large data read", + readData: bytes.Repeat([]byte("a"), 1000), + expectedRead: bytes.Repeat([]byte("a"), 1000), + expectedWrite: []byte("Response\r\n"), + }, + { + name: "Empty read", + readData: []byte{}, + expectedRead: []byte(nil), + expectedWrite: []byte("Response\r\n"), + }, + { + name: "Read with multiple chunks", + readData: []byte("Hello\r\nWorld\r\n"), + expectedRead: []byte("Hello\r\nWorld\r\n"), + expectedWrite: []byte("Response\r\n"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockConn{ + readData: tt.readData, + readErr: tt.readErr, + writeErr: tt.writeErr, + } + + handler := &IOHandler{ + conn: mock, + reader: bufio.NewReaderSize(mock, 512), + writer: bufio.NewWriterSize(mock, 1024), + logger: slog.New(mocks.SlogNoopHandler{}), + } + + ctx := context.Background() + if tt.ctxTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, tt.ctxTimeout) + defer cancel() + } + + // Test ReadRequest + data, err := handler.Read(ctx) + if tt.expectedReadErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedReadErr.Error(), err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedRead, data) + } + + // Test WriteResponse + if tt.response == nil { + tt.response = []byte("Response\r\n") + } + + err = handler.Write(ctx, tt.response) + + if tt.expectedWriteErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedWriteErr.Error(), err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedWrite, mock.writeData.Bytes()) + } + }) + } +} + +func TestNewNetConnIOHandler(t *testing.T) { + tests := []struct { + name string + setup func() (int, func(), error) + expectedErr error + }{ + { + name: "Closed file descriptor", + setup: func() (int, func(), error) { + f, err := os.CreateTemp("", "test") + if err != nil { + return 0, nil, err + } + fd := int(f.Fd()) + f.Close() // Close immediately to create a closed fd + return fd, func() {}, nil + }, + expectedErr: errors.New("failed to create net.Conn"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fd, cleanup, err := tt.setup() + require.NoError(t, err, "Setup failed") + defer cleanup() + + logger := slog.New(mocks.SlogNoopHandler{}) + handler, err := NewIOHandler(fd, logger) + + if tt.expectedErr != nil { + assert.Error(t, err) + assert.Nil(t, handler) + assert.Contains(t, err.Error(), tt.expectedErr.Error()) + } else { + assert.NoError(t, err) + assert.NotNil(t, handler) + assert.NotNil(t, handler.conn) + assert.NotNil(t, handler.reader) + assert.NotNil(t, handler.writer) + + // Test if the created handler can perform basic I/O + testData := []byte("Hello, World!") + go func() { + _, err := handler.conn.(io.Writer).Write(testData) + assert.NoError(t, err) + }() + + readData, err := handler.Read(context.Background()) + assert.NoError(t, err) + assert.Equal(t, testData, readData) + } + }) + } +} + +func TestNewNetConnIOHandler_RealNetwork(t *testing.T) { // More of an integration test + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "Failed to create listener") + defer listener.Close() + + go func() { + conn, err := listener.Accept() + if err != nil { + t.Errorf("Failed to accept connection: %v", err) + return + } + defer conn.Close() + + _, err = conn.Write([]byte("Hello, World!")) + if err != nil { + t.Errorf("Failed to write to connection: %v", err) + } + }() + + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err, "Failed to dial") + + tcpConn, ok := conn.(*net.TCPConn) + require.True(t, ok, "Not a TCP connection") + + file, err := tcpConn.File() + require.NoError(t, err, "Failed to get file from connection") + + fd := int(file.Fd()) + + logger := slog.New(mocks.SlogNoopHandler{}) + handler, err := NewIOHandler(fd, logger) + require.NoError(t, err, "Failed to create IOHandler") + + testData := []byte("Hello, World!") + readData, err := handler.Read(context.Background()) + assert.NoError(t, err) + assert.Equal(t, testData, readData) + + err = handler.Close() + assert.NoError(t, err) + + file.Close() + conn.Close() +} diff --git a/internal/clientio/requestparser/parser.go b/internal/clientio/requestparser/parser.go new file mode 100644 index 000000000..691a277a3 --- /dev/null +++ b/internal/clientio/requestparser/parser.go @@ -0,0 +1,9 @@ +package requestparser + +import ( + "github.com/dicedb/dice/internal/cmd" +) + +type Parser interface { + Parse(data []byte) ([]*cmd.RedisCmd, error) +} diff --git a/internal/clientio/requestparser/resp/respparser.go b/internal/clientio/requestparser/resp/respparser.go new file mode 100644 index 000000000..624542b39 --- /dev/null +++ b/internal/clientio/requestparser/resp/respparser.go @@ -0,0 +1,249 @@ +package respparser + +import ( + "bytes" + "errors" + "fmt" + "log/slog" + "strconv" + "strings" + + "github.com/dicedb/dice/internal/cmd" +) + +type RESPType byte + +const ( + SimpleString RESPType = '+' + Error RESPType = '-' + Integer RESPType = ':' + BulkString RESPType = '$' + Array RESPType = '*' +) + +// Common errors +var ( + ErrInvalidInput = errors.New("invalid input") + ErrUnexpectedEOF = errors.New("unexpected EOF") + ErrProtocolError = errors.New("protocol error") +) + +// CRLF is the line delimiter in RESP +var CRLF = []byte{'\r', '\n'} + +// Parser is responsible for parsing RESP protocol data +type Parser struct { + data []byte + pos int + logger *slog.Logger +} + +// NewParser creates a new RESP parser +func NewParser(l *slog.Logger) *Parser { + return &Parser{ + pos: 0, + logger: l, + } +} + +// SetData WARNING: This function is added for testing purposes only +func (p *Parser) SetData(data []byte) { + p.data = data + p.pos = 0 +} + +// Parse parses the entire input and returns a slice of RedisCmd +func (p *Parser) Parse(data []byte) ([]*cmd.RedisCmd, error) { + p.SetData(data) + var commands []*cmd.RedisCmd + for p.pos < len(p.data) { + c, err := p.parseCommand() + if err != nil { + return commands, err + } + + commands = append(commands, c) + } + + return commands, nil +} + +func (p *Parser) parseCommand() (*cmd.RedisCmd, error) { + if p.pos >= len(p.data) { + return nil, ErrUnexpectedEOF + } + + // A Dice command should always be an array as it follows RESP2 specifications + elements, err := p.parse() + if err != nil { + p.logger.Error("error while parsing command", slog.Any("cmd", string(p.data)), slog.Any("error", err)) + return nil, fmt.Errorf("error parsing command: %w", err) + } + + if len(elements) == 0 { + return nil, fmt.Errorf("error while parsing command, empty command") + } + + return &cmd.RedisCmd{ + Cmd: strings.ToUpper(elements[0]), + Args: elements[1:], + }, nil +} + +func (p *Parser) parse() ([]string, error) { + count := 1 + if p.data[p.pos] == byte(Array) { + var err error + count, err = p.parseArrayLength() + if err != nil { + return nil, err + } + } + + result := make([]string, 0, count) + for i := 0; i < count; i++ { + val, err := p.ParseOne() + if err != nil { + return nil, fmt.Errorf("parse array element %d: %w", i, err) + } + + str, err := p.convertToString(val) + if err != nil { + return nil, err + } + + result = append(result, str) + } + + return result, nil +} + +func (p *Parser) parseArrayLength() (int, error) { + line, err := p.readLine() + if err != nil { + return 0, fmt.Errorf("parse array length: %w", err) + } + + count, err := strconv.Atoi(string(line[1:])) // Remove '*' + if err != nil { + return 0, fmt.Errorf("invalid array length type") + } + + if count <= 0 { + return 0, fmt.Errorf("invalid array length %d", count) + } + + return count, nil +} + +func (p *Parser) convertToString(val any) (string, error) { + switch v := val.(type) { + case string: + return v, nil + case int64: + return strconv.FormatInt(v, 10), nil + default: + return "", fmt.Errorf("unexpected type %T", val) + } +} + +func (p *Parser) ParseOne() (any, error) { + for { + if p.pos >= len(p.data) { + return "", ErrUnexpectedEOF + } + + switch RESPType(p.data[p.pos]) { + case SimpleString: + return p.parseSimpleString() + case Error: + return p.parseError() + case Integer: + return p.parseInteger() + case BulkString: + return p.parseBulkString() + case Array: + return p.parse() + default: + return "", fmt.Errorf("%w: unknown type %c", ErrProtocolError, p.data[p.pos]) + } + } +} + +func (p *Parser) parseSimpleString() (string, error) { + p.pos++ // Skip the '+' + return p.readLineAsString() +} + +func (p *Parser) parseError() (string, error) { + p.pos++ // Skip the '-' + return p.readLineAsString() +} + +func (p *Parser) parseInteger() (val int64, err error) { + p.pos++ // Skip the ':' + line, err := p.readLineAsString() + if err != nil { + return 0, fmt.Errorf("parse integer: %w", err) + } + + return strconv.ParseInt(line, 10, 64) +} + +func (p *Parser) parseBulkString() (string, error) { + line, err := p.readLine() + if err != nil { + return "", fmt.Errorf("parse bulk string: %w", err) + } + + length, err := strconv.Atoi(string(line[1:])) // Remove '$' + if err != nil { + return "", fmt.Errorf("invalid bulk string length %q: %w", line, err) + } + + if length == -1 { + return "(nil)", nil // Null bulk string + } + + if length < -1 { + return "", fmt.Errorf("invalid bulk string length: %d", length) + } + + if p.pos+length+2 > len(p.data) { // +2 for CRLF + return "", ErrUnexpectedEOF + } + + content := string(p.data[p.pos : p.pos+length]) + p.pos += length + 2 // Skip the string content and CRLF + + // Verify CRLF after content + if !bytes.Equal(p.data[p.pos-2:p.pos], CRLF) { + return "", errors.New("bulk string not terminated by CRLF") + } + + return content, nil +} + +func (p *Parser) readLineAsString() (string, error) { + line, err := p.readLine() + if err != nil { + return "", err + } + + return string(line), nil +} + +func (p *Parser) readLine() ([]byte, error) { + if p.pos >= len(p.data) { + return nil, ErrUnexpectedEOF + } + + end := bytes.Index(p.data[p.pos:], CRLF) + if end == -1 { + return nil, ErrUnexpectedEOF + } + + line := p.data[p.pos : p.pos+end] + p.pos += end + 2 // +2 to move past CRLF + return line, nil +} diff --git a/internal/clientio/requestparser/resp/respparser_test.go b/internal/clientio/requestparser/resp/respparser_test.go new file mode 100644 index 000000000..fd98a731a --- /dev/null +++ b/internal/clientio/requestparser/resp/respparser_test.go @@ -0,0 +1,337 @@ +package respparser + +import ( + "github.com/dicedb/dice/mocks" + "log/slog" + "reflect" + "testing" + + "github.com/dicedb/dice/internal/cmd" +) + +func TestParser_Parse(t *testing.T) { + tests := []struct { + name string + input string + want []*cmd.RedisCmd + wantErr bool + }{ + { + name: "Simple SET command", + input: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n", + want: []*cmd.RedisCmd{ + {Cmd: "SET", Args: []string{"key", "value"}}, + }, + }, + { + name: "GET command", + input: "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n", + want: []*cmd.RedisCmd{ + {Cmd: "GET", Args: []string{"key"}}, + }, + }, + { + name: "Multiple commands", + input: "*2\r\n$4\r\nPING\r\n$4\r\nPONG\r\n*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n", + want: []*cmd.RedisCmd{ + {Cmd: "PING", Args: []string{"PONG"}}, + {Cmd: "SET", Args: []string{"key", "value"}}, + }, + }, + { + name: "Command with integer argument", + input: "*3\r\n$6\r\nEXPIRE\r\n$3\r\nkey\r\n:60\r\n", + want: []*cmd.RedisCmd{ + {Cmd: "EXPIRE", Args: []string{"key", "60"}}, + }, + }, + { + name: "Invalid command (not an array)", + input: "NOT AN ARRAY\r\n", + wantErr: true, + }, + { + name: "Empty command", + input: "*0\r\n", + wantErr: true, + }, + { + name: "Command with null bulk string argument", + input: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$-1\r\n", + want: []*cmd.RedisCmd{ + {Cmd: "SET", Args: []string{"key", "(nil)"}}, + }, + }, + { + name: "Command with Simple String argument", + input: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n+OK\r\n", + want: []*cmd.RedisCmd{ + {Cmd: "SET", Args: []string{"key", "OK"}}, + }, + }, + { + name: "Command with Error argument", + input: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n-ERR Invalid argument\r\n", + want: []*cmd.RedisCmd{ + {Cmd: "SET", Args: []string{"key", "ERR Invalid argument"}}, + }, + }, + { + name: "Command with mixed argument types", + input: "*5\r\n$4\r\nMSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n:1000\r\n+OK\r\n", + want: []*cmd.RedisCmd{ + {Cmd: "MSET", Args: []string{"key", "value", "1000", "OK"}}, + }, + }, + { + name: "Invalid array length", + input: "*-2\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n", + wantErr: true, + }, + { + name: "Incomplete command", + input: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n", + wantErr: true, + }, + { + name: "Command with empty bulk string", + input: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n", + want: []*cmd.RedisCmd{ + {Cmd: "SET", Args: []string{"key", ""}}, + }, + }, + { + name: "Invalid bulk string length", + input: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$-2\r\nvalue\r\n", + wantErr: true, + }, + { + name: "Non-integer bulk string length", + input: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$abc\r\nvalue\r\n", + wantErr: true, + }, + { + name: "Large bulk string", + input: "*2\r\n$4\r\nECHO\r\n$1000\r\n" + string(make([]byte, 1000)) + "\r\n", + want: []*cmd.RedisCmd{ + {Cmd: "ECHO", Args: []string{string(make([]byte, 1000))}}, + }, + }, + { + name: "Incomplete CRLF", + input: "*2\r\n$4\r\nECHO\r\n$5\r\nhello\r", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := slog.New(mocks.SlogNoopHandler{}) + p := NewParser(l) + got, err := p.Parse([]byte(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("Parser.Parse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !reflect.DeepEqual(got, tt.want) { + t.Errorf("Parser.Parse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParser_parseSimpleString(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + {"Valid simple string", "+OK\r\n", "OK", false}, + {"Empty simple string", "+\r\n", "", false}, + {"Simple string with spaces", "+Hello World\r\n", "Hello World", false}, + {"Incomplete simple string", "+OK", "", true}, + {"Missing CR", "+OK\n", "", true}, + {"Missing LF", "+OK\r", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Parser{data: []byte(tt.input), pos: 0} + got, err := p.parseSimpleString() + if (err != nil) != tt.wantErr { + t.Errorf("parseSimpleString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("parseSimpleString() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParser_parseError(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + {"Valid error", "-Error message\r\n", "Error message", false}, + {"Empty error", "-\r\n", "", false}, + {"Error with spaces", "-ERR unknown command\r\n", "ERR unknown command", false}, + {"Incomplete error", "-Error", "", true}, + {"Missing CR", "-Error\n", "", true}, + {"Missing LF", "-Error\r", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Parser{data: []byte(tt.input), pos: 0} + got, err := p.parseError() + if (err != nil) != tt.wantErr { + t.Errorf("parseError() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("parseError() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParser_parseInteger(t *testing.T) { + tests := []struct { + name string + input string + want int64 + wantErr bool + }{ + {"Valid positive integer", ":1000\r\n", 1000, false}, + {"Valid negative integer", ":-1000\r\n", -1000, false}, + {"Zero", ":0\r\n", 0, false}, + {"Large integer", ":9223372036854775807\r\n", 9223372036854775807, false}, + {"Invalid integer (float)", ":3.14\r\n", 0, true}, + {"Invalid integer (text)", ":abc\r\n", 0, true}, + {"Incomplete integer", ":123", 0, true}, + {"Missing CR", ":123\n", 0, true}, + {"Missing LF", ":123\r", 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Parser{data: []byte(tt.input), pos: 0} + got, err := p.parseInteger() + if (err != nil) != tt.wantErr { + t.Errorf("parseInteger() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("parseInteger() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParser_parseBulkString(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + {"Valid bulk string", "$5\r\nhello\r\n", "hello", false}, + {"Empty bulk string", "$0\r\n\r\n", "", false}, + {"Null bulk string", "$-1\r\n", "(nil)", false}, + {"Bulk string with spaces", "$11\r\nhello world\r\n", "hello world", false}, + {"Invalid length (negative)", "$-2\r\nhello\r\n", "", true}, + {"Invalid length (non-numeric)", "$abc\r\nhello\r\n", "", true}, + {"Incomplete bulk string", "$5\r\nhell", "", true}, + {"Missing CR", "$5\r\nhello\n", "", true}, + {"Missing LF", "$5\r\nhello\r", "", true}, + {"Length mismatch", "$4\r\nhello\r\n", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Parser{data: []byte(tt.input), pos: 0} + got, err := p.parseBulkString() + if (err != nil) != tt.wantErr { + t.Errorf("parseBulkString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("parseBulkString() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParser_parseArray(t *testing.T) { + tests := []struct { + name string + input string + want []string + wantErr bool + }{ + { + name: "Valid array", + input: "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + want: []string{"hello", "world"}, + }, + { + name: "Empty array", + input: "*0\r\n", + wantErr: true, + }, + { + name: "Null array", + input: "*-1\r\n", + wantErr: true, + }, + { + name: "Array with mixed types", + input: "*3\r\n:1\r\n$5\r\nhello\r\n+world\r\n", + want: []string{"1", "hello", "world"}, + }, + { + name: "Invalid array length", + input: "*-2\r\n", + wantErr: true, + }, + { + name: "Non-numeric array length", + input: "*abc\r\n", + wantErr: true, + }, + { + name: "Array length mismatch (too few elements)", + input: "*3\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + wantErr: true, + }, + { + name: "Array length mismatch (too many elements)", + input: "*1\r\n$5\r\nhello\r\n$5\r\nworld\r\n", + want: []string{"hello"}, // Truncated parsing + }, + { + name: "Incomplete array", + input: "*2\r\n$5\r\nhello\r\n", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Parser{data: []byte(tt.input), pos: 0} + got, err := p.parse() + if (err != nil) != tt.wantErr { + t.Errorf("parse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !reflect.DeepEqual(got, tt.want) { + t.Errorf("parse() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/cmd/cmds.go b/internal/cmd/cmds.go index deba83639..d0dfae6fa 100644 --- a/internal/cmd/cmds.go +++ b/internal/cmd/cmds.go @@ -1,9 +1,12 @@ package cmd type RedisCmd struct { - ID uint32 - Cmd string - Args []string + RequestID uint32 + Cmd string + Args []string } -type RedisCmds []*RedisCmd +type RedisCmds struct { + Cmds []*RedisCmd + RequestID uint32 +} diff --git a/internal/comm/client.go b/internal/comm/client.go index cfbdc88b7..33601ca67 100644 --- a/internal/comm/client.go +++ b/internal/comm/client.go @@ -37,25 +37,29 @@ func (c *Client) TxnBegin() { } func (c *Client) TxnDiscard() { - c.Cqueue = make(cmd.RedisCmds, 0) + c.Cqueue.Cmds = make([]*cmd.RedisCmd, 0) c.IsTxn = false } func (c *Client) TxnQueue(redisCmd *cmd.RedisCmd) { - c.Cqueue = append(c.Cqueue, redisCmd) + c.Cqueue.Cmds = append(c.Cqueue.Cmds, redisCmd) } func NewClient(fd int) *Client { + cmds := make([]*cmd.RedisCmd, 0) return &Client{ - Fd: fd, - Cqueue: make(cmd.RedisCmds, 0), + Fd: fd, + Cqueue: cmd.RedisCmds{ + Cmds: cmds, + }, Session: auth.NewSession(), } } func NewHTTPQwatchClient(qwatchResponseChan chan QwatchResponse, clientIdentifierID uint32) *Client { + cmds := make([]*cmd.RedisCmd, 0) return &Client{ - Cqueue: make(cmd.RedisCmds, 0), + Cqueue: cmd.RedisCmds{Cmds: cmds}, Session: auth.NewSession(), ClientIdentifierID: clientIdentifierID, HTTPQwatchResponseChan: qwatchResponseChan, diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 0138c91da..aec6e5cea 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -25,6 +25,11 @@ const ( JSONPathNotExistErr = "-ERR Path '%s' does not exist" JSONPathValueTypeErr = "-WRONGTYPE wrong type of path value - expected string but found integer" InvalidExpireTime = "-ERR invalid expire time" + InternalServerError = "-ERR: Internal server error, unable to process command" +) + +var ( + ErrAborted = errors.New("server received ABORT command") ) type DiceError struct { diff --git a/internal/eval/eval_amd64.go b/internal/eval/eval_amd64.go index beac53e77..d370a64d8 100644 --- a/internal/eval/eval_amd64.go +++ b/internal/eval/eval_amd64.go @@ -5,6 +5,7 @@ package eval import ( "syscall" + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/clientio" diceerrors "github.com/dicedb/dice/internal/errors" "github.com/dicedb/dice/internal/server/utils" @@ -21,6 +22,13 @@ func EvalBGREWRITEAOF(args []string, store *dstore.Store) []byte { // This technique utilizes the CoW or copy-on-write, so while the main process is free to modify them // the child would save all the pages to disk. // Check details here -https://www.sobyte.net/post/2022-10/fork-cow/ + // TODO: Fix this to work with the threading + // TODO: Problem at hand: In multi-threaded environment, each shard instance would fork a child process. + // TODO: Each child process would now have a copy of the network file descriptor thus resulting in resource leaks. + // TODO: We need to find an alternative approach for the multi-threaded environment. + if config.EnableMultiThreading { + return nil + } newChild, _, _ := syscall.Syscall(syscall.SYS_FORK, 0, 0, 0) if newChild == 0 { // We are inside child process now, so we'll start flushing to disk. diff --git a/internal/eval/eval_darwin_arm64.go b/internal/eval/eval_darwin_arm64.go index 48a142d15..0d13a5d5c 100644 --- a/internal/eval/eval_darwin_arm64.go +++ b/internal/eval/eval_darwin_arm64.go @@ -5,6 +5,7 @@ package eval import ( "syscall" + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/clientio" diceerrors "github.com/dicedb/dice/internal/errors" dstore "github.com/dicedb/dice/internal/store" @@ -20,6 +21,13 @@ func EvalBGREWRITEAOF(args []string, store *dstore.Store) []byte { // This technique utilizes the CoW or copy-on-write, so while the main process is free to modify them // the child would save all the pages to disk. // Check details here -https://www.sobyte.net/post/2022-10/fork-cow/ + // TODO: Fix this to work with the threading + // TODO: Problem at hand: In multi-threaded environment, each shard instance would fork a child process. + // TODO: Each child process would now have a copy of the network file descriptor thus resulting in resource leaks. + // TODO: We need to find an alternative approach for the multi-threaded environment. + if config.EnableMultiThreading { + return nil + } pid, _, err := syscall.RawSyscall(syscall.SYS_FORK, 0, 0, 0) if err != 0 { diff --git a/internal/eval/eval_linux_arm64.go b/internal/eval/eval_linux_arm64.go index 8af44c565..6b18d48db 100644 --- a/internal/eval/eval_linux_arm64.go +++ b/internal/eval/eval_linux_arm64.go @@ -21,6 +21,13 @@ func EvalBGREWRITEAOF(args []string, store *dstore.Store) []byte { // This technique utilizes the CoW or copy-on-write, so while the main process is free to modify them // the child would save all the pages to disk. // Check details here -https://www.sobyte.net/post/2022-10/fork-cow/ + // TODO: Fix this to work with the threading + // TODO: Problem at hand: In multi-threaded environment, each shard instance would fork a child process. + // TODO: Each child process would now have a copy of the network file descriptor thus resulting in resource leaks. + // TODO: We need to find an alternative approach for the multi-threaded environment. + if config.EnableMultiThreading { + nil + } childThreadID, _, _ := syscall.Syscall(syscall.SYS_GETTID, 0, 0, 0) newChild, _, _ := syscall.Syscall(syscall.SYS_CLONE, syscall.CLONE_PARENT_SETTID|syscall.CLONE_CHILD_CLEARTID|uintptr(syscall.SIGCHLD), 0, childThreadID) if newChild == 0 { diff --git a/internal/eval/eval_test.go b/internal/eval/eval_test.go index 5b0f33363..5a19d652a 100644 --- a/internal/eval/eval_test.go +++ b/internal/eval/eval_test.go @@ -9,16 +9,15 @@ import ( "strings" "testing" "time" - - "github.com/bytedance/sonic" - "github.com/dicedb/dice/internal/server/utils" - "github.com/ohler55/ojg/jp" "github.com/axiomhq/hyperloglog" + "github.com/bytedance/sonic" "github.com/dicedb/dice/internal/clientio" diceerrors "github.com/dicedb/dice/internal/errors" "github.com/dicedb/dice/internal/object" + "github.com/dicedb/dice/internal/server/utils" dstore "github.com/dicedb/dice/internal/store" + "github.com/ohler55/ojg/jp" testifyAssert "github.com/stretchr/testify/assert" "gotest.tools/v3/assert" ) diff --git a/internal/ops/store_op.go b/internal/ops/store_op.go index 7691de5fa..56278c7ce 100644 --- a/internal/ops/store_op.go +++ b/internal/ops/store_op.go @@ -7,10 +7,10 @@ import ( ) type StoreOp struct { - SeqID int16 // SeqID is the sequence id of the operation within a single request (optional, may be used for ordering) + SeqID uint8 // SeqID is the sequence id of the operation within a single request (optional, may be used for ordering) RequestID uint32 // RequestID identifies the request that this StoreOp belongs to Cmd *cmd.RedisCmd // Cmd is the atomic Store command (e.g., GET, SET) - ShardID int // ShardID of the shard on which the Store command will be executed + ShardID uint8 // ShardID of the shard on which the Store command will be executed WorkerID string // WorkerID is the ID of the worker that sent this Store operation Client *comm.Client // Client that sent this Store operation. TODO: This can potentially replace the WorkerID in the future HTTPOp bool // HTTPOp is true if this Store operation is a HTTP operation diff --git a/internal/server/cmdMeta.go b/internal/server/cmd_meta.go similarity index 96% rename from internal/server/cmdMeta.go rename to internal/server/cmd_meta.go index 10b1a0ed9..4bec9e376 100644 --- a/internal/server/cmdMeta.go +++ b/internal/server/cmd_meta.go @@ -13,11 +13,11 @@ type CmdType int // Enum values for CmdType using iota for auto-increment. // Global commands don't interact with shards, SingleShard commands interact with one shard, -// Multishard commands interact with multiple shards, and Custom commands require a direct client connection. +// MultiShard commands interact with multiple shards, and Custom commands require a direct client connection. const ( Global CmdType = iota // Global commands don't need to interact with shards. SingleShard // Single-shard commands interact with only one shard. - Multishard // Multishard commands interact with multiple shards using scatter-gather logic. + MultiShard // MultiShard commands interact with multiple shards using scatter-gather logic. Custom // Custom commands involve direct client communication. ) diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index 8738bb533..1b95a707c 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -12,14 +12,13 @@ import ( "sync" "time" - "github.com/dicedb/dice/internal/comm" - + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/clientio" "github.com/dicedb/dice/internal/cmd" - "github.com/dicedb/dice/internal/server/utils" - - "github.com/dicedb/dice/config" + "github.com/dicedb/dice/internal/comm" + derrors "github.com/dicedb/dice/internal/errors" "github.com/dicedb/dice/internal/ops" + "github.com/dicedb/dice/internal/server/utils" "github.com/dicedb/dice/internal/shard" ) @@ -100,7 +99,7 @@ func (s *HTTPServer) Run(ctx context.Context) error { select { case <-ctx.Done(): case <-s.shutdownChan: - err = ErrAborted + err = derrors.ErrAborted s.logger.Debug("Shutting down HTTP Server") } diff --git a/internal/server/multitherading.go b/internal/server/multitherading.go deleted file mode 100644 index 69ca1ad25..000000000 --- a/internal/server/multitherading.go +++ /dev/null @@ -1,88 +0,0 @@ -package server - -import ( - "bytes" - - "github.com/dicedb/dice/internal/cmd" - "github.com/dicedb/dice/internal/comm" - "github.com/dicedb/dice/internal/eval" - "github.com/dicedb/dice/internal/ops" - "github.com/dicedb/dice/internal/shard" - "github.com/twmb/murmur3" -) - -// getShard calculates the shard ID for a given key using Murmur3 hashing. -// It returns the shard ID by computing the hash modulo the number of shards (n). -func getShard(key string, n uint32) uint32 { - hash := murmur3.Sum32([]byte(key)) - return hash % n -} - -// cmdsBreakup breaks down a Redis command into smaller commands if multisharding is supported. -// It uses the metadata to check if the command supports multisharding and calls the respective breakup function. -// If multisharding is not supported, it returns the original command in a slice. -func (s *AsyncServer) cmdsBreakup(redisCmd *cmd.RedisCmd, c *comm.Client) []cmd.RedisCmd { - val, ok := WorkerCmdsMeta[redisCmd.Cmd] - if !ok { - return []cmd.RedisCmd{*redisCmd} - } - - // if command supports multisharding then send the command - // to the respective breakup function - // Which can return array of broken down commands - return val.Breakup(s.shardManager, redisCmd, c) -} - -// scatter distributes the Redis commands to the respective shards based on the key. -// For each command, it calculates the shard ID and sends the command to the shard's request channel for processing. -func (s *AsyncServer) scatter(cmds []cmd.RedisCmd, c *comm.Client) { - // Otherwise check for the shard based on the key using hash - // and send it to the particular shard - for i := 0; i < len(cmds); i++ { - var id uint32 - if len(cmds[i].Args) > 0 { - key := cmds[i].Args[i] - id = getShard(key, uint32(s.shardManager.GetShardCount())) - } - s.shardManager.GetShard(shard.ShardID(id)).ReqChan <- &ops.StoreOp{ - Cmd: &cmds[i], - WorkerID: "server", - ShardID: int(id), - Client: c, - } - } -} - -// gather collects the responses from multiple shards and writes the results into the provided buffer. -// It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). -func (s *AsyncServer) gather(redisCmd *cmd.RedisCmd, buf *bytes.Buffer, numShards int, c CmdType) { - // Loop to wait for messages from numberof shards - var evalResp []eval.EvalResponse - for i := 0; i < numShards; i++ { - resp, ok := <-s.ioChan - if ok { - evalResp = append(evalResp, resp.EvalResponse) - } - } - - // Check if command supports multisharding - val, ok := WorkerCmdsMeta[redisCmd.Cmd] - if !ok { - buf.Write(evalResp[0].Result.([]byte)) - return - } - - switch c { - case SingleShard, Custom: - if evalResp[0].Error != nil { - buf.WriteString(evalResp[0].Error.Error()) - return - } - buf.Write(evalResp[0].Result.([]byte)) - - case Multishard: - buf.Write(val.Gather(evalResp...)) - default: - buf.WriteString("ERR Invalid command type") - } -} diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go new file mode 100644 index 000000000..e4e617d08 --- /dev/null +++ b/internal/server/resp/server.go @@ -0,0 +1,221 @@ +package resp + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/dicedb/dice/config" + "github.com/dicedb/dice/internal/clientio/iohandler/netconn" + respparser "github.com/dicedb/dice/internal/clientio/requestparser/resp" + "github.com/dicedb/dice/internal/ops" + "github.com/dicedb/dice/internal/shard" + "github.com/dicedb/dice/internal/worker" +) + +var ( + workerCounter uint64 + startTime = time.Now().UnixNano() / int64(time.Millisecond) +) + +var ( + ErrInvalidIPAddress = errors.New("invalid IP address") +) + +const ( + DefaultConnBacklogSize = 128 +) + +type Server struct { + Host string + Port int + serverFD int + connBacklogSize int + wm *worker.WorkerManager + sm *shard.ShardManager + globalErrorChan chan error + logger *slog.Logger +} + +func NewServer(sm *shard.ShardManager, wm *worker.WorkerManager, gec chan error, l *slog.Logger) *Server { + return &Server{ + Host: config.DiceConfig.Server.Addr, + Port: config.DiceConfig.Server.Port, + connBacklogSize: DefaultConnBacklogSize, + wm: wm, + sm: sm, + globalErrorChan: gec, + logger: l, + } +} + +func (s *Server) Run(ctx context.Context) (err error) { + // BindAndListen the desired port to the server + if err = s.BindAndListen(); err != nil { + s.logger.Error("failed to bind server", slog.Any("error", err)) + return err + } + + defer s.ReleasePort() + + // Start a go routine to accept connections + errChan := make(chan error, 1) + wg := &sync.WaitGroup{} + wg.Add(1) + go func(wg *sync.WaitGroup) { + defer wg.Done() + if err := s.AcceptConnectionRequests(ctx, wg); err != nil { + errChan <- fmt.Errorf("failed to accept connections %w", err) + } + }(wg) + + s.logger.Info("DiceDB ready to accept connections on port", slog.Int("resp-port", config.Port)) + + select { + case <-ctx.Done(): + s.logger.Info("Context canceled, initiating shutdown") + case err = <-errChan: + s.logger.Error("Error while accepting connections, initiating shutdown", slog.Any("error", err)) + } + + s.Shutdown() + + wg.Wait() // Wait for the go routines to finish + s.logger.Info("All connections are closed, RESP server exiting gracefully.") + + return err +} + +func (s *Server) BindAndListen() error { + serverFD, socketErr := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0) + if socketErr != nil { + return fmt.Errorf("failed to create socket: %w", socketErr) + } + + // Close the socket on exit if an error occurs + var err error + defer func() { + if err != nil { + if closeErr := syscall.Close(serverFD); closeErr != nil { + // Wrap the close error with the original bind/listen error + s.logger.Error("Error occurred", slog.Any("error", err), "additionally, failed to close socket", slog.Any("close-err", closeErr)) + } else { + s.logger.Error("Error occurred", slog.Any("error", err)) + } + } + }() + + if err = syscall.SetsockoptInt(serverFD, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { + return fmt.Errorf("failed to set SO_REUSEADDR: %w", err) + } + + if err = syscall.SetNonblock(serverFD, true); err != nil { + return fmt.Errorf("failed to set socket to non-blocking: %w", err) + } + + ip4 := net.ParseIP(s.Host) + if ip4 == nil { + return ErrInvalidIPAddress + } + + sockAddr := &syscall.SockaddrInet4{ + Port: s.Port, + Addr: [4]byte{ip4[0], ip4[1], ip4[2], ip4[3]}, + } + if err = syscall.Bind(serverFD, sockAddr); err != nil { + return fmt.Errorf("failed to bind socket: %w", err) + } + + if err = syscall.Listen(serverFD, s.connBacklogSize); err != nil { + return fmt.Errorf("failed to listen on socket: %w", err) + } + + s.serverFD = serverFD + s.logger.Info("RESP Server successfully bound", slog.String("Host", s.Host), slog.Int("Port", s.Port)) + return nil +} + +// ReleasePort closes the server socket. +func (s *Server) ReleasePort() { + if err := syscall.Close(s.serverFD); err != nil { + s.logger.Error("Failed to close server socket", slog.Any("error", err)) + } else { + s.logger.Debug("Server socket closed successfully") + } +} + +// AcceptConnectionRequests accepts new client connections +func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGroup) error { + for { + select { + case <-ctx.Done(): + s.logger.Info("Context canceled, initiating RESP server shutdown") + + return ctx.Err() + default: + clientFD, _, err := syscall.Accept(s.serverFD) + if err != nil { + if errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EWOULDBLOCK) { + continue // No more connections to accept at this time + } + + return fmt.Errorf("error accepting connection: %w", err) + } + + // Register a new worker for the client + ioHandler, err := netconn.NewIOHandler(clientFD, s.logger) + if err != nil { + s.logger.Error("Failed to create new IOHandler for clientFD", slog.Int("client-fd", clientFD), slog.Any("error", err)) + return err + } + + parser := respparser.NewParser(s.logger) + respChan := make(chan *ops.StoreResponse) + wID := GenerateUniqueWorkerID() + w := worker.NewWorker(wID, respChan, ioHandler, parser, s.sm, s.globalErrorChan, s.logger) + if err != nil { + s.logger.Error("Failed to create new worker for clientFD", slog.Int("client-fd", clientFD), slog.Any("error", err)) + return err + } + + // Register the worker with the worker manager + err = s.wm.RegisterWorker(w) + if err != nil { + return err + } + + wg.Add(1) + go func(wID string) { + wg.Done() + defer func(wm *worker.WorkerManager, workerID string) { + err := wm.UnregisterWorker(workerID) + if err != nil { + s.logger.Warn("Failed to unregister worker", slog.String("worker-id", wID), slog.Any("error", err)) + } + }(s.wm, wID) + wctx, cwctx := context.WithCancel(ctx) + defer cwctx() + err := w.Start(wctx) + if err != nil { + s.logger.Debug("Worker stopped", slog.String("worker-id", wID), slog.Any("error", err)) + } + }(wID) + } + } +} + +func GenerateUniqueWorkerID() string { + count := atomic.AddUint64(&workerCounter, 1) + timestamp := time.Now().UnixNano()/int64(time.Millisecond) - startTime + return fmt.Sprintf("W-%d-%d", timestamp, count) +} + +func (s *Server) Shutdown() { + // Not implemented +} diff --git a/internal/server/server.go b/internal/server/server.go index 0afdce469..e9f412708 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,14 +6,13 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "strings" "sync" "syscall" "time" - "log/slog" - "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/auth" "github.com/dicedb/dice/internal/clientio" @@ -28,7 +27,6 @@ import ( dstore "github.com/dicedb/dice/internal/store" ) -var ErrAborted = errors.New("server received ABORT command") var ErrInvalidIPAddress = errors.New("invalid IP address") type AsyncServer struct { @@ -223,7 +221,7 @@ func (s *AsyncServer) eventLoop(ctx context.Context) error { } } else { if err := s.handleClientEvent(event); err != nil { - if errors.Is(err, ErrAborted) { + if errors.Is(err, diceerrors.ErrAborted) { s.logger.Debug("Received abort command, initiating graceful shutdown") return err } else if !errors.Is(err, syscall.ECONNRESET) && !errors.Is(err, net.ErrClosed) { @@ -272,52 +270,30 @@ func (s *AsyncServer) handleClientEvent(event iomultiplexer.Event) error { s.EvalAndRespond(commands, client) if hasAbort { - return ErrAborted + return diceerrors.ErrAborted } return nil } -// executeCommandToBuffer handles the execution of a Redis command and writes the result into a buffer. -// It first checks if the command supports multisharding or is a single-shard command. -// If necessary, it breaks down the command into multiple parts and scatters them to the appropriate shards. -// Finally, it gathers responses from the shards and writes the result to the buffer. func (s *AsyncServer) executeCommandToBuffer(redisCmd *cmd.RedisCmd, buf *bytes.Buffer, c *comm.Client) { - // Break down the single command into multiple commands if multisharding is supported. - // The length of commandBreakup helps determine how many shards to wait for responses. - commandBreakup := []cmd.RedisCmd{} - - // Retrieve metadata for the command to determine if multisharding is supported. - val, ok := WorkerCmdsMeta[redisCmd.Cmd] - if !ok { - // If no metadata exists, treat it as a single command. - commandBreakup = append(commandBreakup, *redisCmd) - } else { - // Depending on the command type, decide how to handle it. - switch val.CmdType { - case Global: - // If it's a global command, process it immediately without involving any shards. - buf.Write(val.RespNoShards(redisCmd.Args)) - return - - case SingleShard, Custom: - // For single-shard or custom commands, process them without breaking up. - commandBreakup = append(commandBreakup, *redisCmd) - - case Multishard: - // If the command supports multisharding, break it down into multiple commands. - commandBreakup = s.cmdsBreakup(redisCmd, c) - } + s.shardManager.GetShard(0).ReqChan <- &ops.StoreOp{ + Cmd: redisCmd, + WorkerID: "server", + ShardID: 0, + Client: c, } - // Scatter the broken-down commands to the appropriate shards. - s.scatter(commandBreakup, c) + resp := <-s.ioChan + if resp.EvalResponse.Error != nil { + buf.WriteString(resp.EvalResponse.Error.Error()) + return + } - // Gather the responses from the shards and write them to the buffer. - s.gather(redisCmd, buf, len(commandBreakup), val.CmdType) + buf.Write(resp.EvalResponse.Result.([]byte)) } -func readCommands(c io.ReadWriter) (cmd.RedisCmds, bool, error) { +func readCommands(c io.ReadWriter) (*cmd.RedisCmds, bool, error) { var hasABORT = false rp := clientio.NewRESPParser(c) values, err := rp.DecodeMultiple() @@ -351,7 +327,11 @@ func readCommands(c io.ReadWriter) (cmd.RedisCmds, bool, error) { hasABORT = true } } - return cmds, hasABORT, nil + + rCmds := &cmd.RedisCmds{ + Cmds: cmds, + } + return rCmds, hasABORT, nil } func toArrayString(ai []interface{}) ([]string, error) { @@ -366,11 +346,11 @@ func toArrayString(ai []interface{}) ([]string, error) { return as, nil } -func (s *AsyncServer) EvalAndRespond(cmds cmd.RedisCmds, c *comm.Client) { +func (s *AsyncServer) EvalAndRespond(cmds *cmd.RedisCmds, c *comm.Client) { var resp []byte buf := bytes.NewBuffer(resp) - for _, redisCmd := range cmds { + for _, redisCmd := range cmds.Cmds { if !s.isAuthenticated(redisCmd, c, buf) { continue } @@ -427,17 +407,18 @@ func (s *AsyncServer) handleNonTransactionCommand(redisCmd *cmd.RedisCmd, c *com } func (s *AsyncServer) executeTransaction(c *comm.Client, buf *bytes.Buffer) { - _, err := fmt.Fprintf(buf, "*%d\r\n", len(c.Cqueue)) + cmds := c.Cqueue.Cmds + _, err := fmt.Fprintf(buf, "*%d\r\n", len(cmds)) if err != nil { s.logger.Error("Error writing to buffer", slog.Any("error", err)) return } - for _, cmd := range c.Cqueue { + for _, cmd := range cmds { s.executeCommandToBuffer(cmd, buf, c) } - c.Cqueue = make(cmd.RedisCmds, 0) + c.Cqueue.Cmds = make([]*cmd.RedisCmd, 0) c.IsTxn = false } diff --git a/internal/shard/shard_manager.go b/internal/shard/shard_manager.go index 78053f492..a3be00523 100644 --- a/internal/shard/shard_manager.go +++ b/internal/shard/shard_manager.go @@ -2,13 +2,13 @@ package shard import ( "context" - "log" "log/slog" "os" "os/signal" "sync" "syscall" + "github.com/cespare/xxhash/v2" "github.com/dicedb/dice/internal/ops" dstore "github.com/dicedb/dice/internal/store" ) @@ -19,28 +19,32 @@ type ShardManager struct { // concurrently without synchronization. shards []*ShardThread shardReqMap map[ShardID]chan *ops.StoreOp // shardReqMap is a map of shard id to its respective request channel - globalErrorChan chan *ShardError // globalErrorChan is the common global error channel for all Shards + globalErrorChan chan error // globalErrorChan is the common global error channel for all Shards + ShardErrorChan chan *ShardError // ShardErrorChan is the channel for sending shard-level errors sigChan chan os.Signal // sigChan is the signal channel for the shard manager + shardCount uint8 // shardCount is the number of shards managed by this manager } // NewShardManager creates a new ShardManager instance with the given number of Shards and a parent context. -func NewShardManager(shardCount int8, watchChan chan dstore.WatchEvent, logger *slog.Logger) *ShardManager { +func NewShardManager(shardCount uint8, watchChan chan dstore.WatchEvent, globalErrorChan chan error, logger *slog.Logger) *ShardManager { shards := make([]*ShardThread, shardCount) shardReqMap := make(map[ShardID]chan *ops.StoreOp) - globalErrorChan := make(chan *ShardError) + shardErrorChan := make(chan *ShardError) - for i := int8(0); i < shardCount; i++ { + for i := uint8(0); i < shardCount; i++ { // Shards are numbered from 0 to shardCount-1 - shard := NewShardThread(ShardID(i), globalErrorChan, watchChan, logger) + shard := NewShardThread(i, globalErrorChan, shardErrorChan, watchChan, logger) shards[i] = shard - shardReqMap[ShardID(i)] = shard.ReqChan + shardReqMap[i] = shard.ReqChan } return &ShardManager{ shards: shards, shardReqMap: shardReqMap, globalErrorChan: globalErrorChan, + ShardErrorChan: shardErrorChan, sigChan: make(chan os.Signal, 1), + shardCount: shardCount, } } @@ -57,7 +61,6 @@ func (manager *ShardManager) Run(ctx context.Context) { wg.Add(1) go func() { defer wg.Done() - manager.listenForErrors() }() select { @@ -67,8 +70,8 @@ func (manager *ShardManager) Run(ctx context.Context) { // OS signal received, trigger shutdown } - close(manager.globalErrorChan) // Close the error channel after all Shards stop - wg.Wait() // Wait for all shard goroutines to exit. + close(manager.ShardErrorChan) // Close the error channel after all Shards stop + wg.Wait() // Wait for all shard goroutines to exit. } // start initializes and starts the shard threads. @@ -84,17 +87,15 @@ func (manager *ShardManager) start(ctx context.Context, wg *sync.WaitGroup) { } } -// listenForErrors listens to the global error channel and logs the errors. It exits when the error channel is closed. -func (manager *ShardManager) listenForErrors() { - for err := range manager.globalErrorChan { - // Handle or log shard errors here - log.Printf("Shard %d error: %v", err.shardID, err.err) - } +func (manager *ShardManager) GetShardInfo(key string) (id ShardID, c chan *ops.StoreOp) { + hash := xxhash.Sum64String(key) + id = ShardID(hash % uint64(manager.GetShardCount())) + return id, manager.GetShard(id).ReqChan } // GetShardCount returns the number of shards managed by this ShardManager. -func (manager *ShardManager) GetShardCount() int { - return len(manager.shards) +func (manager *ShardManager) GetShardCount() int8 { + return int8(len(manager.shards)) } // GetShard returns the ShardThread for the given ShardID. diff --git a/internal/shard/shard_thread.go b/internal/shard/shard_thread.go index b06e6bf12..088688ea4 100644 --- a/internal/shard/shard_thread.go +++ b/internal/shard/shard_thread.go @@ -16,11 +16,11 @@ import ( dstore "github.com/dicedb/dice/internal/store" ) -type ShardID int8 +type ShardID = uint8 type ShardError struct { - shardID ShardID // shardID is the ID of the shard that encountered the error - err error // err is the error that occurred + ShardID ShardID // ShardID is the ID of the shard that encountered the error + Error error // Error is the error that occurred } type ShardThread struct { @@ -29,20 +29,22 @@ type ShardThread struct { ReqChan chan *ops.StoreOp // ReqChan is this shard's channel for receiving requests. workerMap map[string]chan *ops.StoreResponse // workerMap maps workerID to its unique response channel workerMutex sync.RWMutex // workerMutex is the workerMap's mutex for thread safety. - errorChan chan *ShardError // errorChan is the channel for sending system-level errors. + globalErrorChan chan error // globalErrorChan is the channel for sending system-level errors. + shardErrorChan chan *ShardError // ShardErrorChan is the channel for sending shard-level errors. lastCronExecTime time.Time // lastCronExecTime is the last time the shard executed cron tasks. cronFrequency time.Duration // cronFrequency is the frequency at which the shard executes cron tasks. logger *slog.Logger // logger is the logger for the shard. } // NewShardThread creates a new ShardThread instance with the given shard id and error channel. -func NewShardThread(id ShardID, errorChan chan *ShardError, watchChan chan dstore.WatchEvent, logger *slog.Logger) *ShardThread { +func NewShardThread(id ShardID, gec chan error, sec chan *ShardError, watchChan chan dstore.WatchEvent, logger *slog.Logger) *ShardThread { return &ShardThread{ id: id, store: dstore.NewStore(watchChan), ReqChan: make(chan *ops.StoreOp, 1000), workerMap: make(map[string]chan *ops.StoreResponse), - errorChan: errorChan, + globalErrorChan: gec, + shardErrorChan: sec, lastCronExecTime: utils.GetCurrentTime(), cronFrequency: config.DiceConfig.Server.ShardCronFrequency, logger: logger, @@ -93,15 +95,20 @@ func (shard *ShardThread) processRequest(op *ops.StoreOp) { workerChan, ok := shard.workerMap[op.WorkerID] shard.workerMutex.RUnlock() - if !ok { - shard.errorChan <- &ShardError{shardID: shard.id, err: fmt.Errorf(diceerrors.WorkerNotFoundErr, op.WorkerID)} - return + sp := &ops.StoreResponse{ + RequestID: op.RequestID, } - workerChan <- &ops.StoreResponse{ - RequestID: op.RequestID, - EvalResponse: resp, + if ok { + sp.EvalResponse = resp + } else { + shard.shardErrorChan <- &ShardError{ + ShardID: shard.id, + Error: fmt.Errorf(diceerrors.WorkerNotFoundErr, op.WorkerID), + } } + + workerChan <- sp } // cleanup handles cleanup logic when the shard stops. diff --git a/internal/server/cmdBreakup.go b/internal/worker/cmd_breakup.go similarity index 90% rename from internal/server/cmdBreakup.go rename to internal/worker/cmd_breakup.go index 71170e5db..da57feed3 100644 --- a/internal/server/cmdBreakup.go +++ b/internal/worker/cmd_breakup.go @@ -1,8 +1,8 @@ -package server +package worker // Breakup file is used by Worker to split commands that need to be executed // across multiple shards. For commands that operate on multiple keys or -// require distribution across shards (e.g., Multishard commands), a Breakup +// require distribution across shards (e.g., MultiShard commands), a Breakup // function is invoked to break the original command into multiple smaller // commands, each targeted at a specific shard. // diff --git a/internal/worker/cmd_meta.go b/internal/worker/cmd_meta.go new file mode 100644 index 000000000..f1eee11f6 --- /dev/null +++ b/internal/worker/cmd_meta.go @@ -0,0 +1,93 @@ +package worker + +import ( + "fmt" + + "github.com/dicedb/dice/internal/cmd" + "github.com/dicedb/dice/internal/eval" + "github.com/dicedb/dice/internal/logger" +) + +type CmdType int + +const ( + Global CmdType = iota + SingleShard + MultiShard + Custom +) + +const ( + // Global commands + CmdPing = "PING" + CmdAbort = "ABORT" + CmdAuth = "AUTH" + + // Single-shard commands. + CmdSet = "SET" + CmdGet = "GET" + CmdGetSet = "GETSET" +) + +type CommandsMeta struct { + CmdType + Cmd string + WorkerCommandHandler func([]string) []byte + decomposeCommand func(redisCmd *cmd.RedisCmd) []*cmd.RedisCmd + composeResponse func(responses ...eval.EvalResponse) []byte +} + +var WorkerCommandsMeta = map[string]CommandsMeta{ + // Global commands. + CmdPing: { + CmdType: Global, + WorkerCommandHandler: eval.RespPING, + }, + CmdAbort: { + CmdType: Custom, + }, + CmdAuth: { + CmdType: Custom, + }, + + // Single-shard commands. + CmdSet: { + CmdType: SingleShard, + }, + CmdGet: { + CmdType: SingleShard, + }, + CmdGetSet: { + CmdType: SingleShard, + }, +} + +func init() { + l := logger.New(logger.Opts{WithTimestamp: true}) + // Validate the metadata for each command + for c, meta := range WorkerCommandsMeta { + if err := validateCmdMeta(c, meta); err != nil { + l.Error("error validating worker command metadata %s: %v", c, err) + } + } +} + +// validateCmdMeta ensures that the metadata for each command is properly configured +func validateCmdMeta(c string, meta CommandsMeta) error { + switch meta.CmdType { + case Global: + if meta.WorkerCommandHandler == nil { + return fmt.Errorf("global command %s must have WorkerCommandHandler function", c) + } + case MultiShard: + if meta.decomposeCommand == nil || meta.composeResponse == nil { + return fmt.Errorf("multi-shard command %s must have both decomposeCommand and composeResponse implemented", c) + } + case SingleShard, Custom: + // No specific validations for these types currently + default: + return fmt.Errorf("unknown command type for %s", c) + } + + return nil +} diff --git a/internal/server/gather.go b/internal/worker/gather.go similarity index 88% rename from internal/server/gather.go rename to internal/worker/gather.go index 8737243b9..8afa4ba98 100644 --- a/internal/server/gather.go +++ b/internal/worker/gather.go @@ -1,8 +1,8 @@ -package server +package worker // Gather file is used by Worker to collect and process responses // from multiple shards. For commands that are executed across -// several shards (e.g., Multishard commands), a Gather function +// several shards (e.g., MultiShard commands), a Gather function // is responsible for aggregating the results. // // Each Gather function takes input in the form of shard responses, diff --git a/internal/worker/worker.go b/internal/worker/worker.go new file mode 100644 index 000000000..45fa8c188 --- /dev/null +++ b/internal/worker/worker.go @@ -0,0 +1,352 @@ +package worker + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "syscall" + "time" + + "github.com/dicedb/dice/config" + "github.com/dicedb/dice/internal/auth" + "github.com/dicedb/dice/internal/clientio" + "github.com/dicedb/dice/internal/clientio/iohandler" + "github.com/dicedb/dice/internal/clientio/requestparser" + "github.com/dicedb/dice/internal/cmd" + diceerrors "github.com/dicedb/dice/internal/errors" + "github.com/dicedb/dice/internal/eval" + "github.com/dicedb/dice/internal/ops" + "github.com/dicedb/dice/internal/shard" +) + +// Worker interface +type Worker interface { + ID() string + Start(context.Context) error + Stop() error +} + +type BaseWorker struct { + id string + ioHandler iohandler.IOHandler + parser requestparser.Parser + shardManager *shard.ShardManager + respChan chan *ops.StoreResponse + Session *auth.Session + globalErrorChan chan error + logger *slog.Logger +} + +func NewWorker(wid string, respChan chan *ops.StoreResponse, + ioHandler iohandler.IOHandler, parser requestparser.Parser, + shardManager *shard.ShardManager, gec chan error, + logger *slog.Logger) *BaseWorker { + return &BaseWorker{ + id: wid, + ioHandler: ioHandler, + parser: parser, + shardManager: shardManager, + globalErrorChan: gec, + respChan: respChan, + logger: logger, + Session: auth.NewSession(), + } +} + +func (w *BaseWorker) ID() string { + return w.id +} + +func (w *BaseWorker) Start(ctx context.Context) error { + errChan := make(chan error, 1) + for { + select { + case <-ctx.Done(): + err := w.Stop() + if err != nil { + w.logger.Warn("Error stopping worker:", slog.String("workerID", w.id), slog.Any("error", err)) + } + return ctx.Err() + case err := <-errChan: + if err != nil { + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { + w.logger.Error("Connection closed for worker", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + } + return fmt.Errorf("error writing response: %w", err) + default: + data, err := w.ioHandler.Read(ctx) + if err != nil { + w.logger.Debug("Read error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + cmds, err := w.parser.Parse(data) + if err != nil { + err = w.ioHandler.Write(ctx, clientio.Encode(err, true)) + if err != nil { + w.logger.Debug("Write error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + } + if len(cmds) == 0 { + err = w.ioHandler.Write(ctx, clientio.Encode("ERR: Invalid request", true)) + if err != nil { + w.logger.Debug("Write error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + continue + } + + // DiceDB supports clients to send only one request at a time + // We also need to ensure that the client is blocked until the response is received + if len(cmds) > 1 { + err = w.ioHandler.Write(ctx, clientio.Encode("ERR: Multiple commands not supported", true)) + if err != nil { + w.logger.Debug("Write error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + } + + err = w.isAuthenticated(cmds[0]) + if err != nil { + werr := w.ioHandler.Write(ctx, clientio.Encode(err, false)) + if werr != nil { + w.logger.Debug("Write error, connection closed possibly", slog.Any("error", errors.Join(err, werr))) + return errors.Join(err, werr) + } + } + // executeCommand executes the command and return the response back to the client + func(errChan chan error) { + execctx, cancel := context.WithTimeout(ctx, 1*time.Second) // Timeout if + defer cancel() + err = w.executeCommand(execctx, cmds[0]) + if err != nil { + w.logger.Error("Error executing command", slog.String("workerID", w.id), slog.Any("error", err)) + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { + w.logger.Debug("Connection closed for worker", slog.String("workerID", w.id), slog.Any("error", err)) + errChan <- err + } + } + }(errChan) + } + } +} + +func (w *BaseWorker) executeCommand(ctx context.Context, redisCmd *cmd.RedisCmd) error { + // Break down the single command into multiple commands if multisharding is supported. + // The length of cmdList helps determine how many shards to wait for responses. + cmdList := make([]*cmd.RedisCmd, 0) + + // Retrieve metadata for the command to determine if multisharding is supported. + meta, ok := WorkerCommandsMeta[redisCmd.Cmd] + if !ok { + // If no metadata exists, treat it as a single command. + cmdList = append(cmdList, redisCmd) + } else { + // Depending on the command type, decide how to handle it. + switch meta.CmdType { + case Global: + // If it's a global command, process it immediately without involving any shards. + err := w.ioHandler.Write(ctx, meta.WorkerCommandHandler(redisCmd.Args)) + w.logger.Debug("Error executing for worker", slog.String("workerID", w.id), slog.Any("error", err)) + return err + + case SingleShard: + // For single-shard or custom commands, process them without breaking up. + cmdList = append(cmdList, redisCmd) + + case MultiShard: + // If the command supports multisharding, break it down into multiple commands. + cmdList = meta.decomposeCommand(redisCmd) + case Custom: + switch redisCmd.Cmd { + case CmdAuth: + err := w.ioHandler.Write(ctx, w.RespAuth(redisCmd.Args)) + w.logger.Error("Error sending auth response to worker", slog.String("workerID", w.id), slog.Any("error", err)) + return err + case CmdAbort: + w.logger.Info("Received ABORT command, initiating server shutdown", slog.String("workerID", w.id)) + w.globalErrorChan <- diceerrors.ErrAborted + return nil + default: + cmdList = append(cmdList, redisCmd) + } + } + } + + // Scatter the broken-down commands to the appropriate shards. + err := w.scatter(ctx, cmdList) + if err != nil { + return err + } + + // Gather the responses from the shards and write them to the buffer. + err = w.gather(ctx, redisCmd.Cmd, len(cmdList), meta.CmdType) + if err != nil { + return err + } + + return nil +} + +// scatter distributes the Redis commands to the respective shards based on the key. +// For each command, it calculates the shard ID and sends the command to the shard's request channel for processing. +func (w *BaseWorker) scatter(ctx context.Context, cmds []*cmd.RedisCmd) error { + // Otherwise check for the shard based on the key using hash + // and send it to the particular shard + select { + case <-ctx.Done(): + return ctx.Err() + default: + for i := uint8(0); i < uint8(len(cmds)); i++ { + var rc chan *ops.StoreOp + var sid shard.ShardID + var key string + if len(cmds[i].Args) > 0 { + key = cmds[i].Args[0] + } else { + key = cmds[i].Cmd + } + + sid, rc = w.shardManager.GetShardInfo(key) + + rc <- &ops.StoreOp{ + SeqID: i, + RequestID: cmds[i].RequestID, + Cmd: cmds[i], + WorkerID: w.id, + ShardID: sid, + Client: nil, + } + } + } + + return nil +} + +// gather collects the responses from multiple shards and writes the results into the provided buffer. +// It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). +func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdType) error { + // Loop to wait for messages from numberof shards + var evalResp []eval.EvalResponse + for numCmds != 0 { + select { + case <-ctx.Done(): + w.logger.Error("Timed out waiting for response from shards", slog.String("workerID", w.id), slog.Any("error", ctx.Err())) + case resp, ok := <-w.respChan: + if ok { + evalResp = append(evalResp, resp.EvalResponse) + } + numCmds-- + continue + case sError, ok := <-w.shardManager.ShardErrorChan: + if ok { + w.logger.Error("Error from shard", slog.String("workerID", w.id), slog.Any("error", sError)) + } + } + } + + // TODO: This is a temporary solution. In the future, all commands should be refactored to be multi-shard compatible. + // TODO: There are a few commands such as QWATCH, RENAME, MGET, MSET that wouldn't work in multi-shard mode without refactoring. + // TODO: These commands should be refactored to be multi-shard compatible before DICE-DB is completely multi-shard. + // Check if command is part of the new WorkerCommandsMeta map i.e. if the command has been refactored to be multi-shard compatible. + // If not found, treat it as a command that's not yet refactored, and write the response back to the client. + val, ok := WorkerCommandsMeta[c] + if !ok { + if evalResp[0].Error != nil { + err := w.ioHandler.Write(ctx, []byte(evalResp[0].Error.Error())) + if err != nil { + w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + } + + err := w.ioHandler.Write(ctx, evalResp[0].Result.([]byte)) + if err != nil { + w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + + return nil + } + + switch ct { + case SingleShard, Custom: + if evalResp[0].Error != nil { + err := w.ioHandler.Write(ctx, []byte(evalResp[0].Error.Error())) + if err != nil { + w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) + } + + return err + } + + err := w.ioHandler.Write(ctx, evalResp[0].Result.([]byte)) + if err != nil { + w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + + case MultiShard: + err := w.ioHandler.Write(ctx, val.composeResponse(evalResp...)) + if err != nil { + w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + + default: + w.logger.Error("Unknown command type", slog.String("workerID", w.id)) + err := w.ioHandler.Write(ctx, []byte(diceerrors.InternalServerError)) + if err != nil { + w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + } + + return nil +} + +func (w *BaseWorker) isAuthenticated(redisCmd *cmd.RedisCmd) error { + if redisCmd.Cmd != auth.AuthCmd && !w.Session.IsActive() { + return errors.New("NOAUTH Authentication required") + } + + return nil +} + +// RespAuth returns with an encoded "OK" if the user is authenticated +// If the user is not authenticated, it returns with an encoded error message +func (w *BaseWorker) RespAuth(args []string) []byte { + // Check for incorrect number of arguments (arity error). + if len(args) < 1 || len(args) > 2 { + return diceerrors.NewErrArity("AUTH") // Return an error if the number of arguments is not equal to 1. + } + + if config.DiceConfig.Auth.Password == "" { + return diceerrors.NewErrWithMessage("AUTH called without any password configured for the default user. Are you sure your configuration is correct?") + } + + username := config.DiceConfig.Auth.UserName + var password string + + if len(args) == 1 { + password = args[0] + } else { + username, password = args[0], args[1] + } + + if err := w.Session.Validate(username, password); err != nil { + return clientio.Encode(err, false) + } + + return clientio.RespOK +} + +func (w *BaseWorker) Stop() error { + w.logger.Info("Stopping worker", slog.String("workerID", w.id)) + w.Session.Expire() + return nil +} diff --git a/internal/worker/workermanager.go b/internal/worker/workermanager.go new file mode 100644 index 000000000..96befcbcb --- /dev/null +++ b/internal/worker/workermanager.go @@ -0,0 +1,74 @@ +package worker + +import ( + "errors" + "sync" + + "github.com/dicedb/dice/internal/shard" +) + +type WorkerManager struct { + connectedClients sync.Map + numWorkers int + maxClients int + shardManager *shard.ShardManager + mu sync.Mutex +} + +var ( + ErrMaxClientsReached = errors.New("maximum number of clients reached") + ErrWorkerNotFound = errors.New("worker not found") +) + +func NewWorkerManager(maxClients int, sm *shard.ShardManager) *WorkerManager { + return &WorkerManager{ + maxClients: maxClients, + shardManager: sm, + } +} + +func (wm *WorkerManager) RegisterWorker(worker Worker) error { + wm.mu.Lock() + defer wm.mu.Unlock() + + if wm.GetWorkerCount() >= wm.maxClients { + return ErrMaxClientsReached + } + + wm.connectedClients.Store(worker.ID(), worker) + respChan := worker.(*BaseWorker).respChan + if respChan != nil { + wm.shardManager.RegisterWorker(worker.ID(), respChan) // TODO: Change respChan type to ShardResponse + } + + wm.numWorkers++ + return nil +} + +func (wm *WorkerManager) GetWorkerCount() int { + return wm.numWorkers +} + +func (wm *WorkerManager) GetWorker(workerID string) (Worker, bool) { + worker, ok := wm.connectedClients.Load(workerID) + if !ok { + return nil, false + } + return worker.(Worker), true +} + +func (wm *WorkerManager) UnregisterWorker(workerID string) error { + if worker, loaded := wm.connectedClients.LoadAndDelete(workerID); loaded { + w := worker.(Worker) + if err := w.Stop(); err != nil { + return err + } + } else { + return ErrWorkerNotFound + } + + wm.shardManager.UnregisterWorker(workerID) + wm.numWorkers++ + + return nil +} diff --git a/main.go b/main.go index 4519c80f6..2f339849a 100644 --- a/main.go +++ b/main.go @@ -4,21 +4,21 @@ import ( "context" "errors" "flag" + "log/slog" "os" "os/signal" "runtime" "sync" "syscall" + "github.com/dicedb/dice/config" + diceerrors "github.com/dicedb/dice/internal/errors" "github.com/dicedb/dice/internal/logger" + "github.com/dicedb/dice/internal/server" + "github.com/dicedb/dice/internal/server/resp" "github.com/dicedb/dice/internal/shard" dstore "github.com/dicedb/dice/internal/store" - - "github.com/dicedb/dice/internal/server" - - "log/slog" - - "github.com/dicedb/dice/config" + "github.com/dicedb/dice/internal/worker" ) func init() { @@ -47,6 +47,7 @@ func main() { signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) watchChan := make(chan dstore.WatchEvent, config.DiceConfig.Server.KeysLimit) + var serverErrCh chan error // Get the number of available CPU cores on the machine using runtime.NumCPU(). // This determines the total number of logical processors that can be utilized @@ -55,11 +56,13 @@ func main() { // If not enabled multithreading, server will run on a single core. var numCores int if config.EnableMultiThreading { + serverErrCh = make(chan error, 1) + logr.Debug("The DiceDB server has started in multi-threaded mode.", slog.Int("number of cores", numCores)) numCores = runtime.NumCPU() - logr.Info("The DiceDB server has started in multi-threaded mode.", slog.Int("number of cores", numCores)) } else { + serverErrCh = make(chan error, 2) + logr.Debug("The DiceDB server has started in single-threaded mode.") numCores = 1 - logr.Info("The DiceDB server has started in single-threaded mode.") } // The runtime.GOMAXPROCS(numCores) call limits the number of operating system @@ -68,33 +71,10 @@ func main() { // improving concurrency performance across multiple goroutines. runtime.GOMAXPROCS(numCores) - shardManager := shard.NewShardManager(int8(numCores), watchChan, logr) - - // Initialize the AsyncServer - asyncServer := server.NewAsyncServer(shardManager, watchChan, logr) - httpServer := server.NewHTTPServer(shardManager, logr) - - // Initialize the HTTP server - - // Find a port and bind it - if err := asyncServer.FindPortAndBind(); err != nil { - cancel() - logr.Error("Error finding and binding port", - slog.Any("error", err), - ) - os.Exit(1) - } + // Initialize the ShardManager + shardManager := shard.NewShardManager(uint8(numCores), watchChan, serverErrCh, logr) wg := sync.WaitGroup{} - // Goroutine to handle shutdown signals - - wg.Add(1) - go func() { - defer wg.Done() - <-sigs - asyncServer.InitiateShutdown() - cancel() - }() wg.Add(1) go func() { @@ -102,50 +82,105 @@ func main() { shardManager.Run(ctx) }() - serverErrCh := make(chan error, 2) var serverWg sync.WaitGroup - serverWg.Add(1) - go func() { - defer serverWg.Done() - // Run the server - err := asyncServer.Run(ctx) - - // Handling different server errors - if err != nil { - if errors.Is(err, context.Canceled) { - logr.Debug("Server was canceled") - } else if errors.Is(err, server.ErrAborted) { - logr.Debug("Server received abort command") - } else { - logr.Error( - "Server error", - slog.Any("error", err), - ) - } - serverErrCh <- err - } else { - logr.Debug("Server stopped without error") + + // Initialize the AsyncServer server + // Find a port and bind it + if !config.EnableMultiThreading { + asyncServer := server.NewAsyncServer(shardManager, watchChan, logr) + if err := asyncServer.FindPortAndBind(); err != nil { + cancel() + logr.Error("Error finding and binding port", slog.Any("error", err)) + os.Exit(1) } - }() - serverWg.Add(1) - go func() { - defer serverWg.Done() - // Run the HTTP server - err := httpServer.Run(ctx) - if err != nil { - if errors.Is(err, context.Canceled) { - logr.Debug("HTTP Server was canceled") - } else if errors.Is(err, server.ErrAborted) { - logr.Debug("HTTP received abort command") + serverWg.Add(1) + go func() { + defer serverWg.Done() + // Run the server + err := asyncServer.Run(ctx) + + // Handling different server errors + if err != nil { + if errors.Is(err, context.Canceled) { + logr.Debug("Server was canceled") + } else if errors.Is(err, diceerrors.ErrAborted) { + logr.Debug("Server received abort command") + } else { + logr.Error( + "Server error", + slog.Any("error", err), + ) + } + serverErrCh <- err } else { - logr.Error("HTTP Server error", slog.Any("error", err)) + logr.Debug("Server stopped without error") } - serverErrCh <- err - } else { - logr.Debug("HTTP Server stopped without error") - } - }() + }() + + // Goroutine to handle shutdown signals + wg.Add(1) + go func() { + defer wg.Done() + <-sigs + asyncServer.InitiateShutdown() + cancel() + }() + + // Initialize the HTTP server + httpServer := server.NewHTTPServer(shardManager, logr) + serverWg.Add(1) + go func() { + defer serverWg.Done() + // Run the HTTP server + err := httpServer.Run(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + logr.Debug("HTTP Server was canceled") + } else if errors.Is(err, diceerrors.ErrAborted) { + logr.Debug("HTTP received abort command") + } else { + logr.Error("HTTP Server error", slog.Any("error", err)) + } + serverErrCh <- err + } else { + logr.Debug("HTTP Server stopped without error") + } + }() + } else { + workerManager := worker.NewWorkerManager(config.DiceConfig.Server.MaxClients, shardManager) + // Initialize the RESP Server + respServer := resp.NewServer(shardManager, workerManager, serverErrCh, logr) + serverWg.Add(1) + go func() { + defer serverWg.Done() + // Run the server + err := respServer.Run(ctx) + + // Handling different server errors + if err != nil { + if errors.Is(err, context.Canceled) { + logr.Debug("Server was canceled") + } else if errors.Is(err, diceerrors.ErrAborted) { + logr.Debug("Server received abort command") + } else { + logr.Error("Server error", "error", err) + } + serverErrCh <- err + } else { + logr.Debug("Server stopped without error") + } + }() + + // Goroutine to handle shutdown signals + wg.Add(1) + go func() { + defer wg.Done() + <-sigs + respServer.Shutdown() + cancel() + }() + } go func() { serverWg.Wait() @@ -153,8 +188,8 @@ func main() { }() for err := range serverErrCh { - if err != nil && errors.Is(err, server.ErrAborted) { - // if either the AsyncServer or the HTTPServer received an abort command, + if err != nil && errors.Is(err, diceerrors.ErrAborted) { + // if either the AsyncServer/RESPServer or the HTTPServer received an abort command, // cancel the context, helping gracefully exiting all servers cancel() } diff --git a/mocks/slog_noop.go b/mocks/slog_noop.go new file mode 100644 index 000000000..f8dae98b3 --- /dev/null +++ b/mocks/slog_noop.go @@ -0,0 +1,14 @@ +package mocks + +import ( + "context" + "log/slog" +) + +// SlogNoopHandler is a no-op implementation of slog.Handler +type SlogNoopHandler struct{} + +func (h SlogNoopHandler) Enabled(context.Context, slog.Level) bool { return false } +func (h SlogNoopHandler) Handle(context.Context, slog.Record) error { return nil } +func (h SlogNoopHandler) WithAttrs([]slog.Attr) slog.Handler { return h } +func (h SlogNoopHandler) WithGroup(string) slog.Handler { return h }