Skip to content

Commit

Permalink
Add wrapper for tiledb_vfs_ls_recursive. (#363)
Browse files Browse the repository at this point in the history
* Implement `ls_recursive`.

* Fix memory leaks.

* Add a test.

* Export the callback type.
  • Loading branch information
teo-tsirpanis authored Jan 10, 2025
1 parent 0ed55b8 commit 1ed8dab
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 2 deletions.
9 changes: 9 additions & 0 deletions clibrary.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ int32_t _vfs_ls(
return ret_val;
}

int32_t _vfs_ls_recursive(
tiledb_ctx_t* ctx,
tiledb_vfs_t* vfs,
const char* path,
void* data) {
int32_t ret_val = tiledb_vfs_ls_recursive(ctx, vfs, path, vfsLsRecursive, data);
return ret_val;
}

int32_t _tiledb_object_walk(
tiledb_ctx_t* ctx,
const char* path,
Expand Down
9 changes: 8 additions & 1 deletion clibrary.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#ifndef CLIBRARY_H
#define CLIBRARY_H

#include <tiledb/tiledb.h>
#include <tiledb/tiledb_experimental.h>

typedef const char cchar_t;

int32_t numOfFragmentsInPath(cchar_t* path, void *data);
int32_t vfsLs(cchar_t* path, void *data);
int32_t vfsLsRecursive(cchar_t* path, size_t path_len, uint64_t size, void *data);
int32_t objectsInPath(cchar_t* path, tiledb_object_t objectType, void *data);

int32_t _num_of_folders_in_path(
Expand All @@ -21,6 +22,12 @@ int32_t _vfs_ls(
const char* path,
void* data);

int32_t _vfs_ls_recursive(
tiledb_ctx_t* ctx,
tiledb_vfs_t* vfs,
const char* path,
void* data);

int32_t _tiledb_object_walk(
tiledb_ctx_t* ctx,
const char* path,
Expand Down
55 changes: 54 additions & 1 deletion vfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ func (v *VFS) NumOfFragmentsInPath(path string) (int, error) {
Vfs: v,
}
data := pointer.Save(&numOfFragmentsData)
defer C.free(data)
defer pointer.Unref(data)

ret := C._num_of_folders_in_path(v.context.tiledbContext, v.tiledbVFS, cpath, data)

Expand Down Expand Up @@ -635,6 +635,7 @@ func (v *VFS) List(path string) ([]string, []string, error) {
Vfs: v,
}
data := pointer.Save(&folderData)
defer pointer.Unref(data)

ret := C._vfs_ls(v.context.tiledbContext, v.tiledbVFS, cpath, data)
if ret != C.TILEDB_OK {
Expand All @@ -643,3 +644,55 @@ func (v *VFS) List(path string) ([]string, []string, error) {

return folderData.Folders, folderData.Files, nil
}

// VisitRecursiveCallback gets called by VFS.VisitRecursive. It returns whether visiting should
// continue, and maybe an error to propagate to the caller. If err is not nil, visiting always
// stops.
type VisitRecursiveCallback = func(path string, size uint64) (doContinue bool, err error)

// visitRecursiveState contains the state of a call to VisitRecursive.
type visitRecursiveState struct {
callback VisitRecursiveCallback
lastError error
}

//export vfsLsRecursive
func vfsLsRecursive(path *C.cchar_t, path_len C.size_t, size C.uint64_t, data unsafe.Pointer) int32 {
state := pointer.Restore(data).(*visitRecursiveState)

if path_len > math.MaxInt {
state.lastError = errors.New("path is too long")
return 0
}

doContinue, err := state.callback(C.GoStringN(path, C.int(path_len)), uint64(size))

if err != nil || !doContinue {
// Save error to return to the user.
state.lastError = err
return 0
}

return 1
}

// VisitRecursive calls a function for every file in a path recursively.
// This function returns if the listing ends, or if the callback returns false or an error.
func (v *VFS) VisitRecursive(path string, callback VisitRecursiveCallback) error {
cpath := C.CString(path)
defer C.free(unsafe.Pointer(cpath))

state := &visitRecursiveState{
callback: callback,
lastError: nil,
}
data := pointer.Save(state)
defer pointer.Unref(data)

ret := C._vfs_ls_recursive(v.context.tiledbContext, v.tiledbVFS, cpath, data)
if ret != C.TILEDB_OK {
return fmt.Errorf("error in recursively listing path %s: %w", path, v.context.LastError())
}

return state.lastError
}
74 changes: 74 additions & 0 deletions vfs_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package tiledb

import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
"slices"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -254,6 +256,78 @@ func TestVFSList(t *testing.T) {
tmpFilePath2, "file://" + tmpFilePath3}, fileList)
}

// TestVFSList validates vfs VisitRecursive operation is successful
func TestVFSVisitRecursive(t *testing.T) {
config, err := NewConfig()
require.NoError(t, err)

context, err := NewContext(config)
require.NoError(t, err)

vfs, err := NewVFS(context, config)
require.NoError(t, err)

tmpPath := filepath.Join(t.TempDir(), "somedir")
tmpPath2 := filepath.Join(tmpPath, "subdir")
tmpPath3 := filepath.Join(tmpPath, "subdir2")

tmpFilePath := filepath.Join(tmpPath, "somefile")
tmpFilePath2 := filepath.Join(tmpPath, "somefile2")
tmpFilePath3 := filepath.Join(tmpPath, "somefile3")

// Create directories
require.NoError(t, vfs.CreateDir(tmpPath))
require.NoError(t, vfs.CreateDir(tmpPath2))
require.NoError(t, vfs.CreateDir(tmpPath3))

// Create Files
createFile(t, vfs, tmpFilePath)
createFile(t, vfs, tmpFilePath2)
createFile(t, vfs, tmpFilePath3)

var fileList []string
err = vfs.VisitRecursive(tmpPath, func(path string, size uint64) (bool, error) {
// Do not use require inside the callback because panicing might have unforeseen consequences.
fileExists, err := vfs.IsFile(path)
if err != nil {
return false, err
}
if !fileExists {
dirExists, err := vfs.IsDir(path)
if err != nil {
return false, err
}
if !dirExists {
return false, fmt.Errorf("%s does not exist neither as a file nor as a directory", path)
}
} else {
if size != 3 {
return false, fmt.Errorf("file %s has unexpected size (%d)", path, size)
}
fileList = append(fileList, path)
}
return true, nil
})
require.NoError(t, err)
slices.Sort(fileList)
assert.EqualValues(t, []string{"file://" + tmpFilePath, "file://" +
tmpFilePath2, "file://" + tmpFilePath3}, fileList)

expectedErr := errors.New("dummy")
err = vfs.VisitRecursive(tmpPath, func(path string, size uint64) (bool, error) {
return false, expectedErr
})
assert.Equal(t, expectedErr, err)

count := 0
err = vfs.VisitRecursive(tmpPath, func(path string, size uint64) (bool, error) {
count++
return count < 2, nil
})
require.NoError(t, err)
assert.Equal(t, 2, count)
}

func createFile(t testing.TB, vfs *VFS, path string) {
t.Helper()
require.NoError(t, vfs.Touch(path))
Expand Down

0 comments on commit 1ed8dab

Please sign in to comment.