diff --git a/persist/filesystem.go b/persist/filesystem.go index a6b9abe..af8c251 100644 --- a/persist/filesystem.go +++ b/persist/filesystem.go @@ -5,13 +5,10 @@ import ( "os" "path/filepath" "time" - - "github.com/spf13/afero" ) // FsStore is a Store that uses the filesystem to store cache data type FsStore struct { - afs afero.Fs dir string useSafeKey bool } @@ -19,7 +16,6 @@ type FsStore struct { // NewFsStore creates a new FsStore, dir is the rood directory where all cached files will be stored func NewFsStore(dir string, useSafeKey bool) *FsStore { return &FsStore{ - afs: afero.NewOsFs(), dir: dir, useSafeKey: useSafeKey, } @@ -40,7 +36,7 @@ func (c *FsStore) Get(_ context.Context, key string) ([]byte, time.Time, error) return nil, time.Time{}, err } - raw, err := afero.ReadFile(c.afs, file) + raw, err := os.ReadFile(file) if err != nil { return nil, time.Time{}, err } @@ -51,8 +47,8 @@ func (c *FsStore) Get(_ context.Context, key string) ([]byte, time.Time, error) // Set writes or updates a file that matches the provided key in the stores root directory. The file will contain // the raw bytes passed in by val func (c *FsStore) Set(_ context.Context, key string, val []byte) error { - if _, err := c.afs.Stat(c.dir); os.IsNotExist(err) { - err := c.afs.MkdirAll(c.dir, 0750) + if _, err := os.Stat(c.dir); os.IsNotExist(err) { + err := os.MkdirAll(c.dir, 0750) if err != nil { return err } @@ -62,7 +58,7 @@ func (c *FsStore) Set(_ context.Context, key string, val []byte) error { key = SafeKey(key) } file := filepath.Join(c.dir, key) - err := afero.WriteFile(c.afs, file, val, 0666) + err := os.WriteFile(file, val, 0666) if err != nil { return err } diff --git a/persist/filesystem_test.go b/persist/filesystem_test.go new file mode 100644 index 0000000..6dfec9a --- /dev/null +++ b/persist/filesystem_test.go @@ -0,0 +1,206 @@ +package persist + +import ( + "context" + "os" + "path/filepath" + "reflect" + "testing" + "time" +) + +func TestFsStore_Get(t *testing.T) { + type fields struct { + dir string + useSafeKey bool + } + type args struct { + ctx context.Context + key string + } + tests := []struct { + name string + fields fields + args args + wantBytes []byte + wantTS time.Time + wantErr bool + }{ + { + "use safe key", + fields{ + dir: func() string { + dir := t.TempDir() + err := os.WriteFile(filepath.Join(dir, SafeKey("safe_key")), []byte(`test`), 0o0644) + if err != nil { + t.Error("failed to write test file", err) + } + + return dir + }(), + useSafeKey: true, + }, + args{ + ctx: context.Background(), + key: "safe_key", + }, + []byte(`test`), + time.Now(), + false, + }, + { + "file does not exist", + fields{ + dir: t.TempDir(), + useSafeKey: false, + }, + args{ + ctx: context.Background(), + key: "test_key", + }, + nil, + time.Time{}, + false, + }, + { + "permission error", + fields{ + dir: func() string { + dir := t.TempDir() + err := os.WriteFile(filepath.Join(dir, "test_key"), []byte(`test`), 0o0222) + if err != nil { + t.Error("failed to write test file", err) + } + return dir + }(), + }, + args{ + ctx: context.Background(), + key: "test_key", + }, + nil, + time.Time{}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &FsStore{ + dir: tt.fields.dir, + useSafeKey: tt.fields.useSafeKey, + } + gotBytes, gotTS, err := c.Get(tt.args.ctx, tt.args.key) + if (err != nil) != tt.wantErr { + t.Errorf("FsStore.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotBytes, tt.wantBytes) { + t.Errorf("FsStore.Get() gotBytes = %v, want %v", gotBytes, tt.wantBytes) + } + + // use a range so time stamps dont have to match exactly + if tt.wantTS.Before(gotTS.Add(-time.Millisecond*10)) || + tt.wantTS.After(gotTS.Add(time.Millisecond*10)) { + t.Errorf("FsStore.Get() gotTS = %v, want %v", gotTS, tt.wantTS) + } + }) + } +} + +func TestFsStore_Set(t *testing.T) { + type fields struct { + dir string + useSafeKey bool + } + type args struct { + ctx context.Context + key string + val []byte + } + tests := []struct { + name string + fields fields + args args + wantFile []byte + wantErr bool + }{ + { + "dir does not exist", + fields{ + dir: filepath.Join(t.TempDir(), "nested"), + useSafeKey: false, + }, + args{ + ctx: context.Background(), + key: "test_key", + val: []byte(`test`), + }, + []byte(`test`), + false, + }, + { + "use safe key", + fields{ + dir: t.TempDir(), + useSafeKey: true, + }, + args{ + ctx: context.Background(), + key: "safe_key", + val: []byte(`test`), + }, + []byte(`test`), + false, + }, + { + "permission denyed", + fields{ + dir: func() string { + dir := t.TempDir() + err := os.Chmod(dir, 0o0666) + if err != nil { + t.Error("FsStore.Set() failed to set up dir", err) + } + + return dir + }(), + useSafeKey: false, + }, + args{ + ctx: context.Background(), + key: "test_key", + val: []byte(`test`), + }, + nil, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &FsStore{ + dir: tt.fields.dir, + useSafeKey: tt.fields.useSafeKey, + } + if err := c.Set(tt.args.ctx, tt.args.key, tt.args.val); (err != nil) != tt.wantErr { + t.Errorf("FsStore.Set() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr { + return + } + + // check the file was written + key := tt.args.key + if c.useSafeKey { + key = SafeKey(key) + } + + got, err := os.ReadFile(filepath.Join(c.dir, key)) + if err != nil { + t.Error("FsStore.Set() failed to read cache file", err) + } + if !reflect.DeepEqual(got, tt.wantFile) { + t.Errorf("FsStore.Set() cache file = %s, wanted = %s", string(got), string(tt.wantFile)) + } + }) + } +}