diff --git a/internal/service/smesher.go b/internal/service/smesher.go index 56660e4..9024e7f 100644 --- a/internal/service/smesher.go +++ b/internal/service/smesher.go @@ -55,28 +55,13 @@ func (e *Service) CountSmesherRewards(ctx context.Context, smesherID string) (to return e.storage.CountSmesherRewards(ctx, smesherID) } -// TODO: optimize queries func (e *Service) getSmeshers(ctx context.Context, filter *bson.D, options *options.FindOptions) (smeshers []*model.Smesher, total int64, err error) { - atxs, err := e.storage.GetActivations(ctx, filter) - if err != nil { - return nil, 0, fmt.Errorf("error count smeshers: %w", err) - } - - smeshersList := make([]string, 0, len(atxs)) - var lastID string - for _, atx := range atxs { - if lastID != atx.SmesherId { - smeshersList = append(smeshersList, atx.SmesherId) - lastID = atx.SmesherId - } - } - - total, err = e.storage.CountSmeshers(ctx, &bson.D{{Key: "id", Value: bson.M{"$in": smeshersList}}}) + total, err = e.storage.CountEpochSmeshers(ctx, filter) if err != nil { return []*model.Smesher{}, 0, err } - smeshers, err = e.storage.GetSmeshers(ctx, &bson.D{{Key: "id", Value: bson.M{"$in": smeshersList}}}, options) + smeshers, err = e.storage.GetEpochSmeshers(ctx, filter, options) if err != nil { return nil, 0, fmt.Errorf("error load smeshers: %w", err) } diff --git a/internal/storage/storagereader/abstract.go b/internal/storage/storagereader/abstract.go index 542ff0a..a23171b 100644 --- a/internal/storage/storagereader/abstract.go +++ b/internal/storage/storagereader/abstract.go @@ -52,5 +52,7 @@ type StorageReader interface { CountSmeshers(ctx context.Context, query *bson.D, opts ...*options.CountOptions) (int64, error) GetSmeshers(ctx context.Context, query *bson.D, opts ...*options.FindOptions) ([]*model.Smesher, error) GetSmesher(ctx context.Context, smesherID string) (*model.Smesher, error) + CountEpochSmeshers(ctx context.Context, query *bson.D, opts ...*options.CountOptions) (int64, error) + GetEpochSmeshers(ctx context.Context, query *bson.D, opts ...*options.FindOptions) ([]*model.Smesher, error) CountSmesherRewards(ctx context.Context, smesherID string) (total, count int64, err error) } diff --git a/internal/storage/storagereader/smeshers.go b/internal/storage/storagereader/smeshers.go index 7d786df..de0c4bb 100644 --- a/internal/storage/storagereader/smeshers.go +++ b/internal/storage/storagereader/smeshers.go @@ -3,6 +3,7 @@ package storagereader import ( "context" "fmt" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -122,6 +123,103 @@ func (s *Reader) GetSmeshers(ctx context.Context, query *bson.D, opts ...*option return smeshers, nil } +// GetEpochSmeshers returns the smeshers for specific epoch +func (s *Reader) CountEpochSmeshers(ctx context.Context, query *bson.D, opts ...*options.CountOptions) (int64, error) { + pipeline := bson.A{ + bson.D{ + {"$match", query}, + }, + bson.D{ + {"$lookup", + bson.D{ + {"from", "smeshers"}, + {"localField", "smesher"}, + {"foreignField", "id"}, + {"as", "joinedData"}, + }, + }, + }, + bson.D{{"$unwind", bson.D{{"path", "$joinedData"}}}}, + bson.D{{"$replaceRoot", bson.D{{"newRoot", "$joinedData"}}}}, + bson.D{ + {"$group", + bson.D{ + {"_id", primitive.Null{}}, + {"total", bson.D{{"$sum", 1}}}, + }, + }, + }, + } + + cursor, err := s.db.Collection("activations").Aggregate(ctx, pipeline) + if err != nil { + return 0, fmt.Errorf("error get smeshers: %w", err) + } + + if !cursor.Next(ctx) { + return 0, nil + } + + doc := cursor.Current + return utils.GetAsInt64(doc.Lookup("total")), nil +} + +// GetEpochSmeshers returns the smeshers for specific epoch +func (s *Reader) GetEpochSmeshers(ctx context.Context, query *bson.D, opts ...*options.FindOptions) ([]*model.Smesher, error) { + skip := int64(0) + limit := int64(0) + if len(opts) > 0 { + if opts[0].Skip != nil { + skip = *opts[0].Skip + } + + if opts[0].Limit != nil { + limit = *opts[0].Limit + } + } + + pipeline := bson.A{ + bson.D{ + {"$match", query}, + }, + bson.D{ + {"$lookup", + bson.D{ + {"from", "smeshers"}, + {"localField", "smesher"}, + {"foreignField", "id"}, + {"as", "joinedData"}, + }, + }, + }, + bson.D{{"$unwind", bson.D{{"path", "$joinedData"}}}}, + bson.D{{"$addFields", bson.D{{"joinedData.atxLayer", "$layer"}}}}, + bson.D{{"$replaceRoot", bson.D{{"newRoot", "$joinedData"}}}}, + bson.D{{Key: "$sort", Value: bson.D{{Key: "atxLayer", Value: -1}}}}, + bson.D{{Key: "$skip", Value: skip}}, + } + + if limit > 0 { + pipeline = append(pipeline, bson.D{{Key: "$limit", Value: limit}}) + } + + cursor, err := s.db.Collection("activations").Aggregate(ctx, pipeline) + if err != nil { + return nil, fmt.Errorf("error get smeshers: %w", err) + } + + var smeshers []*model.Smesher + if err = cursor.All(ctx, &smeshers); err != nil { + return nil, fmt.Errorf("error decode smeshers: %w", err) + } + + for _, smesher := range smeshers { + smesher.Timestamp = s.GetLayerTimestamp(smesher.AtxLayer) + } + + return smeshers, nil +} + // GetSmesher returns the smesher matching the query. func (s *Reader) GetSmesher(ctx context.Context, smesherID string) (*model.Smesher, error) { matchStage := bson.D{{Key: "$match", Value: bson.D{{Key: "id", Value: smesherID}}}}