Skip to content

Commit

Permalink
Fix custom tracer header propagation (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexshtin authored Oct 9, 2020
1 parent 0914229 commit d2bd562
Show file tree
Hide file tree
Showing 17 changed files with 352 additions and 171 deletions.
2 changes: 1 addition & 1 deletion converter/composite_data_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type (
// NewCompositeDataConverter creates new instance of CompositeDataConverter from ordered list of PayloadConverters.
// Order is important here because during serialization DataConverter will try PayloadsConverters in
// that order until PayloadConverter returns non nil payload.
// Last PayloadConverter should always serialize the value (JSONPayloadConverter is good candidate for it),
// Last PayloadConverter should always serialize the value (JSONPayloadConverter is good candidate for it).
func NewCompositeDataConverter(payloadConverters ...PayloadConverter) *CompositeDataConverter {
dc := &CompositeDataConverter{
payloadConverters: make(map[string]PayloadConverter, len(payloadConverters)),
Expand Down
5 changes: 3 additions & 2 deletions converter/default_data_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ var (
defaultDataConverter = NewCompositeDataConverter(
NewNilPayloadConverter(),
NewByteSlicePayloadConverter(),
// Only one proto converter should be used.
// Although they check for different interfaces (proto.Message and proto.Marshaler) all proto messages implements both interfaces.

// Only one proto converter (JSON or regular) should be used because they check for the same proto.Message interface.
NewProtoJSONPayloadConverter(),
// NewProtoPayloadConverter(),

NewJSONPayloadConverter(),
)
)
Expand Down
3 changes: 2 additions & 1 deletion internal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,8 @@ func NewValues(data *commonpb.Payloads) converter.EncodedValues {
return newEncodedValues(data, nil)
}

// checkHealth checks service health using gRPC health check: // https://github.com/grpc/grpc/blob/master/doc/health-checking.md
// checkHealth checks service health using gRPC health check:
// https://github.com/grpc/grpc/blob/master/doc/health-checking.md
func checkHealth(connection grpc.ClientConnInterface, options ConnectionOptions) error {
if options.DisableHealthCheck {
return nil
Expand Down
2 changes: 1 addition & 1 deletion internal/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ func Test_ContinueAsNewError(t *testing.T) {

s := &WorkflowTestSuite{
header: header,
contextPropagators: []ContextPropagator{NewStringMapPropagator([]string{"test"})},
contextPropagators: []ContextPropagator{NewKeysPropagator([]string{"test"})},
}
wfEnv := s.NewTestWorkflowEnvironment()
wfEnv.RegisterWorkflowWithOptions(continueAsNewWorkflowFn, RegisterWorkflowOptions{
Expand Down
13 changes: 11 additions & 2 deletions internal/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type HeaderWriter interface {

// HeaderReader is an interface to read information from temporal headers
type HeaderReader interface {
Get(string) (*commonpb.Payload, bool)
ForEachKey(handler func(string, *commonpb.Payload) error) error
}

Expand Down Expand Up @@ -74,9 +75,17 @@ func (hr *headerReader) ForEachKey(handler func(string, *commonpb.Payload) error
return nil
}

func (hr *headerReader) Get(key string) (*commonpb.Payload, bool) {
if hr.header == nil {
panic("headerReader.header is nil")
}
payload, ok := hr.header.Fields[key]
return payload, ok
}

// NewHeaderReader returns a header reader interface
func NewHeaderReader(header *commonpb.Header) HeaderReader {
return &headerReader{header}
return &headerReader{header: header}
}

type headerWriter struct {
Expand All @@ -95,5 +104,5 @@ func NewHeaderWriter(header *commonpb.Header) HeaderWriter {
if header != nil && header.Fields == nil {
header.Fields = make(map[string]*commonpb.Payload)
}
return &headerWriter{header}
return &headerWriter{header: header}
}
61 changes: 60 additions & 1 deletion internal/headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func encodeString(t *testing.T, s string) *commonpb.Payload {
return p
}

func TestHeaderReader(t *testing.T) {
func TestHeaderReader_ForEachKey(t *testing.T) {
t.Parallel()
tests := []struct {
name string
Expand Down Expand Up @@ -157,3 +157,62 @@ func TestHeaderReader(t *testing.T) {
})
}
}

func TestHeaderReader_Get(t *testing.T) {
t.Parallel()
tests := []struct {
name string
header *commonpb.Header
key string
headerExists bool
}{
{
"valid key",
&commonpb.Header{
Fields: map[string]*commonpb.Payload{
"key1": encodeString(t, "val1"),
"key2": encodeString(t, "val2"),
},
},
"key1",
true,
},
{
"invalid key",
&commonpb.Header{
Fields: map[string]*commonpb.Payload{
"key1": encodeString(t, "val1"),
"key2": encodeString(t, "val2"),
},
},
"key3",
false,
},
{
"nil fields",
&commonpb.Header{},
"key1",
false,
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
reader := NewHeaderReader(test.header)
_, headerExist := reader.Get(test.key)
if test.headerExists {
assert.True(t, headerExist)
} else {
assert.False(t, headerExist)
}
})
}

t.Run("nil panic", func(t *testing.T) {
reader := NewHeaderReader(nil)
assert.Panics(t, func() { reader.Get("") })
})

}
2 changes: 1 addition & 1 deletion internal/internal_task_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1801,7 +1801,7 @@ func (ath *activityTaskHandlerImpl) Execute(taskQueue string, t *workflowservice
for _, ctxProp := range ath.contextPropagators {
var err error
if ctx, err = ctxProp.Extract(ctx, NewHeaderReader(t.Header)); err != nil {
return nil, fmt.Errorf("unable to propagate context %v", err)
return nil, fmt.Errorf("unable to propagate context: %w", err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion internal/internal_task_pollers.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ func (lath *localActivityTaskHandler) executeLocalActivityTask(task *localActivi
result = &localActivityResult{
task: task,
result: nil,
err: fmt.Errorf("unable to propagate context %v", err),
err: fmt.Errorf("unable to propagate context: %w", err),
}
return result
}
Expand Down
3 changes: 2 additions & 1 deletion internal/internal_workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ func (d *syncWorkflowDefinition) Execute(env WorkflowEnvironment, header *common
for _, ctxProp := range env.GetContextPropagators() {
var err error
if rootCtx, err = ctxProp.ExtractToWorkflow(rootCtx, NewHeaderReader(header)); err != nil {
panic(fmt.Sprintf("Unable to propagate context %v", err))
panic(fmt.Sprintf("Unable to propagate context: %v", err))
}
}

Expand Down Expand Up @@ -1192,6 +1192,7 @@ func setWorkflowEnvOptionsIfNotExist(ctx Context) Context {
if newOptions.DataConverter == nil {
newOptions.DataConverter = converter.GetDefaultDataConverter()
}

return WithValue(ctx, workflowEnvOptionsContextKey, &newOptions)
}

Expand Down
111 changes: 38 additions & 73 deletions internal/internal_workflow_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,21 @@ type (
}
)

// stringMapPropagator propagates the list of keys across a workflow,
// keysPropagator propagates the list of keys across a workflow,
// interpreting the payloads as strings.
type stringMapPropagator struct {
keys map[string]struct{}
type keysPropagator struct {
keys []string
}

// NewStringMapPropagator returns a context propagator that propagates a set of
// NewKeysPropagator returns a context propagator that propagates a set of
// string key-value pairs across a workflow
func NewStringMapPropagator(keys []string) ContextPropagator {
keyMap := make(map[string]struct{}, len(keys))
for _, key := range keys {
keyMap[key] = struct{}{}
}
return &stringMapPropagator{keyMap}
func NewKeysPropagator(keys []string) ContextPropagator {
return &keysPropagator{keys}
}

// Inject injects values from context into headers for propagation
func (s *stringMapPropagator) Inject(ctx context.Context, writer HeaderWriter) error {
for key := range s.keys {
func (s *keysPropagator) Inject(ctx context.Context, writer HeaderWriter) error {
for _, key := range s.keys {
value, ok := ctx.Value(contextKey(key)).(string)
if !ok {
return fmt.Errorf("unable to extract key from context %v", key)
Expand All @@ -101,8 +97,8 @@ func (s *stringMapPropagator) Inject(ctx context.Context, writer HeaderWriter) e
}

// InjectFromWorkflow injects values from context into headers for propagation
func (s *stringMapPropagator) InjectFromWorkflow(ctx Context, writer HeaderWriter) error {
for key := range s.keys {
func (s *keysPropagator) InjectFromWorkflow(ctx Context, writer HeaderWriter) error {
for _, key := range s.keys {
value, ok := ctx.Value(contextKey(key)).(string)
if !ok {
return fmt.Errorf("unable to extract key from context %v", key)
Expand All @@ -117,37 +113,37 @@ func (s *stringMapPropagator) InjectFromWorkflow(ctx Context, writer HeaderWrite
}

// Extract extracts values from headers and puts them into context
func (s *stringMapPropagator) Extract(ctx context.Context, reader HeaderReader) (context.Context, error) {
if err := reader.ForEachKey(func(key string, value *commonpb.Payload) error {
if _, ok := s.keys[key]; ok {
var decodedValue string
err := converter.GetDefaultDataConverter().FromPayload(value, &decodedValue)
if err != nil {
return err
}
ctx = context.WithValue(ctx, contextKey(key), decodedValue)
func (s *keysPropagator) Extract(ctx context.Context, reader HeaderReader) (context.Context, error) {
for _, key := range s.keys {
value, ok := reader.Get(key)
if !ok {
// If key that should be propagated doesn't exist in the header, ignore the key.
continue
}
var decodedValue string
err := converter.GetDefaultDataConverter().FromPayload(value, &decodedValue)
if err != nil {
return ctx, err
}
return nil
}); err != nil {
return nil, err
ctx = context.WithValue(ctx, contextKey(key), decodedValue)
}
return ctx, nil
}

// ExtractToWorkflow extracts values from headers and puts them into context
func (s *stringMapPropagator) ExtractToWorkflow(ctx Context, reader HeaderReader) (Context, error) {
if err := reader.ForEachKey(func(key string, value *commonpb.Payload) error {
if _, ok := s.keys[key]; ok {
var decodedValue string
err := converter.GetDefaultDataConverter().FromPayload(value, &decodedValue)
if err != nil {
return err
}
ctx = WithValue(ctx, contextKey(key), decodedValue)
func (s *keysPropagator) ExtractToWorkflow(ctx Context, reader HeaderReader) (Context, error) {
for _, key := range s.keys {
value, ok := reader.Get(key)
if !ok {
// If key that should be propagated doesn't exist in the header, ignore the key.
continue
}
var decodedValue string
err := converter.GetDefaultDataConverter().FromPayload(value, &decodedValue)
if err != nil {
return ctx, err
}
return nil
}); err != nil {
return nil, err
ctx = WithValue(ctx, contextKey(key), decodedValue)
}
return ctx, nil
}
Expand Down Expand Up @@ -1015,7 +1011,7 @@ func (s *workflowClientTestSuite) TestStartWorkflow() {
WorkflowTaskTimeout: timeoutInSeconds,
}
f1 := func(ctx Context, r []byte) string {
return "result"
panic("this is just a stub")
}

createResponse := &workflowservice.StartWorkflowExecutionResponse{
Expand All @@ -1029,37 +1025,6 @@ func (s *workflowClientTestSuite) TestStartWorkflow() {
s.Equal(createResponse.GetRunId(), resp.RunID)
}

func (s *workflowClientTestSuite) TestStartWorkflow_WithContext() {
s.client = NewServiceClient(s.service, nil, ClientOptions{
ContextPropagators: []ContextPropagator{NewStringMapPropagator([]string{testHeader})},
})
client, ok := s.client.(*WorkflowClient)
s.True(ok)
options := StartWorkflowOptions{
ID: workflowID,
TaskQueue: taskqueue,
WorkflowExecutionTimeout: timeoutInSeconds,
WorkflowTaskTimeout: timeoutInSeconds,
}
f1 := func(ctx Context, r []byte) error {
value := ctx.Value(contextKey(testHeader))
if val, ok := value.([]byte); ok {
s.Equal("test-data", string(val))
return nil
}
return fmt.Errorf("context did not propagate to workflow")
}

createResponse := &workflowservice.StartWorkflowExecutionResponse{
RunId: runID,
}
s.service.EXPECT().StartWorkflowExecution(gomock.Any(), gomock.Any(), gomock.Any()).Return(createResponse, nil)

resp, err := client.StartWorkflow(context.Background(), options, f1, []byte("test"))
s.Nil(err)
s.Equal(createResponse.GetRunId(), resp.RunID)
}

func (s *workflowClientTestSuite) TestStartWorkflowWithDataConverter() {
dc := iconverter.NewTestDataConverter()
s.client = NewServiceClient(s.service, nil, ClientOptions{DataConverter: dc})
Expand All @@ -1072,7 +1037,7 @@ func (s *workflowClientTestSuite) TestStartWorkflowWithDataConverter() {
WorkflowTaskTimeout: timeoutInSeconds,
}
f1 := func(ctx Context, r []byte) string {
return "result"
panic("this is just a stub")
}
input := []byte("test")

Expand Down Expand Up @@ -1111,7 +1076,7 @@ func (s *workflowClientTestSuite) TestStartWorkflow_WithMemoAndSearchAttr() {
SearchAttributes: searchAttributes,
}
wf := func(ctx Context) string {
return "result"
panic("this is just a stub")
}
startResp := &workflowservice.StartWorkflowExecutionResponse{}

Expand Down Expand Up @@ -1145,7 +1110,7 @@ func (s *workflowClientTestSuite) SignalWithStartWorkflowWithMemoAndSearchAttr()
SearchAttributes: searchAttributes,
}
wf := func(ctx Context) string {
return "result"
panic("this is just a stub")
}
startResp := &workflowservice.StartWorkflowExecutionResponse{}

Expand Down
Loading

0 comments on commit d2bd562

Please sign in to comment.