Skip to content

Commit

Permalink
feat(go): add basic driver logging
Browse files Browse the repository at this point in the history
Fixes #492.
  • Loading branch information
lidavidm committed Sep 12, 2023
1 parent d9439e2 commit b9e144d
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 24 deletions.
15 changes: 15 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import (
"github.com/apache/arrow/go/v13/arrow/memory"
"github.com/bluele/gcache"
"golang.org/x/exp/maps"
"golang.org/x/exp/slog"
"google.golang.org/grpc"
grpccodes "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -156,6 +157,7 @@ func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
db.dialOpts.block = false
db.dialOpts.maxMsgSize = 16 * 1024 * 1024

db.logger = nilLogger()
db.options = make(map[string]string)

return db, db.SetOptions(opts)
Expand Down Expand Up @@ -186,11 +188,20 @@ type database struct {
timeout timeoutOption
dialOpts dbDialOpts
enableCookies bool
logger *slog.Logger
options map[string]string

alloc memory.Allocator
}

func (d *database) SetLogger(logger *slog.Logger) {
if logger != nil {
d.logger = logger
} else {
d.logger = nilLogger()
}
}

func (d *database) SetOptions(cnOptions map[string]string) error {
var tlsConfig tls.Config

Expand Down Expand Up @@ -691,6 +702,10 @@ func (b *bearerAuthMiddleware) HeadersReceived(ctx context.Context, md metadata.
func getFlightClient(ctx context.Context, loc string, d *database) (*flightsql.Client, error) {
authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()}
middleware := []flight.ClientMiddleware{
{
Unary: makeUnaryLoggingInterceptor(d.logger),
Stream: makeStreamLoggingInterceptor(d.logger),
},
flight.CreateClientMiddleware(authMiddle),
{
Unary: unaryTimeoutInterceptor,
Expand Down
103 changes: 103 additions & 0 deletions go/adbc/driver/flightsql/logging.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package flightsql

import (
"context"
"io"
"os"
"time"

"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/exp/slog"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

func nilLogger() *slog.Logger {
h := slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
AddSource: false,
Level: slog.LevelError,
})
return slog.New(h)
}

func makeUnaryLoggingInterceptor(logger *slog.Logger) grpc.UnaryClientInterceptor {
interceptor := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
start := time.Now()
// Ignore errors
outgoing, _ := metadata.FromOutgoingContext(ctx)
err := invoker(ctx, method, req, reply, cc, opts...)
if logger.Enabled(ctx, slog.LevelDebug) {
logger.DebugContext(ctx, method, "target", cc.Target(), "duration", time.Since(start), "err", err, "metadata", outgoing)
} else {
keys := maps.Keys(outgoing)
slices.Sort(keys)
logger.InfoContext(ctx, method, "target", cc.Target(), "duration", time.Since(start), "err", err, "metadata", keys)
}
return err
}
return interceptor
}

func makeStreamLoggingInterceptor(logger *slog.Logger) grpc.StreamClientInterceptor {
interceptor := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
start := time.Now()
// Ignore errors
outgoing, _ := metadata.FromOutgoingContext(ctx)
stream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
logger.InfoContext(ctx, method, "target", cc.Target(), "duration", time.Since(start), "err", err)
return stream, err
}

return &loggedStream{ClientStream: stream, logger: logger, ctx: ctx, method: method, start: start, target: cc.Target(), outgoing: outgoing}, err
}
return interceptor
}

type loggedStream struct {
grpc.ClientStream

logger *slog.Logger
ctx context.Context
method string
start time.Time
target string
outgoing metadata.MD
}

func (stream *loggedStream) RecvMsg(m any) error {
err := stream.ClientStream.RecvMsg(m)
if err != nil {
loggedErr := err
if loggedErr == io.EOF {
loggedErr = nil
}

if stream.logger.Enabled(stream.ctx, slog.LevelDebug) {
stream.logger.DebugContext(stream.ctx, stream.method, "target", stream.target, "duration", time.Since(stream.start), "err", loggedErr, "metadata", stream.outgoing)
} else {
keys := maps.Keys(stream.outgoing)
slices.Sort(keys)
stream.logger.InfoContext(stream.ctx, stream.method, "target", stream.target, "duration", time.Since(stream.start), "err", loggedErr, "metadata", keys)
}
}
return err
}
30 changes: 30 additions & 0 deletions go/adbc/ext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package adbc

import (
"golang.org/x/exp/slog"
)

// DatabaseLogging is a Database that also supports logging information to an
// application-supplied log sink.
//
// EXPERIMENTAL. Not formally part of the ADBC APIs.
type DatabaseLogging interface {
SetLogger(*slog.Logger)
}
2 changes: 1 addition & 1 deletion go/adbc/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

module github.com/apache/arrow-adbc/go/adbc

go 1.18
go 1.19

require (
github.com/apache/arrow/go/v13 v13.0.0
Expand Down
54 changes: 49 additions & 5 deletions go/adbc/pkg/_tmpl/driver.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import (
"os"
"runtime"
"runtime/cgo"
"strings"
"sync/atomic"
"unsafe"

Expand All @@ -61,6 +62,7 @@ import (
"github.com/apache/arrow/go/v13/arrow/cdata"
"github.com/apache/arrow/go/v13/arrow/memory"
"github.com/apache/arrow/go/v13/arrow/memory/mallocator"
"golang.org/x/exp/slog"
)

// Must use malloc() to respect CGO rules
Expand All @@ -71,6 +73,7 @@ var drv = {{.Driver}}{Alloc: mallocator.NewMallocator()}
var globalPoison int32 = 0

const errPrefix = "[{{.Prefix}}] "
const logLevelEnvVar = "ADBC_DRIVER_{{.PrefixUpper}}_LOG_LEVEL"

func setErr(err *C.struct_AdbcError, format string, vals ...interface{}) {
if err == nil {
Expand Down Expand Up @@ -162,6 +165,45 @@ func poison(err *C.struct_AdbcError, fname string, e interface{}) C.AdbcStatusCo
return C.ADBC_STATUS_INTERNAL
}

// Check environment variables and enable logging if possible.
func initLoggingFromEnv(db adbc.Database) {
logLevel := slog.LevelError
switch strings.ToLower(os.Getenv(logLevelEnvVar)) {
case "debug":
logLevel = slog.LevelDebug
case "info":
logLevel = slog.LevelInfo
case "warn":
case "warning":
logLevel = slog.LevelWarn
case "error":
logLevel = slog.LevelError
case "":
return
default:
printLoggingHelp()
return
}

h := slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
AddSource: false,
Level: logLevel,
})
logger := slog.New(h)

ext, ok := db.(adbc.DatabaseLogging)
if !ok {
logger.Error("{{.Prefix}} does not support logging")
return
}
ext.SetLogger(logger)
}

func printLoggingHelp() {
fmt.Fprintf(os.Stderr, "{{.Prefix}}: to enable logging, set %s to 'debug', 'info', 'warn', or 'error'", logLevelEnvVar)
}


// Allocate a new cgo.Handle and store its address in a heap-allocated
// uintptr_t. Experimentally, this was found to be necessary, else
// something (the Go runtime?) would corrupt (garbage-collect?) the
Expand Down Expand Up @@ -305,7 +347,7 @@ func (cStream *cArrayStream) maybeError() C.int {

//export {{.Prefix}}ArrayStreamGetLastError
func {{.Prefix}}ArrayStreamGetLastError(stream *C.struct_ArrowArrayStream) *C.cchar_t {
if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) {
if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) || stream.private_data == nil {
return nil
}
cStream := getFromHandle[cArrayStream](stream.private_data)
Expand All @@ -317,7 +359,7 @@ func {{.Prefix}}ArrayStreamGetLastError(stream *C.struct_ArrowArrayStream) *C.cc

//export {{.Prefix}}ArrayStreamGetNext
func {{.Prefix}}ArrayStreamGetNext(stream *C.struct_ArrowArrayStream, array *C.struct_ArrowArray) C.int {
if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) {
if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) || stream.private_data == nil {
return C.EINVAL
}
cStream := getFromHandle[cArrayStream](stream.private_data)
Expand All @@ -332,7 +374,7 @@ func {{.Prefix}}ArrayStreamGetNext(stream *C.struct_ArrowArrayStream, array *C.s

//export {{.Prefix}}ArrayStreamGetSchema
func {{.Prefix}}ArrayStreamGetSchema(stream *C.struct_ArrowArrayStream, schema *C.struct_ArrowSchema) C.int {
if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) {
if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) || stream.private_data == nil {
return C.EINVAL
}
cStream := getFromHandle[cArrayStream](stream.private_data)
Expand All @@ -346,7 +388,7 @@ func {{.Prefix}}ArrayStreamGetSchema(stream *C.struct_ArrowArrayStream, schema *

//export {{.Prefix}}ArrayStreamRelease
func {{.Prefix}}ArrayStreamRelease(stream *C.struct_ArrowArrayStream) {
if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) {
if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) || stream.private_data == nil {
return
}
h := (*(*cgo.Handle)(stream.private_data))
Expand All @@ -365,7 +407,7 @@ func {{.Prefix}}ArrayStreamRelease(stream *C.struct_ArrowArrayStream) {

//export {{.Prefix}}ErrorFromArrayStream
func {{.Prefix}}ErrorFromArrayStream(stream *C.struct_ArrowArrayStream, status *C.AdbcStatusCode) (*C.struct_AdbcError) {
if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) {
if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) || stream.private_data == nil {
return nil
}
cStream := getFromHandle[cArrayStream](stream.private_data)
Expand Down Expand Up @@ -509,6 +551,8 @@ func {{.Prefix}}DatabaseInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError)
return C.AdbcStatusCode(errToAdbcErr(err, aerr))
}

initLoggingFromEnv(adb)

cdb.db = adb
return C.ADBC_STATUS_OK
}
Expand Down
Loading

0 comments on commit b9e144d

Please sign in to comment.