diff --git a/dagstore.go b/dagstore.go index 38989ac..e73eb7b 100644 --- a/dagstore.go +++ b/dagstore.go @@ -11,9 +11,17 @@ import ( "github.com/filecoin-project/dagstore/mount" "github.com/filecoin-project/dagstore/shard" ds "github.com/ipfs/go-datastore" + "github.com/ipfs/go-datastore/namespace" + "github.com/ipfs/go-datastore/query" + dssync "github.com/ipfs/go-datastore/sync" logging "github.com/ipfs/go-log/v2" ) +var ( + // StoreNamespace is the namespace under which shard state will be persisted. + StoreNamespace = ds.NewKey("dagstore") +) + var log = logging.Logger("dagstore") var ( @@ -39,6 +47,7 @@ type DAGStore struct { shards map[shard.Key]*Shard config Config indices index.FullIndexRepo + store ds.Datastore // externalCh receives external tasks. externalCh chan *task @@ -75,9 +84,9 @@ type ShardResult struct { } type Config struct { - // ScrapRoot is the path to the scratch space, where local copies of - // remote mounts are saved. - ScratchSpaceDir string + // TransientsDir is the path to directory where local transient files will + // be created for remote mounts. + TransientsDir string // IndexDir is the path where indices are stored. IndexDir string @@ -85,18 +94,17 @@ type Config struct { // Datastore is the datastore where shard state will be persisted. Datastore ds.Datastore - // MountTypes are the recognized mount types, bound to their corresponding - // URL schemes. - MountTypes map[string]mount.Type + // MountRegistry contains the set of recognized mount types. + MountRegistry *mount.Registry } // NewDAGStore constructs a new DAG store with the supplied configuration. func NewDAGStore(cfg Config) (*DAGStore, error) { // validate and manage scratch root directory. - if cfg.ScratchSpaceDir == "" { + if cfg.TransientsDir == "" { return nil, fmt.Errorf("missing scratch area root path") } - if err := ensureDir(cfg.ScratchSpaceDir); err != nil { + if err := ensureDir(cfg.TransientsDir); err != nil { return nil, fmt.Errorf("failed to create scratch root dir: %w", err) } @@ -119,25 +127,23 @@ func NewDAGStore(cfg Config) (*DAGStore, error) { // handle the datastore. if cfg.Datastore == nil { log.Warnf("no datastore provided; falling back to in-mem datastore; shard state will not survive restarts") - cfg.Datastore = ds.NewMapDatastore() + cfg.Datastore = dssync.MutexWrap(ds.NewMapDatastore()) // TODO can probably remove mutex wrap, since access is single-threaded } - // create the registry and register all mount types. - mounts := mount.NewRegistry() - for scheme, typ := range cfg.MountTypes { - if err := mounts.Register(scheme, typ); err != nil { - return nil, fmt.Errorf("failed to register mount factory: %w", err) - } - } + // namespace all store operations. + cfg.Datastore = namespace.Wrap(cfg.Datastore, StoreNamespace) - // TODO: recover persisted shard state from the Datastore. + if cfg.MountRegistry == nil { + cfg.MountRegistry = mount.NewRegistry() + } ctx, cancel := context.WithCancel(context.Background()) dagst := &DAGStore{ - mounts: mounts, + mounts: cfg.MountRegistry, config: cfg, indices: indices, shards: make(map[shard.Key]*Shard), + store: cfg.Datastore, externalCh: make(chan *task, 128), // len=128, concurrent external tasks that can be queued up before exercising backpressure. internalCh: make(chan *task, 1), // len=1, because eventloop will only ever stage another internal event. completionCh: make(chan *task, 64), // len=64, hitting this limit will just make async tasks wait. @@ -146,6 +152,24 @@ func NewDAGStore(cfg Config) (*DAGStore, error) { cancelFn: cancel, } + if err := dagst.restoreState(); err != nil { + // TODO add a lenient mode. + return nil, fmt.Errorf("failed to restore dagstore state: %w", err) + } + + // reset in-progress states. + for _, s := range dagst.shards { + if s.state == ShardStateServing { + // no active acquirers at start. + s.state = ShardStateAvailable + } + if s.state == ShardStateInitializing { + // restart the registration. + s.state = ShardStateNew + _ = dagst.queueTask(&task{op: OpShardRegister, shard: s}, dagst.externalCh) + } + } + dagst.wg.Add(1) go dagst.control() @@ -157,7 +181,7 @@ func NewDAGStore(cfg Config) (*DAGStore, error) { type RegisterOpts struct { // ExistingTransient can be supplied when registering a shard to indicate that - // there's already an existing local transient local that can be used for + // there's already an existing local transient copy that can be used for // indexing. ExistingTransient string } @@ -174,7 +198,7 @@ func (d *DAGStore) RegisterShard(ctx context.Context, key shard.Key, mnt mount.M return fmt.Errorf("%s: %w", key.String(), ErrShardExists) } - upgraded, err := mount.Upgrade(mnt, opts.ExistingTransient) + upgraded, err := mount.Upgrade(mnt, d.config.TransientsDir, opts.ExistingTransient) if err != nil { d.lk.Unlock() return err @@ -184,6 +208,7 @@ func (d *DAGStore) RegisterShard(ctx context.Context, key shard.Key, mnt mount.M // add the shard to the shard catalogue, and drop the lock. s := &Shard{ + d: d, key: key, state: ShardStateNew, mount: upgraded, @@ -264,6 +289,7 @@ func (d *DAGStore) AllShardsInfo() AllShardsInfo { func (d *DAGStore) Close() error { d.cancelFn() d.wg.Wait() + _ = d.store.Sync(ds.Key{}) return nil } @@ -276,6 +302,25 @@ func (d *DAGStore) queueTask(tsk *task, ch chan<- *task) error { } } +func (d *DAGStore) restoreState() error { + results, err := d.store.Query(query.Query{}) + if err != nil { + return fmt.Errorf("failed to recover dagstore state from store: %w", err) + } + for { + res, ok := results.NextSync() + if !ok { + return nil + } + s := &Shard{d: d} + if err := s.UnmarshalJSON(res.Value); err != nil { + log.Warnf("failed to recover state of shard %s: %s; skipping", shard.KeyFromString(res.Key), err) + continue + } + d.shards[s.key] = s + } +} + // ensureDir checks whether the specified path is a directory, and if not it // attempts to create it. func ensureDir(path string) error { diff --git a/dagstore_control.go b/dagstore_control.go index 2d8d588..a17ac64 100644 --- a/dagstore_control.go +++ b/dagstore_control.go @@ -30,8 +30,16 @@ func (o OpType) String() string { func (d *DAGStore) control() { defer d.wg.Done() - tsk, err := d.consumeNext() - for ; err == nil; tsk, err = d.consumeNext() { + var ( + tsk *task + err error + ) + + for { + if tsk, err = d.consumeNext(); err != nil { + break + } + log.Debugw("processing task", "op", tsk.op, "shard", tsk.shard.key, "error", tsk.err) s := tsk.shard @@ -76,6 +84,7 @@ func (d *DAGStore) control() { s.state = ShardStateServing s.refs++ + go d.acquireAsync(tsk.ctx, w, s, s.mount) case OpShardRelease: @@ -128,8 +137,12 @@ func (d *DAGStore) control() { } - s.lk.Unlock() + // persist the current shard state. + if err := s.persist(d.config.Datastore); err != nil { // TODO maybe fail shard? + log.Warnw("failed to persist shard", "shard", s.key, "error", err) + } + s.lk.Unlock() } if err != context.Canceled { @@ -152,6 +165,6 @@ func (d *DAGStore) consumeNext() (tsk *task, error error) { case tsk = <-d.completionCh: return tsk, nil case <-d.ctx.Done(): - return // TODO drain and process before returning? + return nil, d.ctx.Err() // TODO drain and process before returning? } } diff --git a/dagstore_test.go b/dagstore_test.go index 9aafdde..3a77f92 100644 --- a/dagstore_test.go +++ b/dagstore_test.go @@ -3,7 +3,7 @@ package dagstore import ( "bytes" "context" - _ "embed" + "embed" "fmt" "testing" @@ -11,6 +11,8 @@ import ( "github.com/filecoin-project/dagstore/shard" "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" + dsq "github.com/ipfs/go-datastore/query" + dssync "github.com/ipfs/go-datastore/sync" logging "github.com/ipfs/go-log/v2" "github.com/ipld/go-car/v2" @@ -18,10 +20,16 @@ import ( "golang.org/x/sync/errgroup" ) +const ( + carv1path = "testdata/sample-v1.car" + carv2path = "testdata/sample-wrapped-v2.car" +) + var ( - //go:embed testdata/sample-v1.car + //go:embed testdata/* + testdata embed.FS + carv1 []byte - //go:embed testdata/sample-wrapped-v2.car carv2 []byte // rootCID is the root CID of the carv2 for testing. @@ -31,6 +39,17 @@ var ( func init() { _ = logging.SetLogLevel("dagstore", "DEBUG") + var err error + carv1, err = testdata.ReadFile(carv1path) + if err != nil { + panic(err) + } + + carv2, err = testdata.ReadFile(carv2path) + if err != nil { + panic(err) + } + reader, err := car.NewReader(bytes.NewReader(carv2)) if err != nil { panic(fmt.Errorf("failed to parse carv2: %w", err)) @@ -48,16 +67,16 @@ func init() { } func TestRegisterCarV1(t *testing.T) { - dir := t.TempDir() dagst, err := NewDAGStore(Config{ - ScratchSpaceDir: dir, - Datastore: datastore.NewMapDatastore(), + MountRegistry: testRegistry(t), + TransientsDir: t.TempDir(), + Datastore: datastore.NewMapDatastore(), }) require.NoError(t, err) ch := make(chan ShardResult, 1) k := shard.KeyFromString("foo") - err = dagst.RegisterShard(context.Background(), k, &mount.BytesMount{Bytes: carv1}, ch, RegisterOpts{}) + err = dagst.RegisterShard(context.Background(), k, &mount.FSMount{FS: testdata, Path: carv1path}, ch, RegisterOpts{}) require.NoError(t, err) res := <-ch @@ -75,16 +94,16 @@ func TestRegisterCarV1(t *testing.T) { } func TestRegisterCarV2(t *testing.T) { - dir := t.TempDir() dagst, err := NewDAGStore(Config{ - ScratchSpaceDir: dir, - Datastore: datastore.NewMapDatastore(), + MountRegistry: testRegistry(t), + TransientsDir: t.TempDir(), + Datastore: datastore.NewMapDatastore(), }) require.NoError(t, err) ch := make(chan ShardResult, 1) k := shard.KeyFromString("foo") - err = dagst.RegisterShard(context.Background(), k, &mount.BytesMount{Bytes: carv2}, ch, RegisterOpts{}) + err = dagst.RegisterShard(context.Background(), k, &mount.FSMount{FS: testdata, Path: carv2path}, ch, RegisterOpts{}) require.NoError(t, err) res := <-ch @@ -102,38 +121,21 @@ func TestRegisterCarV2(t *testing.T) { func TestRegisterConcurrentShards(t *testing.T) { run := func(t *testing.T, n int) { - dir := t.TempDir() + store := dssync.MutexWrap(datastore.NewMapDatastore()) dagst, err := NewDAGStore(Config{ - ScratchSpaceDir: dir, - Datastore: datastore.NewMapDatastore(), + MountRegistry: testRegistry(t), + TransientsDir: t.TempDir(), + Datastore: store, }) require.NoError(t, err) - grp, _ := errgroup.WithContext(context.Background()) - for i := 0; i < n; i++ { - i := i - grp.Go(func() error { - ch := make(chan ShardResult, 1) - k := shard.KeyFromString(fmt.Sprintf("shard-%d", i)) - err := dagst.RegisterShard(context.Background(), k, &mount.BytesMount{Bytes: carv2}, ch, RegisterOpts{}) - if err != nil { - return err - } - res := <-ch - return res.Error - }) - } - - require.NoError(t, grp.Wait()) - - info := dagst.AllShardsInfo() - require.Len(t, info, n) - for _, ss := range info { - require.Equal(t, ShardStateAvailable, ss.ShardState) - require.NoError(t, ss.Error) - } + registerShards(t, dagst, n) } + t.Run("1", func(t *testing.T) { run(t, 1) }) + t.Run("2", func(t *testing.T) { run(t, 2) }) + t.Run("4", func(t *testing.T) { run(t, 4) }) + t.Run("8", func(t *testing.T) { run(t, 8) }) t.Run("16", func(t *testing.T) { run(t, 16) }) t.Run("32", func(t *testing.T) { run(t, 32) }) t.Run("64", func(t *testing.T) { run(t, 64) }) @@ -142,10 +144,10 @@ func TestRegisterConcurrentShards(t *testing.T) { } func TestAcquireInexistentShard(t *testing.T) { - dir := t.TempDir() dagst, err := NewDAGStore(Config{ - ScratchSpaceDir: dir, - Datastore: datastore.NewMapDatastore(), + MountRegistry: testRegistry(t), + TransientsDir: t.TempDir(), + Datastore: datastore.NewMapDatastore(), }) require.NoError(t, err) @@ -157,17 +159,16 @@ func TestAcquireInexistentShard(t *testing.T) { func TestAcquireAfterRegisterWait(t *testing.T) { t.Skip("uncomment when https://github.com/ipfs/go-cid/issues/126#issuecomment-872364155 is fixed") - - dir := t.TempDir() dagst, err := NewDAGStore(Config{ - ScratchSpaceDir: dir, - Datastore: datastore.NewMapDatastore(), + MountRegistry: testRegistry(t), + TransientsDir: t.TempDir(), + Datastore: datastore.NewMapDatastore(), }) require.NoError(t, err) ch := make(chan ShardResult, 1) k := shard.KeyFromString("foo") - err = dagst.RegisterShard(context.Background(), k, &mount.BytesMount{Bytes: carv2}, ch, RegisterOpts{}) + err = dagst.RegisterShard(context.Background(), k, &mount.FSMount{FS: testdata, Path: carv2path}, ch, RegisterOpts{}) require.NoError(t, err) res := <-ch @@ -196,65 +197,80 @@ func TestAcquireAfterRegisterWait(t *testing.T) { } func TestConcurrentAcquires(t *testing.T) { - dir := t.TempDir() dagst, err := NewDAGStore(Config{ - ScratchSpaceDir: dir, - Datastore: datastore.NewMapDatastore(), + MountRegistry: testRegistry(t), + TransientsDir: t.TempDir(), }) require.NoError(t, err) ch := make(chan ShardResult, 1) k := shard.KeyFromString("foo") - err = dagst.RegisterShard(context.Background(), k, &mount.BytesMount{Bytes: carv2}, ch, RegisterOpts{}) + err = dagst.RegisterShard(context.Background(), k, &mount.FSMount{FS: testdata, Path: carv2path}, ch, RegisterOpts{}) require.NoError(t, err) res := <-ch require.NoError(t, res.Error) - run := func(t *testing.T, n int) { - grp, _ := errgroup.WithContext(context.Background()) - for i := 0; i < n; i++ { - grp.Go(func() error { - ch := make(chan ShardResult, 1) - err := dagst.AcquireShard(context.Background(), k, ch, AcquireOpts{}) - if err != nil { - return err - } - - res := <-ch - if res.Error != nil { - return res.Error - } - defer res.Accessor.Close() - - bs, err := res.Accessor.Blockstore() - if err != nil { - return err - } - - _, err = bs.Get(rootCID) - return err - }) - } - require.NoError(t, grp.Wait()) - - info := dagst.AllShardsInfo() - require.Len(t, info, 1) - for _, ss := range info { - require.Equal(t, ShardStateServing, ss.ShardState) - require.NoError(t, ss.Error) - } + t.Run("1", func(t *testing.T) { acquireShard(t, dagst, k, 1) }) + t.Run("2", func(t *testing.T) { acquireShard(t, dagst, k, 2) }) + t.Run("4", func(t *testing.T) { acquireShard(t, dagst, k, 4) }) + t.Run("8", func(t *testing.T) { acquireShard(t, dagst, k, 8) }) + t.Run("16", func(t *testing.T) { acquireShard(t, dagst, k, 16) }) + t.Run("32", func(t *testing.T) { acquireShard(t, dagst, k, 32) }) + t.Run("64", func(t *testing.T) { acquireShard(t, dagst, k, 64) }) + t.Run("128", func(t *testing.T) { acquireShard(t, dagst, k, 128) }) + t.Run("256", func(t *testing.T) { acquireShard(t, dagst, k, 256) }) + + info := dagst.AllShardsInfo() + require.Len(t, info, 1) + for _, ss := range info { + require.Equal(t, ShardStateServing, ss.ShardState) + require.NoError(t, ss.Error) } +} - t.Run("1", func(t *testing.T) { run(t, 1) }) - t.Run("2", func(t *testing.T) { run(t, 2) }) - t.Run("4", func(t *testing.T) { run(t, 4) }) - t.Run("8", func(t *testing.T) { run(t, 8) }) - t.Run("16", func(t *testing.T) { run(t, 16) }) - t.Run("32", func(t *testing.T) { run(t, 32) }) - t.Run("64", func(t *testing.T) { run(t, 64) }) - t.Run("128", func(t *testing.T) { run(t, 128) }) - t.Run("256", func(t *testing.T) { run(t, 256) }) +func TestRestartRestoresState(t *testing.T) { + dir := t.TempDir() + store := datastore.NewLogDatastore(dssync.MutexWrap(datastore.NewMapDatastore()), "trace") + dagst, err := NewDAGStore(Config{ + MountRegistry: testRegistry(t), + TransientsDir: dir, + Datastore: store, + }) + require.NoError(t, err) + + keys := registerShards(t, dagst, 100) + for _, k := range keys[0:20] { // acquire the first 20 keys. + acquireShard(t, dagst, k, 4) + } + + res, err := store.Query(dsq.Query{}) + require.NoError(t, err) + entries, err := res.Rest() + require.NoError(t, err) + require.Len(t, entries, 100) // we have 100 shards. + + // close the DAG store. + err = dagst.Close() + require.NoError(t, err) + + // create a new dagstore with the same datastore. + dagst, err = NewDAGStore(Config{ + MountRegistry: testRegistry(t), + TransientsDir: dir, + Datastore: store, + }) + require.NoError(t, err) + info := dagst.AllShardsInfo() + require.Len(t, info, 100) + for _, ss := range info { + require.Equal(t, ShardStateAvailable, ss.ShardState) + require.NoError(t, ss.Error) + } +} + +func TestRestartResumesRegistration(t *testing.T) { + t.Skip("TODO") } // TestBlockCallback tests that blocking a callback blocks the dispatcher @@ -262,3 +278,68 @@ func TestConcurrentAcquires(t *testing.T) { func TestBlockCallback(t *testing.T) { t.Skip("TODO") } + +// registerShards registers n shards concurrently, using the CARv2 mount. +func registerShards(t *testing.T, dagst *DAGStore, n int) (ret []shard.Key) { + grp, _ := errgroup.WithContext(context.Background()) + for i := 0; i < n; i++ { + k := shard.KeyFromString(fmt.Sprintf("shard-%d", i)) + grp.Go(func() error { + ch := make(chan ShardResult, 1) + err := dagst.RegisterShard(context.Background(), k, &mount.FSMount{FS: testdata, Path: carv2path}, ch, RegisterOpts{}) + if err != nil { + return err + } + res := <-ch + return res.Error + }) + ret = append(ret, k) + } + + require.NoError(t, grp.Wait()) + + info := dagst.AllShardsInfo() + require.Len(t, info, n) + for _, ss := range info { + require.Equal(t, ShardStateAvailable, ss.ShardState) + require.NoError(t, ss.Error) + } + return ret +} + +// acquireShard acquires the shard known by key `k` concurrently `n` times. +func acquireShard(t *testing.T, dagst *DAGStore, k shard.Key, n int) { + grp, _ := errgroup.WithContext(context.Background()) + for i := 0; i < n; i++ { + grp.Go(func() error { + ch := make(chan ShardResult, 1) + err := dagst.AcquireShard(context.Background(), k, ch, AcquireOpts{}) + if err != nil { + return err + } + + res := <-ch + if res.Error != nil { + return res.Error + } + defer res.Accessor.Close() + + bs, err := res.Accessor.Blockstore() + if err != nil { + return err + } + + _, err = bs.Get(rootCID) + return err + }) + } + + require.NoError(t, grp.Wait()) +} + +func testRegistry(t *testing.T) *mount.Registry { + r := mount.NewRegistry() + err := r.Register("fs", &mount.FSMount{FS: testdata}) + require.NoError(t, err) + return r +} diff --git a/gen/main.go b/gen/main.go new file mode 100644 index 0000000..97a5a46 --- /dev/null +++ b/gen/main.go @@ -0,0 +1,19 @@ +package main + +import ( + "fmt" + "os" + + "github.com/filecoin-project/dagstore" + gen "github.com/whyrusleeping/cbor-gen" +) + +func main() { + err := gen.WriteMapEncodersToFile("./shard_gen.go", "dagstore", + dagstore.PersistedShard{}, + ) + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/go.mod b/go.mod index 9502bb6..9cb853c 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/ipld/go-car/v2 v2.0.0-20210706083137-aa61149042cd github.com/mr-tron/base58 v1.2.0 github.com/stretchr/testify v1.7.0 + github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 ) diff --git a/mount/byte.go b/mount/bytes.go similarity index 52% rename from mount/byte.go rename to mount/bytes.go index 32a1e6f..01b269d 100644 --- a/mount/byte.go +++ b/mount/bytes.go @@ -7,13 +7,15 @@ import ( "net/url" ) +// BytesMount encloses a byte slice. It is mainly used for testing. The +// Upgrader passes through it. type BytesMount struct { Bytes []byte } var _ Mount = (*BytesMount)(nil) -func (b *BytesMount) Fetch(ctx context.Context) (Reader, error) { +func (b *BytesMount) Fetch(_ context.Context) (Reader, error) { r := bytes.NewReader(b.Bytes) return &NopCloser{ Reader: r, @@ -23,26 +25,36 @@ func (b *BytesMount) Fetch(ctx context.Context) (Reader, error) { } func (b *BytesMount) Info() Info { - u := &url.URL{ - Scheme: "memory", - Host: base64.StdEncoding.EncodeToString(b.Bytes), - } return Info{ Kind: KindLocal, - URL: u, AccessSequential: true, AccessSeek: true, AccessRandom: true, } } -func (b *BytesMount) Stat(ctx context.Context) (Stat, error) { +func (b *BytesMount) Stat(_ context.Context) (Stat, error) { return Stat{ Exists: true, Size: int64(len(b.Bytes)), }, nil } +func (b *BytesMount) Serialize() *url.URL { + return &url.URL{ + Host: base64.StdEncoding.EncodeToString(b.Bytes), + } +} + +func (b *BytesMount) Deserialize(u *url.URL) error { + decoded, err := base64.StdEncoding.DecodeString(u.Host) + if err != nil { + return err + } + b.Bytes = decoded + return nil +} + func (b *BytesMount) Close() error { b.Bytes = nil // release return nil diff --git a/mount/file.go b/mount/file.go index 58320a8..38279f8 100644 --- a/mount/file.go +++ b/mount/file.go @@ -18,10 +18,8 @@ func (f *FileMount) Fetch(_ context.Context) (Reader, error) { } func (f *FileMount) Info() Info { - u, _ := url.Parse(fmt.Sprintf("file://%s", f.Path)) return Info{ Kind: KindLocal, - URL: u, AccessRandom: true, AccessSeek: true, AccessSequential: true, @@ -39,6 +37,20 @@ func (f *FileMount) Stat(_ context.Context) (Stat, error) { }, err } +func (f *FileMount) Serialize() *url.URL { + return &url.URL{ + Host: f.Path, + } +} + +func (f *FileMount) Deserialize(u *url.URL) error { + if u.Host == "" { + return fmt.Errorf("invalid host") + } + f.Path = u.Host + return nil +} + func (f *FileMount) Close() error { return nil } diff --git a/mount/file_test.go b/mount/file_test.go index 23c3c14..f9a4145 100644 --- a/mount/file_test.go +++ b/mount/file_test.go @@ -35,10 +35,12 @@ func TestFileMount(t *testing.T) { require.True(t, stat.Exists) require.EqualValues(t, size, stat.Size) + // check URL. + require.Equal(t, mnt.Path, mnt.Serialize().Host) + info := mnt.Info() require.True(t, info.AccessSequential && info.AccessSeek && info.AccessRandom) // all flags true require.Equal(t, KindLocal, info.Kind) - require.Equal(t, "file://"+mnt.Path, info.URL.String()) reader, err := mnt.Fetch(context.Background()) require.NoError(t, err) diff --git a/mount/fs.go b/mount/fs.go new file mode 100644 index 0000000..68fb997 --- /dev/null +++ b/mount/fs.go @@ -0,0 +1,83 @@ +package mount + +import ( + "context" + "errors" + "fmt" + "io/fs" + "net/url" +) + +// FSMount is a mount that opens the file indicated by Path, using the +// provided fs.FS. Given that io/fs does not support random access patterns, +// this mount requires an Upgrade. It is suitable for testing. +type FSMount struct { + FS fs.FS + Path string +} + +var _ Mount = (*FSMount)(nil) + +func (f *FSMount) Close() error { + return nil // TODO +} + +func (f *FSMount) Fetch(_ context.Context) (Reader, error) { + file, err := f.FS.Open(f.Path) + return &fsReader{File: file}, err +} + +func (f *FSMount) Info() Info { + return Info{ + Kind: KindLocal, + AccessSequential: true, + AccessSeek: false, + AccessRandom: false, + } +} + +func (f *FSMount) Stat(_ context.Context) (Stat, error) { + st, err := fs.Stat(f.FS, f.Path) + if errors.Is(err, fs.ErrNotExist) { + return Stat{Exists: false, Size: 0}, nil + } + if err != nil { + return Stat{}, err + } + return Stat{ + Exists: true, + Size: st.Size(), + }, nil +} + +func (f *FSMount) Serialize() *url.URL { + u := new(url.URL) + if st, err := fs.Stat(f.FS, f.Path); err != nil { + u.Host = "irrecoverable" + } else { + u.Host = st.Name() + } + return u +} + +func (f *FSMount) Deserialize(u *url.URL) error { + if u.Host == "irrecoverable" || u.Host == "" { + return fmt.Errorf("invalid host") + } + f.Path = u.Host + return nil +} + +type fsReader struct { + fs.File +} + +var _ Reader = (*fsReader)(nil) + +func (f *fsReader) ReadAt(_ []byte, _ int64) (n int, err error) { + return 0, ErrRandomAccessUnsupported +} + +func (f *fsReader) Seek(_ int64, _ int) (int64, error) { + return 0, ErrSeekUnsupported +} diff --git a/mount/mount.go b/mount/mount.go index 1873316..2de490b 100644 --- a/mount/mount.go +++ b/mount/mount.go @@ -8,9 +8,13 @@ import ( ) var ( - // ErrNotSeekable is returned when FetchSeek is called on a mount that is + // ErrSeekUnsupported is returned when Seek is called on a mount that is // not seekable. - ErrNotSeekable = errors.New("mount not seekable") + ErrSeekUnsupported = errors.New("mount does not support seek") + + // ErrRandomAccessUnsupported is returned when ReadAt is called on a mount + // that does not support random access. + ErrRandomAccessUnsupported = errors.New("mount does not support random access") ) // Kind is an enum describing the source of a Mount. @@ -71,6 +75,13 @@ type Mount interface { // Stat describes the underlying resource. Stat(ctx context.Context) (Stat, error) + + // Serialize returns a canonical URL that can be used to revive the Mount + // after a restart. + Serialize() *url.URL + + // Deserialize configures this Mount from the specified URL. + Deserialize(*url.URL) error } // Reader is a fully-featured Reader returned from MountTypes. It is the @@ -87,8 +98,6 @@ type Reader interface { type Info struct { // Kind indicates the kind of mount. Kind Kind - // URL is the canonical URL this Mount serializes to. - URL *url.URL // TODO convert to bitfield AccessSequential bool @@ -104,13 +113,6 @@ type Stat struct { Size int64 } -// Type represents a mount type, and allows instantiation of a Mount from its -// URL serialized form. -type Type interface { - // Parse initializes the mount from a URL. - Parse(u *url.URL) (Mount, error) -} - type NopCloser struct { io.Reader io.ReaderAt diff --git a/mount/registry.go b/mount/registry.go index 0d5d075..17495bc 100644 --- a/mount/registry.go +++ b/mount/registry.go @@ -2,61 +2,113 @@ package mount import ( "errors" + "fmt" "net/url" + "reflect" "sync" - - "golang.org/x/xerrors" ) -// ErrUnrecognizedScheme is returned by Instantiate() when attempting to -// initialize a Mount with an unrecognized URL scheme. -var ErrUnrecognizedScheme = errors.New("unrecognized mount scheme") +var ( + // ErrUnrecognizedScheme is returned by Instantiate() when attempting to + // initialize a Mount with an unrecognized URL scheme. + ErrUnrecognizedScheme = errors.New("unrecognized mount scheme") + + // ErrUnrecognizedType is returned by Encode() when attempting to + // represent a Mount whose type has not been registered. + ErrUnrecognizedType = errors.New("unrecognized mount type") +) // Registry is a registry of Mount factories known to the DAG store. type Registry struct { - lk sync.RWMutex - m map[string]Type + lk sync.RWMutex + byScheme map[string]Mount + byType map[reflect.Type]string } // NewRegistry constructs a blank registry. func NewRegistry() *Registry { - return &Registry{} + return &Registry{byScheme: map[string]Mount{}, byType: map[reflect.Type]string{}} } -// Register adds a new Mount factory to the registry and maps it against the given URL scheme. -func (r *Registry) Register(scheme string, mount Type) error { +// Register adds a new mount type to the registry under the specified scheme. +// +// The supplied Mount is used as a template to create new instances. +// +// This means that the provided Mount can contain environmental configuration +// that will be automatically carried over to all instances. +func (r *Registry) Register(scheme string, template Mount) error { r.lk.Lock() defer r.lk.Unlock() - if _, ok := r.m[scheme]; ok { - return xerrors.New("mount factory already registered for scheme") + if _, ok := r.byScheme[scheme]; ok { + return fmt.Errorf("mount already registered for scheme: %s", scheme) } - r.m[scheme] = mount + if _, ok := r.byType[reflect.TypeOf(template)]; ok { + return fmt.Errorf("mount already registered for type: %T", template) + } + + r.byScheme[scheme] = template + r.byType[reflect.TypeOf(template)] = scheme return nil } // Instantiate instantiates a new Mount from a URL. // -// It looks up the Mount factory in the registry based on the URL scheme, -// calls Parse() on it to get a Mount and returns the Mount. +// It looks up the Mount template in the registry based on the URL scheme, +// creates a copy, and calls Deserialize() on it with the supplied URL before +// returning. // -// If it errors, it propagates the error returned by the Mount factory. If the scheme -// is not recognized, it returns ErrUnrecognizedScheme. +// It propagates any error returned by the Mount#Deserialize method. +// If the scheme is not recognized, it returns ErrUnrecognizedScheme. func (r *Registry) Instantiate(u *url.URL) (Mount, error) { r.lk.RLock() defer r.lk.RUnlock() - mft, ok := r.m[u.Scheme] + template, ok := r.byScheme[u.Scheme] if !ok { - return nil, ErrUnrecognizedScheme + return nil, fmt.Errorf("%w: %s", ErrUnrecognizedScheme, u.Scheme) + } + + instance := clone(template) + if err := instance.Deserialize(u); err != nil { + return nil, fmt.Errorf("failed to instantiate mount with url %s into type %T: %w", u.String(), template, err) } + return instance, nil +} - mt, err := mft.Parse(u) - if err != nil { - return nil, xerrors.Errorf("failed to instantiate mount with factory.Parse: %w", err) +// Represent returns the URL representation of a Mount, using the scheme that +// was registered for that type of mount. +func (r *Registry) Represent(mount Mount) (*url.URL, error) { + r.lk.RLock() + defer r.lk.RUnlock() + // special-case the upgrader, as it's transparent. + if up, ok := mount.(*Upgrader); ok { + mount = up.underlying } - return mt, nil + scheme, ok := r.byType[reflect.TypeOf(mount)] + if !ok { + return nil, fmt.Errorf("failed to represent mount with type %T: %w", mount, ErrUnrecognizedType) + } + + u := mount.Serialize() + u.Scheme = scheme + return u, nil +} + +// clone clones m1 into m2, casting it back to a Mount. It is only able to deal +// with pointer types that implement Mount. +func clone(m1 Mount) (m2 Mount) { + m2obj := reflect.New(reflect.TypeOf(m1).Elem()) + m1val := reflect.ValueOf(m1).Elem() + m2val := m2obj.Elem() + for i := 0; i < m1val.NumField(); i++ { + field := m2val.Field(i) + if field.CanSet() { + field.Set(m1val.Field(i)) + } + } + return m2obj.Interface().(Mount) } diff --git a/mount/registry_test.go b/mount/registry_test.go index 3f7ed71..3ab0403 100644 --- a/mount/registry_test.go +++ b/mount/registry_test.go @@ -15,71 +15,61 @@ import ( var _ Mount = (*MockMount)(nil) type MockMount struct { - Val string - URL *url.URL - StatSize int64 + Val string + URL *url.URL + StatSize int64 + Templated string } -func (m *MockMount) Close() error { - panic("implement me") -} - -func (m *MockMount) Fetch(_ context.Context) (Reader, error) { - r := strings.NewReader(m.Val) - return &NopCloser{Reader: r, ReaderAt: r, Seeker: r}, nil -} - -func (m *MockMount) Info() Info { - return Info{ - Kind: KindRemote, - URL: m.URL, +func (m *MockMount) Serialize() *url.URL { + u := &url.URL{ + Scheme: "aaa", // random, will get replaced + Host: m.Val, } + u.Query().Set("size", strconv.FormatInt(m.StatSize, 10)) + return u } -func (m *MockMount) Stat(_ context.Context) (Stat, error) { - return Stat{ - Exists: true, - Size: m.StatSize, - }, nil -} - -type MockMountFactory1 struct{} - -func (mf *MockMountFactory1) Parse(u *url.URL) (Mount, error) { +func (m *MockMount) Deserialize(u *url.URL) error { vals, err := url.ParseQuery(u.RawQuery) if err != nil { - return nil, err + return err } statSize, err := strconv.ParseInt(vals["size"][0], 10, 64) if err != nil { - return nil, err + return err } - return &MockMount{ - Val: u.Host, - URL: u, - StatSize: statSize, - }, nil + if v, err := strconv.ParseBool(vals["timestwo"][0]); err != nil { + return err + } else if v { + statSize *= 2 + } + + m.Val = u.Host + m.URL = u + m.StatSize = statSize + return nil } -type MockMountFactory2 struct{} +func (m *MockMount) Close() error { + panic("implement me") +} -func (mf *MockMountFactory2) Parse(u *url.URL) (Mount, error) { - vals, err := url.ParseQuery(u.RawQuery) - if err != nil { - return nil, err - } +func (m *MockMount) Fetch(_ context.Context) (Reader, error) { + r := strings.NewReader(m.Val) + return &NopCloser{Reader: r, ReaderAt: r, Seeker: r}, nil +} - statSize, err := strconv.ParseInt(vals["size"][0], 10, 64) - if err != nil { - return nil, err - } +func (m *MockMount) Info() Info { + return Info{Kind: KindRemote} +} - return &MockMount{ - Val: u.Host, - URL: u, - StatSize: statSize * 2, +func (m *MockMount) Stat(_ context.Context) (Stat, error) { + return Stat{ + Exists: true, + Size: m.StatSize, }, nil } @@ -88,30 +78,34 @@ func TestRegistry(t *testing.T) { m2StatSize := uint64(5678) // create a registry - r := Registry{ - m: make(map[string]Type), - } + r := NewRegistry() + + type ( + MockMount1 struct{ MockMount } + MockMount2 struct{ MockMount } + MockMount3 struct{ MockMount } + ) // create & register mock mount factory 1 - u := fmt.Sprintf("http://host1:123?size=%d", m1StatSize) - url, err := url.Parse(u) + url1 := fmt.Sprintf("http://host1:123?size=%d×two=false", m1StatSize) + u1, err := url.Parse(url1) require.NoError(t, err) - m1 := &MockMountFactory1{} - require.NoError(t, r.Register("http", m1)) - // // register same scheme again -> fails - require.Error(t, r.Register("http", m1)) + require.NoError(t, r.Register("http", new(MockMount1))) + // register same scheme again -> fails + require.Error(t, r.Register("http", new(MockMount2))) + // register same type again -> fails, different scheme + require.Error(t, r.Register("http2", new(MockMount1))) // create and register mock mount factory 2 - url2 := fmt.Sprintf("ftp://host2:1234?size=%d", m2StatSize) - u2, err := url.Parse(url2) + url2 := fmt.Sprintf("ftp://host2:1234?size=%d×two=true", m2StatSize) + u2, err := u1.Parse(url2) require.NoError(t, err) - m2 := &MockMountFactory2{} - require.NoError(t, r.Register("ftp", m2)) + require.NoError(t, r.Register("ftp", new(MockMount3))) // instantiate mount 1 and verify state is constructed correctly - m, err := r.Instantiate(url) + m, err := r.Instantiate(u1) require.NoError(t, err) - require.Equal(t, url.Host, fetchAndReadAll(t, m)) + require.Equal(t, u1.Host, fetchAndReadAll(t, m)) stat, err := m.Stat(context.TODO()) require.NoError(t, err) require.EqualValues(t, m1StatSize, stat.Size) @@ -125,6 +119,45 @@ func TestRegistry(t *testing.T) { require.EqualValues(t, m2StatSize*2, stat.Size) } +func TestRegistryHonoursTemplate(t *testing.T) { + r := NewRegistry() + + template := &MockMount{Templated: "give me proof"} + err := r.Register("foo", template) + require.NoError(t, err) + + u, err := url.Parse("foo://bang?size=100×two=false") + require.NoError(t, err) + + m, err := r.Instantiate(u) + require.NoError(t, err) + + require.Equal(t, "give me proof", m.(*MockMount).Templated) +} + +func TestRegistryRecognizedType(t *testing.T) { + type ( + MockMount1 struct{ MockMount } + MockMount2 struct{ MockMount } + MockMount3 struct{ MockMount } + ) + + // register all three types under different schemes + r := NewRegistry() + err := r.Register("mount1", new(MockMount1)) + require.NoError(t, err) + err = r.Register("mount2", new(MockMount2)) + require.NoError(t, err) + err = r.Register("mount3", new(MockMount3)) + require.NoError(t, err) + + // now attempt to encode an instance of MockMount2 + u, err := r.Represent(&MockMount2{}) + require.NoError(t, err) + + require.Equal(t, "mount2", u.Scheme) +} + func fetchAndReadAll(t *testing.T, m Mount) string { rd, err := m.Fetch(context.Background()) require.NoError(t, err) diff --git a/mount/upgrader.go b/mount/upgrader.go index 0e191e4..759afc9 100644 --- a/mount/upgrader.go +++ b/mount/upgrader.go @@ -4,39 +4,41 @@ import ( "context" "fmt" "io" + "net/url" "os" "sync" ) -// Upgrader serves as a bridge to upgrade a Mount into one with full-featured -// Reader capabilities. It does this by caching a transient copy as file if -// the original mount type does not support all access patterns. +// Upgrader is a bridge to upgrade any Mount into one with full-featured +// Reader capabilities, whether the original mount is of remote or local kind. +// It does this by managing a local transient copy. // -// If the underlying mount is already fully-featured, the Upgrader is -// acts as a noop. -// -// TODO perform refcounts so we can track inactive transient files. -// TODO provide root directory for temp files (or better: temp file factory function). +// If the underlying mount is fully-featured, the Upgrader has no effect, and +// simply passes through to the underlying mount. type Upgrader struct { underlying Mount passthrough bool lk sync.Mutex transient string - // TODO refs int + rootdir string } var _ Mount = (*Upgrader)(nil) -// Upgrade constructs a new Upgrader for the underlying Mount. -func Upgrade(underlying Mount, initial string) (*Upgrader, error) { - ret := &Upgrader{underlying: underlying} +// Upgrade constructs a new Upgrader for the underlying Mount. If provided, it +// will reuse the file in path `initial` as the initial transient copy. Whenever +// a new transient copy has to be created, it will be created under `rootdir`. +func Upgrade(underlying Mount, rootdir, initial string) (*Upgrader, error) { + ret := &Upgrader{underlying: underlying, rootdir: rootdir} + if ret.rootdir == "" { + ret.rootdir = os.TempDir() // use the OS' default temp dir. + } - info := underlying.Info() - if !info.AccessSequential { + switch info := underlying.Info(); { + case !info.AccessSequential: return nil, fmt.Errorf("underlying mount must support sequential access") - } - if info.AccessSeek && info.AccessRandom { + case info.AccessSeek && info.AccessRandom: ret.passthrough = true return ret, nil } @@ -77,7 +79,6 @@ func (u *Upgrader) Fetch(ctx context.Context) (Reader, error) { func (u *Upgrader) Info() Info { return Info{ Kind: KindLocal, - URL: u.underlying.Info().URL, AccessSequential: true, AccessSeek: true, AccessRandom: true, @@ -94,6 +95,23 @@ func (u *Upgrader) Stat(ctx context.Context) (Stat, error) { return u.underlying.Stat(ctx) } +// TransientPath returns the local path of the transient file. If the Upgrader +// is passthrough, the return value will be "". +func (u *Upgrader) TransientPath() string { + u.lk.Lock() + defer u.lk.Unlock() + + return u.transient +} + +func (u *Upgrader) Serialize() *url.URL { + return u.underlying.Serialize() +} + +func (u *Upgrader) Deserialize(url *url.URL) error { + return u.underlying.Deserialize(url) +} + func (u *Upgrader) Close() error { panic("implement me") } @@ -102,7 +120,7 @@ func (u *Upgrader) refetch(ctx context.Context) error { if u.transient != "" { _ = os.Remove(u.transient) } - file, err := os.CreateTemp("dagstore", "transient") + file, err := os.CreateTemp(u.rootdir, "transient") if err != nil { return fmt.Errorf("failed to create temporary file: %w", err) } @@ -132,30 +150,22 @@ func (u *Upgrader) refetch(ctx context.Context) error { return nil } -// -// // Clean removes any transient assets. -// func (m *Upgrader) Clean() error { -// s.Lock() -// defer s.Unlock() -// -// // check if we have readers and refuse to clean if so. -// if s.refs != 0 { -// return fmt.Errorf("failed to delete shard: %w", ErrShardInUse) -// } -// -// if s.transient == nil { -// // nothing to do. -// return nil -// } -// -// // we can safely remove the transient. -// _ = s.transient.Close() -// err := os.Remove(s.transient.Name()) -// if err == nil { -// s.transient = nil -// } -// -// // refresh the availability. -// _, _ = s.refreshAvailability(nil) -// return nil -// } +// DeleteTransient deletes the transient associated with this Upgrader, if +// one exists. It is the caller's responsibility to ensure the transient is +// not in use. +func (u *Upgrader) DeleteTransient() error { + u.lk.Lock() + defer u.lk.Unlock() + + if u.transient == "" { + return nil // nothing to do. + } + + err := os.Remove(u.transient) + if err != nil { + return err + } + + u.transient = "" + return nil +} diff --git a/shard.go b/shard.go index 777a062..8dd1be1 100644 --- a/shard.go +++ b/shard.go @@ -11,9 +11,8 @@ import ( // waiter encapsulates a context passed by the user, and the channel they want // the result returned to. type waiter struct { - // context governing the operation if this is an external op. - ctx context.Context - outCh chan ShardResult + ctx context.Context // governs the op if it's external + outCh chan ShardResult // to send back the result } func (w waiter) deliver(res *ShardResult) { @@ -27,18 +26,21 @@ func (w waiter) deliver(res *ShardResult) { type Shard struct { lk sync.RWMutex - // IMMUTABLE FIELDS: safe to read outside the event loop without a lock. - key shard.Key - mount *mount.Upgrader + // IMMUTABLE FIELDS + // safe to read outside the event loop without a lock + d *DAGStore // backreference + key shard.Key // persisted in PersistedShard.Key + mount *mount.Upgrader // persisted in PersistedShard.URL (underlying) - // MUTABLE FIELDS: cannot read/write outside event loop. - state ShardState - err error // populated if shard state is errored. - indexed bool + // MUTABLE FIELDS + // cannot read/write outside event loop. + state ShardState // persisted in PersistedShard.State + err error // populated if shard state is errored; persisted in PersistedShard.Error + indexed bool // persisted in PersistedShard.Indexed wRegister *waiter wAcquire []*waiter wDestroy *waiter - refs uint32 // count of DAG accessors currently open + refs uint32 // number of DAG accessors currently open } diff --git a/shard_gen.go b/shard_gen.go new file mode 100644 index 0000000..d740c5a --- /dev/null +++ b/shard_gen.go @@ -0,0 +1,230 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + +package dagstore + +import ( + "fmt" + "io" + "math" + + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" +) + +var _ = xerrors.Errorf + +func (t *PersistedShard) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{165}); err != nil { + return err + } + + // t.Key (string) (string) + if len("Key") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Key\" was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Key")))); err != nil { + return err + } + if _, err := w.Write([]byte("Key")); err != nil { + return err + } + + if len(t.Key) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Key was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.Key)))); err != nil { + return err + } + if _, err := w.Write([]byte(t.Key)); err != nil { + return err + } + + // t.URL (string) (string) + if len("URL") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"URL\" was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("URL")))); err != nil { + return err + } + if _, err := w.Write([]byte("URL")); err != nil { + return err + } + + if len(t.URL) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.URL was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.URL)))); err != nil { + return err + } + if _, err := w.Write([]byte(t.URL)); err != nil { + return err + } + + // t.State (dagstore.ShardState) (uint8) + if len("State") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"State\" was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("State")))); err != nil { + return err + } + if _, err := w.Write([]byte("State")); err != nil { + return err + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, uint64(t.State))); err != nil { + return err + } + + // t.Indexed (bool) (bool) + if len("Indexed") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Indexed\" was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("Indexed")))); err != nil { + return err + } + if _, err := w.Write([]byte("Indexed")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.Indexed); err != nil { + return err + } + + // t.TransientPath (string) (string) + if len("TransientPath") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"TransientPath\" was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len("TransientPath")))); err != nil { + return err + } + if _, err := w.Write([]byte("TransientPath")); err != nil { + return err + } + + if len(t.TransientPath) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.TransientPath was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.TransientPath)))); err != nil { + return err + } + if _, err := w.Write([]byte(t.TransientPath)); err != nil { + return err + } + return nil +} + +func (t *PersistedShard) UnmarshalCBOR(r io.Reader) error { + br := cbg.GetPeeker(r) + + maj, extra, err := cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("PersistedShard: map struct too large (%d)", extra) + } + + var name string + n := extra + + for i := uint64(0); i < n; i++ { + + { + sval, err := cbg.ReadString(br) + if err != nil { + return err + } + + name = string(sval) + } + + switch name { + // t.Key (string) (string) + case "Key": + + { + sval, err := cbg.ReadString(br) + if err != nil { + return err + } + + t.Key = string(sval) + } + // t.URL (string) (string) + case "URL": + + { + sval, err := cbg.ReadString(br) + if err != nil { + return err + } + + t.URL = string(sval) + } + // t.State (dagstore.ShardState) (uint8) + case "State": + + maj, extra, err = cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint8 field") + } + if extra > math.MaxUint8 { + return fmt.Errorf("integer in input was too large for uint8 field") + } + t.State = ShardState(extra) + // t.Indexed (bool) (bool) + case "Indexed": + + maj, extra, err = cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.Indexed = false + case 21: + t.Indexed = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.TransientPath (string) (string) + case "TransientPath": + + { + sval, err := cbg.ReadString(br) + if err != nil { + return err + } + + t.TransientPath = string(sval) + } + + default: + return fmt.Errorf("unknown struct field %d: '%s'", i, name) + } + } + + return nil +} diff --git a/shard_persist.go b/shard_persist.go new file mode 100644 index 0000000..5eb0871 --- /dev/null +++ b/shard_persist.go @@ -0,0 +1,97 @@ +package dagstore + +import ( + "encoding/json" + "errors" + "fmt" + "net/url" + + "github.com/filecoin-project/dagstore/mount" + "github.com/filecoin-project/dagstore/shard" + ds "github.com/ipfs/go-datastore" +) + +// PersistedShard is the persistent representation of the Shard. +type PersistedShard struct { + Key string `json:"k"` + URL string `json:"u"` + State ShardState `json:"s"` + Indexed bool `json:"i"` + TransientPath string `json:"t"` + Error string `json:"e"` +} + +// MarshalJSON returns a serialized representation of the state. It must be +// called from inside the event loop, as it accesses mutable state, or under a +// shard read lock. +func (s *Shard) MarshalJSON() ([]byte, error) { + u, err := s.d.mounts.Represent(s.mount) + if err != nil { + return nil, fmt.Errorf("failed to encode mount: %w", err) + } + ps := PersistedShard{ + Key: s.key.String(), + URL: u.String(), + State: s.state, + Indexed: s.indexed, + TransientPath: s.mount.TransientPath(), + } + if s.err != nil { + ps.Error = s.err.Error() + } + + return json.Marshal(ps) + // TODO maybe switch to CBOR, as it's probably faster. + // var b bytes.Buffer + // if err := ps.MarshalCBOR(&b); err != nil { + // return nil, err + // } + // return b.Bytes(), nil +} + +func (s *Shard) UnmarshalJSON(b []byte) error { + var ps PersistedShard // TODO try to avoid this alloc by marshalling/unmarshalling directly. + if err := json.Unmarshal(b, &ps); err != nil { + return err + } + + // restore basics. + s.key = shard.KeyFromString(ps.Key) + s.state = ps.State + if ps.Error != "" { + s.err = errors.New(ps.Error) + } + + // restore mount. + u, err := url.Parse(ps.URL) + if err != nil { + return fmt.Errorf("failed to parse mount URL: %w", err) + } + mnt, err := s.d.mounts.Instantiate(u) + if err != nil { + return fmt.Errorf("failed to instantiate mount from URL: %w", err) + } + s.mount, err = mount.Upgrade(mnt, s.d.config.TransientsDir, ps.TransientPath) + if err != nil { + return fmt.Errorf("failed to apply mount upgrader: %w", err) + } + + s.indexed = ps.Indexed + return nil +} + +func (s *Shard) persist(store ds.Datastore) error { + ps, err := s.MarshalJSON() + if err != nil { + return fmt.Errorf("failed to serialize shard state: %w", err) + } + // assuming that the datastore is namespaced if need be. + k := ds.NewKey(s.key.String()) + if err := store.Put(k, ps); err != nil { + return fmt.Errorf("failed to put shard state: %w", err) + } + if err := store.Sync(ds.Key{}); err != nil { + return fmt.Errorf("failed to sync shard state to store: %w", err) + } + return nil +}