Skip to content

Commit

Permalink
fix regression in downloading out/err streams (bazelbuild#472)
Browse files Browse the repository at this point in the history
The server may terminate the streaming request before sending available
bytes, which prevents the download fallback from getting them.

This change implements a resuming fallback to download the remainder of
the streams.

Additionally, the digests of the out and err streams will be printed in
the output of `execution_action` command to allow the user to access
them even if the remote execution fails, in which case they will not be
retrievable by `show_action`.
  • Loading branch information
mrahs authored Jul 4, 2023
1 parent 2dfac22 commit 8ab3738
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 36 deletions.
28 changes: 23 additions & 5 deletions go/pkg/fakes/cas.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ import (
bspb "google.golang.org/genproto/googleapis/bytestream"
)

var zstdEncoder, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true))
var zstdDecoder, _ = zstd.NewReader(nil)
var (
zstdEncoder, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true))
zstdDecoder, _ = zstd.NewReader(nil)
)

// Reader implements ByteStream's Read interface, returning one blob.
type Reader struct {
Expand Down Expand Up @@ -684,8 +686,11 @@ func (f *CAS) Write(stream bsgrpc.ByteStream_WriteServer) (err error) {

// Read implements the corresponding RE API function.
func (f *CAS) Read(req *bspb.ReadRequest, stream bsgrpc.ByteStream_ReadServer) error {
if req.ReadOffset != 0 || req.ReadLimit != 0 {
return status.Error(codes.Unimplemented, "test fake does not implement read_offset or limit")
if req.ReadOffset < 0 {
return status.Error(codes.InvalidArgument, "test fake expected a positive value for offset")
}
if req.ReadLimit != 0 {
return status.Error(codes.Unimplemented, "test fake does not implement limit")
}

path := strings.Split(req.ResourceName, "/")
Expand Down Expand Up @@ -726,12 +731,25 @@ func (f *CAS) Read(req *bspb.ReadRequest, stream bsgrpc.ByteStream_ReadServer) e
}

resp := &bspb.ReadResponse{}
var offset int64
for ch.HasNext() {
chunk, err := ch.Next()
resp.Data = chunk.Data
if err != nil {
return err
}
// Seek to req.ReadOffset.
offset += int64(len(chunk.Data))
if offset < req.ReadOffset {
continue
}
// Scale the offset to the chunk.
offset = offset - req.ReadOffset // The chunk tail that we want.
offset = int64(len(chunk.Data)) - offset // The chunk head that we don't want.
if offset < 0 {
// The chunk is past the offset.
offset = 0
}
resp.Data = chunk.Data[int(offset):]
err = stream.Send(resp)
if err != nil {
return err
Expand Down
75 changes: 48 additions & 27 deletions go/pkg/rexec/rexec.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/bazelbuild/remote-apis-sdks/go/pkg/filemetadata"
"github.com/bazelbuild/remote-apis-sdks/go/pkg/outerr"
"github.com/bazelbuild/remote-apis-sdks/go/pkg/uploadinfo"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/prototext"
Expand Down Expand Up @@ -76,15 +75,20 @@ func (c *Client) NewContext(ctx context.Context, cmd *command.Command, opt *comm
}, nil
}

func (ec *Context) downloadStream(raw []byte, dgPb *repb.Digest, write func([]byte)) error {
// downloadStream reads the blob for the digest dgPb into memory and forwards the bytes to the write function.
func (ec *Context) downloadStream(raw []byte, dgPb *repb.Digest, offset int64, write func([]byte)) error {
if raw != nil {
write(raw)
o := int(offset)
if int64(o) != offset || o > len(raw) {
return fmt.Errorf("offset %d is out of range for length %d", offset, len(raw))
}
write(raw[o:])
} else if dgPb != nil {
dg, err := digest.NewFromProto(dgPb)
if err != nil {
return err
}
bytes, stats, err := ec.client.GrpcClient.ReadBlob(ec.ctx, dg)
bytes, stats, err := ec.client.GrpcClient.ReadBlobRange(ec.ctx, dg, offset, 0)
if err != nil {
return err
}
Expand Down Expand Up @@ -137,10 +141,10 @@ func (ec *Context) setOutputMetadata() {
}

func (ec *Context) downloadOutErr() *command.Result {
if err := ec.downloadStream(ec.resPb.StdoutRaw, ec.resPb.StdoutDigest, ec.oe.WriteOut); err != nil {
if err := ec.downloadStream(ec.resPb.StdoutRaw, ec.resPb.StdoutDigest, 0, ec.oe.WriteOut); err != nil {
return command.NewRemoteErrorResult(err)
}
if err := ec.downloadStream(ec.resPb.StderrRaw, ec.resPb.StderrDigest, ec.oe.WriteErr); err != nil {
if err := ec.downloadStream(ec.resPb.StderrRaw, ec.resPb.StderrDigest, 0, ec.oe.WriteErr); err != nil {
return command.NewRemoteErrorResult(err)
}
return command.NewResultFromExitCode((int)(ec.resPb.ExitCode))
Expand Down Expand Up @@ -332,48 +336,65 @@ func (ec *Context) ExecuteRemotely() {
ec.Metadata.RealBytesUploaded = bytesMoved
log.V(1).Infof("%s %s> Executing remotely...\n%s", cmdID, executionID, strings.Join(ec.cmd.Args, " "))
ec.Metadata.EventTimes[command.EventExecuteRemotely] = &command.TimeInterval{From: time.Now()}
eg, ctx := errgroup.WithContext(ec.ctx)
// Initiate each streaming request once at most.
var streamOut, streamErr sync.Once
op, err := ec.client.GrpcClient.ExecuteAndWaitProgress(ctx, &repb.ExecuteRequest{
var streamWg sync.WaitGroup
// These variables are owned by the progress callback (which is async but not concurrent) until the execution returns.
var nOutStreamed, nErrStreamed int64
op, err := ec.client.GrpcClient.ExecuteAndWaitProgress(ec.ctx, &repb.ExecuteRequest{
InstanceName: ec.client.GrpcClient.InstanceName,
SkipCacheLookup: !ec.opt.AcceptCached || ec.opt.DoNotCache,
ActionDigest: ec.Metadata.ActionDigest.ToProto(),
}, func(md *repb.ExecuteOperationMetadata) {
if !ec.opt.StreamOutErr {
return
}
// The server may return either, both, or neither of the stream names, and not necessarily in the same or first call.
// The streaming request for each must be initiated once at most.
if name := md.GetStdoutStreamName(); name != "" {
streamOut.Do(func() {
eg.Go(func() error {
streamWg.Add(1)
go func() {
defer streamWg.Done()
path := fmt.Sprintf("%s/logstreams/%s", ec.client.GrpcClient.InstanceName, name)
log.V(1).Infof("%s %s> Streaming to stdout from %q", cmdID, executionID, path)
_, err = ec.client.GrpcClient.ReadResourceTo(ctx, path, outerr.NewOutWriter(ec.oe))
return err
})
// Ignoring the error here since the net result is downloading the full stream after the fact.
n, err := ec.client.GrpcClient.ReadResourceTo(ec.ctx, path, outerr.NewOutWriter(ec.oe))
if err != nil {
log.Errorf("%s %s> error streaming stdout: %v", cmdID, executionID, err)
}
nOutStreamed += n
}()
})
}
if name := md.GetStderrStreamName(); name != "" {
streamErr.Do(func() {
eg.Go(func() error {
streamWg.Add(1)
go func() {
defer streamWg.Done()
path := fmt.Sprintf("%s/logstreams/%s", ec.client.GrpcClient.InstanceName, name)
log.V(1).Infof("%s %s> Streaming to stdout from %q", cmdID, executionID, path)
_, err = ec.client.GrpcClient.ReadResourceTo(ctx, path, outerr.NewErrWriter(ec.oe))
return err
})
// Ignoring the error here since the net result is downloading the full stream after the fact.
n, err := ec.client.GrpcClient.ReadResourceTo(ec.ctx, path, outerr.NewErrWriter(ec.oe))
if err != nil {
log.Errorf("%s %s> error streaming stderr: %v", cmdID, executionID, err)
}
nErrStreamed += n
}()
})
}
})
ec.Metadata.EventTimes[command.EventExecuteRemotely].To = time.Now()
// This will always be called after both of the Add calls above if any, because the execution call above returns
// after all invokations of the progress callback.
// The server will terminate the streams when the execution finishes, regardless of its result, which will ensure the goroutines
// will have terminated at this point.
streamWg.Wait()
if err != nil {
ec.Result = command.NewRemoteErrorResult(err)
return
}

if err := eg.Wait(); err != nil {
ec.Result = command.NewRemoteErrorResult(fmt.Errorf("failure writing output streams: %v", err))
return
}

or := op.GetResponse()
if or == nil {
ec.Result = command.NewRemoteErrorResult(fmt.Errorf("unexpected operation result type: %v", or))
Expand All @@ -397,16 +418,16 @@ func (ec *Context) ExecuteRemotely() {
ec.setOutputMetadata()
ec.Result = command.NewResultFromExitCode((int)(ec.resPb.ExitCode))
if ec.opt.DownloadOutErr {
streamOut.Do(func() {
if err := ec.downloadStream(ec.resPb.StdoutRaw, ec.resPb.StdoutDigest, ec.oe.WriteOut); err != nil {
if ec.resPb.StdoutDigest == nil || ec.resPb.StdoutDigest.SizeBytes > nOutStreamed {
if err := ec.downloadStream(ec.resPb.StdoutRaw, ec.resPb.StdoutDigest, nOutStreamed, ec.oe.WriteOut); err != nil {
ec.Result = command.NewRemoteErrorResult(err)
}
})
streamErr.Do(func() {
if err := ec.downloadStream(ec.resPb.StderrRaw, ec.resPb.StderrDigest, ec.oe.WriteErr); err != nil {
}
if ec.resPb.StderrDigest == nil || ec.resPb.StderrDigest.SizeBytes > nErrStreamed {
if err := ec.downloadStream(ec.resPb.StderrRaw, ec.resPb.StderrDigest, nErrStreamed, ec.oe.WriteErr); err != nil {
ec.Result = command.NewRemoteErrorResult(err)
}
})
}
}
if ec.Result.Err == nil && ec.opt.DownloadOutputs {
log.V(1).Infof("%s %s> Downloading outputs...", cmdID, executionID)
Expand Down
50 changes: 46 additions & 4 deletions go/pkg/rexec/rexec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,10 @@ func TestStreamOutErr(t *testing.T) {
requestStreams bool
hasStdOutStream bool
hasStdErrStream bool
outChunks []string
errChunks []string
outContent string
errContent string
wantRes *command.Result
wantStdOut string
wantStdErr string
Expand All @@ -413,6 +417,8 @@ func TestStreamOutErr(t *testing.T) {
requestStreams: false,
hasStdOutStream: true,
hasStdErrStream: true,
outContent: "stdout-blob",
errContent: "stderr-blob",
wantRes: &command.Result{Status: command.SuccessResultStatus},
wantStdOut: "stdout-blob",
wantStdErr: "stderr-blob",
Expand All @@ -422,6 +428,7 @@ func TestStreamOutErr(t *testing.T) {
requestStreams: true,
hasStdOutStream: true,
hasStdErrStream: false,
errContent: "stderr-blob",
wantRes: &command.Result{Status: command.SuccessResultStatus},
wantStdOut: "streaming-stdout",
wantStdErr: "stderr-blob",
Expand All @@ -431,6 +438,7 @@ func TestStreamOutErr(t *testing.T) {
requestStreams: true,
hasStdOutStream: false,
hasStdErrStream: true,
outContent: "stdout-blob",
wantRes: &command.Result{Status: command.SuccessResultStatus},
wantStdOut: "stdout-blob",
wantStdErr: "streaming-stderr",
Expand All @@ -440,6 +448,8 @@ func TestStreamOutErr(t *testing.T) {
requestStreams: true,
hasStdOutStream: false,
hasStdErrStream: false,
outContent: "stdout-blob",
errContent: "stderr-blob",
wantRes: &command.Result{Status: command.SuccessResultStatus},
wantStdOut: "stdout-blob",
wantStdErr: "stderr-blob",
Expand All @@ -459,6 +469,8 @@ func TestStreamOutErr(t *testing.T) {
requestStreams: true,
hasStdOutStream: true,
hasStdErrStream: true,
outContent: "stdout-blob",
errContent: "stderr-blob",
wantRes: &command.Result{Status: command.CacheHitResultStatus},
wantStdOut: "stdout-blob",
wantStdErr: "stderr-blob",
Expand All @@ -483,6 +495,20 @@ func TestStreamOutErr(t *testing.T) {
wantStdOut: "streaming-stdout",
wantStdErr: "streaming-stderr",
},
{
name: "remote failure partial stream",
requestStreams: true,
hasStdOutStream: true,
hasStdErrStream: true,
outChunks: []string{"streaming"},
errChunks: []string{"streaming"},
outContent: "streaming-stdout",
errContent: "streaming-stderr",
status: status.New(codes.Internal, "problem"),
wantRes: command.NewRemoteErrorResult(status.New(codes.Internal, "problem").Err()),
wantStdOut: "streaming-stdout",
wantStdErr: "streaming-stderr",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
Expand All @@ -500,11 +526,27 @@ func TestStreamOutErr(t *testing.T) {
DownloadOutErr: true,
StreamOutErr: tc.requestStreams,
}
outChunks := tc.outChunks
if outChunks == nil {
outChunks = []string{"streaming", "-", "stdout"}
}
errChunks := tc.errChunks
if errChunks == nil {
errChunks = []string{"streaming", "-", "stderr"}
}
outContent := tc.outContent
if outContent == "" {
outContent = "streaming-stdout"
}
errContent := tc.errContent
if errContent == "" {
errContent = "streaming-stderr"
}
opts := []fakes.Option{
fakes.StdOut("stdout-blob"),
fakes.StdErr("stderr-blob"),
&fakes.LogStream{Name: "stdout-stream", Chunks: []string{"streaming", "-", "stdout"}},
&fakes.LogStream{Name: "stderr-stream", Chunks: []string{"streaming", "-", "stderr"}},
fakes.StdOut(outContent),
fakes.StdErr(errContent),
&fakes.LogStream{Name: "stdout-stream", Chunks: outChunks},
&fakes.LogStream{Name: "stderr-stream", Chunks: errChunks},
fakes.ExecutionCacheHit(tc.cached),
}
if tc.hasStdOutStream {
Expand Down
2 changes: 2 additions & 0 deletions go/pkg/tool/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ func (c *Client) ExecuteAction(ctx context.Context, actionDigest, actionRoot, ou
fmt.Printf("---------------\n")
fmt.Printf("Action digest: %v\n", ec.Metadata.ActionDigest.String())
fmt.Printf("Command digest: %v\n", ec.Metadata.CommandDigest.String())
fmt.Printf("Stdout digest: %v\n", ec.Metadata.StdoutDigest.String())
fmt.Printf("Stderr digest: %v\n", ec.Metadata.StderrDigest.String())
fmt.Printf("Number of Input Files: %v\n", ec.Metadata.InputFiles)
fmt.Printf("Number of Input Dirs: %v\n", ec.Metadata.InputDirectories)
fmt.Printf("Number of Output Files: %v\n", ec.Metadata.OutputFiles)
Expand Down

0 comments on commit 8ab3738

Please sign in to comment.