Skip to content

Commit

Permalink
Try #6074:
Browse files Browse the repository at this point in the history
  • Loading branch information
spacemesh-bors[bot] authored Jun 26, 2024
2 parents 568a5a1 + af2b7b7 commit 7d3677b
Show file tree
Hide file tree
Showing 18 changed files with 1,323 additions and 625 deletions.
2 changes: 0 additions & 2 deletions activation/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ func (h *Handler) determineVersion(msg []byte) (*types.AtxVersion, error) {

type opaqueAtx interface {
ID() types.ATXID
Published() types.EpochID
TotalNumUnits() uint32
}

func (h *Handler) decodeATX(msg []byte) (opaqueAtx, error) {
Expand Down
54 changes: 11 additions & 43 deletions activation/handler_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,47 +168,6 @@ func (h *HandlerV1) commitment(atx *wire.ActivationTxV1) (types.ATXID, error) {
return atxs.CommitmentATX(h.cdb, atx.SmesherID)
}

// Obtain the previous ATX for the given ATX.
// We need to decode it from the blob because we are interested in the true NumUnits value
// that was declared by the previous ATX and the `atxs` table only holds the effective NumUnits.
// However, in case of a golden ATX, the blob is not available and we fallback to fetching the ATX from the DB
// to use the effective num units.
func (h *HandlerV1) previous(ctx context.Context, atx *wire.ActivationTxV1) (*types.ActivationTx, error) {
var blob sql.Blob
v, err := atxs.LoadBlob(ctx, h.cdb, atx.PrevATXID[:], &blob)
if err != nil {
return nil, err
}

if len(blob.Bytes) == 0 {
// An empty blob indicates a golden ATX (after a checkpoint-recovery).
// Fallback to fetching it from the DB to get the effective NumUnits.
atx, err := atxs.Get(h.cdb, atx.PrevATXID)
if err != nil {
return nil, fmt.Errorf("fetching golden previous atx: %w", err)
}
return atx, nil
}
if v != types.AtxV1 {
return nil, fmt.Errorf("previous atx %s is not of version 1", atx.PrevATXID)
}

var prev wire.ActivationTxV1
if err := codec.Decode(blob.Bytes, &prev); err != nil {
return nil, fmt.Errorf("decoding previous atx: %w", err)
}
prev.SetID(atx.PrevATXID)
if prev.VRFNonce == nil {
nonce, err := atxs.NonceByID(h.cdb, prev.ID())
if err != nil {
return nil, fmt.Errorf("failed to get nonce of previous ATX %s: %w", prev.ID(), err)
}
prev.VRFNonce = (*uint64)(&nonce)
}

return wire.ActivationTxFromWireV1(&prev, blob.Bytes...), nil
}

func (h *HandlerV1) syntacticallyValidateDeps(
ctx context.Context,
atx *wire.ActivationTxV1,
Expand All @@ -224,14 +183,18 @@ func (h *HandlerV1) syntacticallyValidateDeps(
}
effectiveNumUnits = atx.NumUnits
} else {
previous, err := h.previous(ctx, atx)
previous, err := atxs.Get(h.cdb, atx.PrevATXID)
if err != nil {
return 0, 0, nil, fmt.Errorf("fetching previous atx %s: %w", atx.PrevATXID, err)
}
if err := h.validateNonInitialAtx(ctx, atx, previous, commitmentATX); err != nil {
return 0, 0, nil, err
}
effectiveNumUnits = min(previous.NumUnits, atx.NumUnits)
prevUnits, err := atxs.Units(h.cdb, atx.PrevATXID, atx.SmesherID)
if err != nil {
return 0, 0, nil, fmt.Errorf("fetching previous atx units: %w", err)
}
effectiveNumUnits = min(prevUnits, atx.NumUnits)
}

err = h.nipostValidator.PositioningAtx(atx.PositioningATXID, h.cdb, h.goldenATXID, atx.PublishEpoch)
Expand Down Expand Up @@ -591,6 +554,11 @@ func (h *HandlerV1) storeAtx(
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
err = atxs.SetUnits(tx, atx.ID(), map[types.NodeID]uint32{atx.SmesherID: watx.NumUnits})
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("set atx units: %w", err)
}

return nil
}); err != nil {
return nil, fmt.Errorf("store atx: %w", err)
Expand Down
121 changes: 32 additions & 89 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ func (h *HandlerV2) processATX(
SmesherID: watx.SmesherID,
AtxBlob: types.AtxBlob{Blob: blob, Version: types.AtxV2},
}

if watx.Initial == nil {
// FIXME: update to keep many previous ATXs to support merged ATXs
atx.PrevATXID = watx.PreviousATXs[0]
Expand All @@ -144,7 +145,7 @@ func (h *HandlerV2) processATX(
atx.SetID(watx.ID())
atx.SetReceived(received)

proof, err = h.storeAtx(ctx, atx, watx, marrying)
proof, err = h.storeAtx(ctx, atx, watx, marrying, parts.units)
if err != nil {
return nil, fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err)
}
Expand Down Expand Up @@ -313,86 +314,22 @@ func (h *HandlerV2) collectAtxDeps(atx *wire.ActivationTxV2) ([]types.Hash32, []
return maps.Keys(poetRefs), maps.Keys(filtered)
}

func (h *HandlerV2) previous(ctx context.Context, id types.ATXID) (opaqueAtx, error) {
var blob sql.Blob
version, err := atxs.LoadBlob(ctx, h.cdb, id[:], &blob)
if err != nil {
return nil, err
}

if len(blob.Bytes) == 0 {
// An empty blob indicates a golden ATX (after a checkpoint-recovery).
// Fallback to fetching it from the DB to get the effective NumUnits.
atx, err := atxs.Get(h.cdb, id)
if err != nil {
return nil, fmt.Errorf("fetching golden previous atx: %w", err)
}
return atx, nil
}

switch version {
case types.AtxV1:
var prev wire.ActivationTxV1
if err := codec.Decode(blob.Bytes, &prev); err != nil {
return nil, fmt.Errorf("decoding previous atx v1: %w", err)
}
return &prev, nil
case types.AtxV2:
var prev wire.ActivationTxV2
if err := codec.Decode(blob.Bytes, &prev); err != nil {
return nil, fmt.Errorf("decoding previous atx v2: %w", err)
}
return &prev, nil
}
return nil, fmt.Errorf("unexpected previous ATX version: %d", version)
}

// Validate the previous ATX for the given PoST and return the effective numunits.
func (h *HandlerV2) validatePreviousAtx(id types.NodeID, post *wire.SubPostV2, prevAtxs []opaqueAtx) (uint32, error) {
func (h *HandlerV2) validatePreviousAtx(
id types.NodeID,
post *wire.SubPostV2,
prevAtxs []*types.ActivationTx,
) (uint32, error) {
if post.PrevATXIndex >= uint32(len(prevAtxs)) {
return 0, fmt.Errorf("prevATXIndex out of bounds: %d > %d", post.PrevATXIndex, len(prevAtxs))
}
prev := prevAtxs[post.PrevATXIndex]

switch prev := prev.(type) {
case *types.ActivationTx:
// A golden ATX
// TODO: support merged golden ATX
if prev.SmesherID != id {
return 0, fmt.Errorf("prev golden ATX has different owner: %s (expected %s)", prev.SmesherID, id)
}
return min(prev.NumUnits, post.NumUnits), nil

case *wire.ActivationTxV1:
if prev.SmesherID != id {
return 0, fmt.Errorf("prev ATX V1 has different owner: %s (expected %s)", prev.SmesherID, id)
}
return min(prev.NumUnits, post.NumUnits), nil
case *wire.ActivationTxV2:
if prev.MarriageATX != nil {
// Previous is a merged ATX
// need to find out if the given ID was present in the previous ATX
_, idx, err := identities.MarriageInfo(h.cdb, id)
if err != nil {
return 0, fmt.Errorf("fetching marriage info for ID %s: %w", id, err)
}
for _, nipost := range prev.NiPosts {
for _, post := range nipost.Posts {
if post.MarriageIndex == uint32(idx) {
return min(post.NumUnits, post.NumUnits), nil
}
}
}
} else {
// Previous is a solo ATX
if prev.SmesherID == id {
return min(prev.NiPosts[0].Posts[0].NumUnits, post.NumUnits), nil
}
}

return 0, fmt.Errorf("previous ATX V2 doesn't contain %s", id)
prevUnits, err := atxs.Units(h.cdb, prev.ID(), id)
if err != nil {
return 0, fmt.Errorf("fetching previous atx %s units for ID %s: %w", prev.ID(), id, err)
}
return 0, fmt.Errorf("unexpected previous ATX type: %T", prev)

return min(prevUnits, post.NumUnits), nil
}

func (h *HandlerV2) validateCommitmentAtx(golden, commitmentAtxId types.ATXID, publish types.EpochID) error {
Expand Down Expand Up @@ -498,6 +435,7 @@ type atxParts struct {
ticks uint64
weight uint64
effectiveUnits uint32
units map[types.NodeID]uint32
}

type nipostSize struct {
Expand Down Expand Up @@ -556,23 +494,26 @@ func (h *HandlerV2) syntacticallyValidateDeps(
ctx context.Context,
atx *wire.ActivationTxV2,
) (*atxParts, *mwire.MalfeasanceProof, error) {
parts := atxParts{
units: make(map[types.NodeID]uint32),
}
if atx.Initial != nil {
if err := h.validateCommitmentAtx(h.goldenATXID, atx.Initial.CommitmentATX, atx.PublishEpoch); err != nil {
return nil, nil, fmt.Errorf("verifying commitment ATX: %w", err)
}
}

previousAtxs := make([]opaqueAtx, len(atx.PreviousATXs))
prevAtxs := make([]*types.ActivationTx, len(atx.PreviousATXs))
for i, prev := range atx.PreviousATXs {
prevAtx, err := h.previous(ctx, prev)
prevAtx, err := atxs.Get(h.cdb, prev)
if err != nil {
return nil, nil, fmt.Errorf("fetching previous atx: %w", err)
}
if prevAtx.Published() >= atx.PublishEpoch {
err := fmt.Errorf("previous atx is too new (%d >= %d) (%s) ", prevAtx.Published(), atx.PublishEpoch, prev)
if prevAtx.PublishEpoch >= atx.PublishEpoch {
err := fmt.Errorf("previous atx is too new (%d >= %d) (%s) ", prevAtx.PublishEpoch, atx.PublishEpoch, prev)
return nil, nil, err
}
previousAtxs[i] = prevAtx
prevAtxs[i] = prevAtx
}

equivocationSet, err := h.equivocationSet(atx)
Expand All @@ -594,7 +535,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
effectiveNumUnits := post.NumUnits
if atx.Initial == nil {
var err error
effectiveNumUnits, err = h.validatePreviousAtx(id, &post, previousAtxs)
effectiveNumUnits, err = h.validatePreviousAtx(id, &post, prevAtxs)
if err != nil {
return nil, nil, fmt.Errorf("validating previous atx: %w", err)
}
Expand Down Expand Up @@ -644,7 +585,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
nipostSizes[i].ticks = leaves / h.tickSize
}

totalEffectiveNumUnits, totalWeight, err := nipostSizes.sumUp()
parts.effectiveUnits, parts.weight, err = nipostSizes.sumUp()
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -689,15 +630,10 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if err != nil {
return nil, nil, fmt.Errorf("invalid post for ID %s: %w", id, err)
}
parts.units[id] = post.NumUnits
}
}

parts := &atxParts{
ticks: nipostSizes.minTicks(),
effectiveUnits: totalEffectiveNumUnits,
weight: totalWeight,
}

if atx.Initial == nil {
if smesherCommitment == nil {
return nil, nil, errors.New("ATX signer not present in merged ATX")
Expand All @@ -708,7 +644,9 @@ func (h *HandlerV2) syntacticallyValidateDeps(
}
}

return parts, nil, nil
parts.ticks = nipostSizes.minTicks()

return &parts, nil, nil
}

func (h *HandlerV2) checkMalicious(
Expand Down Expand Up @@ -768,6 +706,7 @@ func (h *HandlerV2) storeAtx(
atx *types.ActivationTx,
watx *wire.ActivationTxV2,
marrying []types.NodeID,
units map[types.NodeID]uint32,
) (*mwire.MalfeasanceProof, error) {
var (
malicious bool
Expand Down Expand Up @@ -799,6 +738,10 @@ func (h *HandlerV2) storeAtx(
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
err = atxs.SetUnits(tx, atx.ID(), units)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("set atx units: %w", err)
}
return nil
}); err != nil {
return nil, fmt.Errorf("store atx: %w", err)
Expand Down
Loading

0 comments on commit 7d3677b

Please sign in to comment.