diff --git a/libs/filer/dbfs_client.go b/libs/filer/dbfs_client.go index 38e8f9f3f3..679e187980 100644 --- a/libs/filer/dbfs_client.go +++ b/libs/filer/dbfs_client.go @@ -1,11 +1,15 @@ package filer import ( + "bytes" "context" "errors" + "fmt" "io" "io/fs" + "mime/multipart" "net/http" + "os" "path" "slices" "sort" @@ -14,6 +18,7 @@ import ( "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/client" "github.com/databricks/databricks-sdk-go/service/files" ) @@ -63,22 +68,118 @@ func (info dbfsFileInfo) Sys() any { return info.fi } +// Interface to allow mocking of the Databricks API client. +type databricksClient interface { + Do(ctx context.Context, method, path string, headers map[string]string, + requestBody any, responseBody any, visitors ...func(*http.Request) error) error +} + // DbfsClient implements the [Filer] interface for the DBFS backend. type DbfsClient struct { workspaceClient *databricks.WorkspaceClient + apiClient databricksClient + // File operations will be relative to this path. root WorkspaceRootPath } func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) { + apiClient, err := client.New(w.Config) + if err != nil { + return nil, fmt.Errorf("failed to create API client: %w", err) + } + return &DbfsClient{ workspaceClient: w, + apiClient: apiClient, root: NewWorkspaceRootPath(root), }, nil } +func (w *DbfsClient) uploadUsingDbfsPutApi(ctx context.Context, path string, overwrite bool, file *os.File) error { + overwriteField := "False" + if overwrite { + overwriteField = "True" + } + + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + err := writer.WriteField("path", path) + if err != nil { + return err + } + err = writer.WriteField("overwrite", overwriteField) + if err != nil { + return err + } + contents, err := writer.CreateFormFile("contents", "") + if err != nil { + return err + } + + _, err = io.Copy(contents, file) + if err != nil { + return err + } + + err = writer.Close() + if err != nil { + return err + } + + // Request bodies of Content-Type multipart/form-data must are not supported by + // the Go SDK directly for DBFS. So we use the Do method directly. + return w.apiClient.Do(ctx, http.MethodPost, "/api/2.0/dbfs/put", map[string]string{ + "Content-Type": writer.FormDataContentType(), + }, buf.Bytes(), nil) +} + +func (w *DbfsClient) uploadUsingDbfsStreamingApi(ctx context.Context, path string, overwrite bool, reader io.Reader) error { + fileMode := files.FileModeWrite + if overwrite { + fileMode |= files.FileModeOverwrite + } + + handle, err := w.workspaceClient.Dbfs.Open(ctx, path, fileMode) + if err != nil { + var aerr *apierr.APIError + if !errors.As(err, &aerr) { + return err + } + + // This API returns a 400 if the file already exists. + if aerr.StatusCode == http.StatusBadRequest { + if aerr.ErrorCode == "RESOURCE_ALREADY_EXISTS" { + return FileAlreadyExistsError{path} + } + } + + return err + } + + _, err = io.Copy(handle, reader) + cerr := handle.Close() + if err == nil { + err = cerr + } + return err +} + +// TODO CONTINUE: +// 1. Write the unit tests that make sure the filer write method works correctly +// in either case. +// 2. Write a intergration test that asserts write continues works for big file +// uploads. Also test the overwrite flag in the integration test. +// We can change MaxDbfsUploadLimitForPutApi in the test to avoid creating +// massive test fixtures. + +// MaxUploadLimitForPutApi is the maximum size in bytes of a file that can be uploaded +// using the /dbfs/put API. If the file is larger than this limit, the streaming +// API (/dbfs/create and /dbfs/add-block) will be used instead. +var MaxDbfsUploadLimitForPutApi int64 = 2 * 1024 * 1024 + func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, mode ...WriteMode) error { absPath, err := w.root.Join(name) if err != nil { @@ -114,30 +215,27 @@ func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, m } } - handle, err := w.workspaceClient.Dbfs.Open(ctx, absPath, fileMode) - if err != nil { - var aerr *apierr.APIError - if !errors.As(err, &aerr) { - return err - } + localFile, ok := reader.(*os.File) - // This API returns a 400 if the file already exists. - if aerr.StatusCode == http.StatusBadRequest { - if aerr.ErrorCode == "RESOURCE_ALREADY_EXISTS" { - return FileAlreadyExistsError{absPath} - } - } + // If the source is not a local file, we'll always use the streaming API endpoint. + if !ok { + return w.uploadUsingDbfsStreamingApi(ctx, absPath, slices.Contains(mode, OverwriteIfExists), reader) + } - return err + stat, err := localFile.Stat() + if err != nil { + return fmt.Errorf("failed to stat file: %w", err) } - _, err = io.Copy(handle, reader) - cerr := handle.Close() - if err == nil { - err = cerr + // If the source is a local file, but is too large then we'll use the streaming API endpoint. + if stat.Size() > MaxDbfsUploadLimitForPutApi { + return w.uploadUsingDbfsStreamingApi(ctx, absPath, slices.Contains(mode, OverwriteIfExists), localFile) } - return err + // Use the /dbfs/put API when the file is on the local filesystem + // and is small enough. This is the most common case when users use the + // `databricks fs cp` command. + return w.uploadUsingDbfsPutApi(ctx, absPath, slices.Contains(mode, OverwriteIfExists), localFile) } func (w *DbfsClient) Read(ctx context.Context, name string) (io.ReadCloser, error) {