diff --git a/types/context.go b/types/context.go index 60aaa72144f2..ccd860efa057 100644 --- a/types/context.go +++ b/types/context.go @@ -237,6 +237,10 @@ func (c Context) WithValue(key, value interface{}) Context { } func (c Context) Value(key interface{}) interface{} { + if key == SdkContextKey { + return c + } + return c.baseCtx.Value(key) } diff --git a/types/context_test.go b/types/context_test.go index 3052890a3f0c..fafbfb441211 100644 --- a/types/context_test.go +++ b/types/context_test.go @@ -220,4 +220,9 @@ func (s *contextTestSuite) TestUnwrapSDKContext() { ctx = context.Background() s.Require().Panics(func() { types.UnwrapSDKContext(ctx) }) + + // test unwrapping when we've used context.WithValue + ctx = context.WithValue(sdkCtx, "foo", "bar") + sdkCtx2 = types.UnwrapSDKContext(ctx) + s.Require().Equal(sdkCtx, sdkCtx2) }