From fd44931ae4595e5a594d9b927c66be7b648bd122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robin=20Br=C3=A4mer?= Date: Thu, 28 Dec 2023 22:01:53 +0100 Subject: [PATCH] Reduce extensive error checking (#266) --- pkg/edition/java/proto/codec/decoder.go | 10 +- pkg/edition/java/proto/codec/encoder.go | 4 +- pkg/edition/java/proto/packet/joingame.go | 341 +++++----------------- pkg/edition/java/proto/packet/login.go | 25 +- pkg/edition/java/proto/util/preader.go | 116 ++++++++ pkg/edition/java/proto/util/pwriter.go | 131 +++++++++ 6 files changed, 335 insertions(+), 292 deletions(-) create mode 100644 pkg/edition/java/proto/util/preader.go create mode 100644 pkg/edition/java/proto/util/pwriter.go diff --git a/pkg/edition/java/proto/codec/decoder.go b/pkg/edition/java/proto/codec/decoder.go index 3c7d0226..47e280a6 100644 --- a/pkg/edition/java/proto/codec/decoder.go +++ b/pkg/edition/java/proto/codec/decoder.go @@ -226,9 +226,13 @@ func (d *Decoder) decodePayload(p []byte) (ctx *proto.PacketContext, err error) } // Packet is known, decode data into it. - if err = ctx.Packet.Decode(ctx, payload); err != nil { - if err == io.EOF { // payload was too short or decoder has a bug - err = io.ErrUnexpectedEOF + err = util.RecoverFunc(func() error { + return ctx.Packet.Decode(ctx, payload) + }) + if err != nil { + if errors.Is(err, io.EOF) { + // payload was too short or packet decoder has a bug + err = errors.Join(err, io.ErrUnexpectedEOF) } return ctx, errs.NewSilentErr("error decoding packet (type: %T, id: %s, protocol: %s, direction: %s): %w", ctx.Packet, ctx.PacketID, ctx.Protocol, ctx.Direction, err) diff --git a/pkg/edition/java/proto/codec/encoder.go b/pkg/edition/java/proto/codec/encoder.go index e869db16..8768c3f8 100644 --- a/pkg/edition/java/proto/codec/encoder.go +++ b/pkg/edition/java/proto/codec/encoder.go @@ -90,7 +90,9 @@ func (e *Encoder) WritePacket(packet proto.Packet) (n int, err error) { Payload: nil, } - if err = packet.Encode(ctx, buf); err != nil { + if err = util.RecoverFunc(func() error { + return packet.Encode(ctx, buf) + }); err != nil { return } diff --git a/pkg/edition/java/proto/packet/joingame.go b/pkg/edition/java/proto/packet/joingame.go index 51636ecc..fe7bd0fc 100644 --- a/pkg/edition/java/proto/packet/joingame.go +++ b/pkg/edition/java/proto/packet/joingame.go @@ -46,40 +46,23 @@ type DeathPosition struct { } func (d *DeathPosition) encode(wr io.Writer) error { - err := util.WriteBool(wr, d != nil) - if err != nil { - return err - } + w := util.PanicWriter(wr) + w.Bool(d != nil) if d != nil { - err = util.WriteString(wr, d.Key) - if err != nil { - return err - } - err = util.WriteInt64(wr, d.Value) - if err != nil { - return err - } + w.String(d.Key) + w.Int64(d.Value) } return nil } func decodeDeathPosition(rd io.Reader) (*DeathPosition, error) { - ok, err := util.ReadBool(rd) - if err != nil { - return nil, err - } - if !ok { + r := util.PanicReader(rd) + if !r.Ok() { return nil, nil } dp := new(DeathPosition) - dp.Key, err = util.ReadString(rd) - if err != nil { - return nil, err - } - dp.Value, err = util.ReadInt64(rd) - if err != nil { - return nil, err - } + r.String(&dp.Key) + r.Int64(&dp.Value) return dp, nil } @@ -103,192 +86,89 @@ func (j *JoinGame) Encode(c *proto.PacketContext, wr io.Writer) error { } func (j *JoinGame) encode116Up(c *proto.PacketContext, wr io.Writer) error { - err := util.WriteInt(wr, j.EntityID) - if err != nil { - return err - } + w := util.PanicWriter(wr) + w.Int(j.EntityID) if c.Protocol.GreaterEqual(version.Minecraft_1_16_2) { - err = util.WriteBool(wr, j.Hardcore) - if err != nil { - return err - } - err = util.WriteByte(wr, byte(j.Gamemode)) - if err != nil { - return err - } + w.Bool(j.Hardcore) + w.Byte(byte(j.Gamemode)) } else { b := byte(j.Gamemode) if j.Hardcore { b = byte(j.Gamemode) | 0x8 } - err = util.WriteByte(wr, b) - if err != nil { - return err - } - } - err = util.WriteByte(wr, byte(j.PreviousGamemode)) - if err != nil { - return err - } - - err = util.WriteStrings(wr, j.LevelNames) - if err != nil { - return err - } - err = j.Registry.Write(wr) - if err != nil { - return err + w.Byte(b) } + w.Byte(byte(j.PreviousGamemode)) + w.Strings(j.LevelNames) + w.NBT(j.Registry) if c.Protocol.GreaterEqual(version.Minecraft_1_16_2) && c.Protocol.Lower(version.Minecraft_1_19) { - err = j.CurrentDimensionData.Write(wr) - if err != nil { - return err - } - err = util.WriteString(wr, j.DimensionInfo.RegistryIdentifier) - if err != nil { - return err - } + w.NBT(j.CurrentDimensionData) + w.String(j.DimensionInfo.RegistryIdentifier) } else { - err = util.WriteString(wr, j.DimensionInfo.RegistryIdentifier) - if err != nil { - return err - } + w.String(j.DimensionInfo.RegistryIdentifier) if j.DimensionInfo.LevelName == nil { return errors.New("dimension info level name must not be nil") } - err = util.WriteString(wr, *j.DimensionInfo.LevelName) - if err != nil { - return err - } - } - - err = util.WriteInt64(wr, j.PartialHashedSeed) - if err != nil { - return err + w.String(*j.DimensionInfo.LevelName) } + w.Int64(j.PartialHashedSeed) if c.Protocol.GreaterEqual(version.Minecraft_1_16_2) { - err = util.WriteVarInt(wr, j.MaxPlayers) - if err != nil { - return err - } + w.VarInt(j.MaxPlayers) } else { - err = util.WriteByte(wr, byte(j.MaxPlayers)) - if err != nil { - return err - } - } - - err = util.WriteVarInt(wr, j.ViewDistance) - if err != nil { - return err + w.Byte(byte(j.MaxPlayers)) } + w.VarInt(j.ViewDistance) if c.Protocol.GreaterEqual(version.Minecraft_1_18) { - err = util.WriteVarInt(wr, j.SimulationDistance) - if err != nil { - return err - } - } - - err = util.WriteBool(wr, j.ReducedDebugInfo) - if err != nil { - return err - } - err = util.WriteBool(wr, j.ShowRespawnScreen) - - if err != nil { - return err + w.VarInt(j.SimulationDistance) } - err = util.WriteBool(wr, j.DimensionInfo.DebugType) - if err != nil { - return err - } - err = util.WriteBool(wr, j.DimensionInfo.Flat) - if err != nil { - return err - } - + w.Bool(j.ReducedDebugInfo) + w.Bool(j.ShowRespawnScreen) + w.Bool(j.DimensionInfo.DebugType) + w.Bool(j.DimensionInfo.Flat) if c.Protocol.GreaterEqual(version.Minecraft_1_19) { err = j.LastDeathPosition.encode(wr) if err != nil { return err } } - if c.Protocol.GreaterEqual(version.Minecraft_1_20) { - err = util.WriteVarInt(wr, j.PortalCooldown) - if err != nil { - return err - } + w.VarInt(j.PortalCooldown) } - return nil } func (j *JoinGame) encodeLegacy(c *proto.PacketContext, wr io.Writer) error { - err := util.WriteInt32(wr, int32(j.EntityID)) - if err != nil { - return err - } + w := util.PanicWriter(wr) + w.Int(j.EntityID) b := byte(j.Gamemode) if j.Hardcore { b = byte(j.Gamemode) | 0x8 } - err = util.WriteByte(wr, b) - if err != nil { - return err - } + w.Byte(b) if c.Protocol.GreaterEqual(version.Minecraft_1_9_1) { - err = util.WriteInt32(wr, int32(j.Dimension)) - if err != nil { - return err - } + w.Int(j.Dimension) } else { - err = util.WriteByte(wr, byte(j.Dimension)) - if err != nil { - return err - } + w.Byte(byte(j.Dimension)) } if c.Protocol.LowerEqual(version.Minecraft_1_13_2) { - err = util.WriteByte(wr, byte(j.Difficulty)) - if err != nil { - return err - } + w.Byte(byte(j.Difficulty)) } if c.Protocol.GreaterEqual(version.Minecraft_1_15) { - err = util.WriteInt64(wr, j.PartialHashedSeed) - if err != nil { - return err - } - } - err = util.WriteByte(wr, byte(j.MaxPlayers)) - if err != nil { - return err + w.Int64(j.PartialHashedSeed) } + w.Byte(byte(j.MaxPlayers)) if j.LevelType == nil { return errors.New("no level type specified") } - err = util.WriteString(wr, *j.LevelType) - if err != nil { - return err - } + w.String(*j.LevelType) if c.Protocol.GreaterEqual(version.Minecraft_1_14) { - err = util.WriteVarInt(wr, j.ViewDistance) - if err != nil { - return err - } + w.VarInt(j.ViewDistance) } if c.Protocol.GreaterEqual(version.Minecraft_1_8) { - err = util.WriteBool(wr, j.ReducedDebugInfo) - if err != nil { - return err - } + w.Bool(j.ReducedDebugInfo) } - if c.Protocol.GreaterEqual(version.Minecraft_1_15) { - err = util.WriteBool(wr, j.ShowRespawnScreen) - if err != nil { - return err - } + w.Bool(j.ShowRespawnScreen) } return nil } @@ -410,67 +290,38 @@ func (j *JoinGame) Decode(c *proto.PacketContext, rd io.Reader) (err error) { } func (j *JoinGame) decodeLegacy(c *proto.PacketContext, rd io.Reader) (err error) { - j.EntityID, err = util.ReadInt(rd) - if err != nil { - return err - } + r := util.PanicReader(rd) + r.Int(&j.EntityID) if err = j.readGamemode(rd); err != nil { return err } j.Hardcore = (j.Gamemode & 0x08) != 0 j.Gamemode &= ^0x08 // bitwise complement if c.Protocol.GreaterEqual(version.Minecraft_1_9_1) { - j.Dimension, err = util.ReadInt(rd) - if err != nil { - return err - } + r.Int(&j.Dimension) } else { - d, err := util.ReadByte(rd) - if err != nil { - return err - } - j.Dimension = int(d) + j.Dimension = int(util.PReadByteVal(rd)) } if c.Protocol.LowerEqual(version.Minecraft_1_13_2) { - difficulty, err := util.ReadByte(rd) - if err != nil { - return err - } - j.Difficulty = int16(difficulty) + j.Difficulty = int16(util.PReadByteVal(rd)) } if c.Protocol.GreaterEqual(version.Minecraft_1_15) { - j.PartialHashedSeed, err = util.ReadInt64(rd) - if err != nil { - return err - } - } - maxPlayers, err := util.ReadByte(rd) - j.MaxPlayers = int(maxPlayers) - if err != nil { - return err + r.Int64(&j.PartialHashedSeed) } + j.MaxPlayers = int(util.PReadByteVal(rd)) lt, err := util.ReadStringMax(rd, 16) if err != nil { return err } j.LevelType = < if c.Protocol.GreaterEqual(version.Minecraft_1_14) { - j.ViewDistance, err = util.ReadVarInt(rd) - if err != nil { - return err - } + r.VarInt(&j.ViewDistance) } if c.Protocol.GreaterEqual(version.Minecraft_1_8) { - j.ReducedDebugInfo, err = util.ReadBool(rd) - if err != nil { - return err - } + r.Bool(&j.ReducedDebugInfo) } if c.Protocol.GreaterEqual(version.Minecraft_1_15) { - j.ShowRespawnScreen, err = util.ReadBool(rd) - if err != nil { - return err - } + r.Bool(&j.ShowRespawnScreen) } return nil } @@ -482,15 +333,10 @@ func (j *JoinGame) readGamemode(rd io.Reader) (err error) { } func (j *JoinGame) decode116Up(c *proto.PacketContext, rd io.Reader) (err error) { - j.EntityID, err = util.ReadInt(rd) - if err != nil { - return err - } + r := util.PanicReader(rd) + r.Int(&j.EntityID) if c.Protocol.GreaterEqual(version.Minecraft_1_16_2) { - j.Hardcore, err = util.ReadBool(rd) - if err != nil { - return err - } + r.Bool(&j.Hardcore) if err = j.readGamemode(rd); err != nil { return err } @@ -501,16 +347,9 @@ func (j *JoinGame) decode116Up(c *proto.PacketContext, rd io.Reader) (err error) j.Hardcore = (j.Gamemode & 0x08) != 0 j.Gamemode &= ^0x08 // bitwise complement } - previousGamemode, err := util.ReadByte(rd) - if err != nil { - return err - } - j.PreviousGamemode = int16(previousGamemode) + j.PreviousGamemode = int16(util.PReadByteVal(rd)) - j.LevelNames, err = util.ReadStringArray(rd) - if err != nil { - return err - } + r.Strings(&j.LevelNames) nbtDecoder := util.NewNBTDecoder(rd) j.Registry, err = util.DecodeNBT(nbtDecoder) if err != nil { @@ -524,65 +363,28 @@ func (j *JoinGame) decode116Up(c *proto.PacketContext, rd io.Reader) (err error) if err != nil { return err } - dimensionIdentifier, err = util.ReadString(rd) - if err != nil { - return err - } + r.String(&dimensionIdentifier) } else { - dimensionIdentifier, err = util.ReadString(rd) - if err != nil { - return err - } - levelName, err = util.ReadString(rd) - if err != nil { - return err - } + r.String(&dimensionIdentifier) + r.String(&levelName) } - j.PartialHashedSeed, err = util.ReadInt64(rd) - if err != nil { - return err - } + r.Int64(&j.PartialHashedSeed) if c.Protocol.GreaterEqual(version.Minecraft_1_16_2) { - j.MaxPlayers, err = util.ReadVarInt(rd) - if err != nil { - return err - } + r.VarInt(&j.MaxPlayers) } else { - maxPlayers, err := util.ReadByte(rd) - j.MaxPlayers = int(maxPlayers) - if err != nil { - return err - } + j.MaxPlayers = int(util.PReadByteVal(rd)) } - j.ViewDistance, err = util.ReadVarInt(rd) - if err != nil { - return err - } + r.VarInt(&j.ViewDistance) if c.Protocol.GreaterEqual(version.Minecraft_1_18) { - j.SimulationDistance, err = util.ReadVarInt(rd) - if err != nil { - return err - } - } - j.ReducedDebugInfo, err = util.ReadBool(rd) - if err != nil { - return err - } - j.ShowRespawnScreen, err = util.ReadBool(rd) - if err != nil { - return err + r.VarInt(&j.SimulationDistance) } + r.Bool(&j.ReducedDebugInfo) + r.Bool(&j.ShowRespawnScreen) - debug, err := util.ReadBool(rd) - if err != nil { - return err - } - flat, err := util.ReadBool(rd) - if err != nil { - return err - } + debug := r.Ok() + flat := r.Ok() j.DimensionInfo = &DimensionInfo{ RegistryIdentifier: dimensionIdentifier, LevelName: &levelName, @@ -599,10 +401,7 @@ func (j *JoinGame) decode116Up(c *proto.PacketContext, rd io.Reader) (err error) } if c.Protocol.GreaterEqual(version.Minecraft_1_20) { - j.PortalCooldown, err = util.ReadVarInt(rd) - if err != nil { - return err - } + r.VarInt(&j.PortalCooldown) } return nil } diff --git a/pkg/edition/java/proto/packet/login.go b/pkg/edition/java/proto/packet/login.go index 5b6fdf43..373bcb5a 100644 --- a/pkg/edition/java/proto/packet/login.go +++ b/pkg/edition/java/proto/packet/login.go @@ -360,26 +360,17 @@ type LoginPluginMessage struct { } func (l *LoginPluginMessage) Encode(_ *proto.PacketContext, wr io.Writer) error { - err := util.WriteVarInt(wr, l.ID) - if err != nil { - return err - } - err = util.WriteString(wr, l.Channel) - if err != nil { - return err - } - return util.WriteBytes(wr, l.Data) + w := util.PanicWriter(wr) + w.VarInt(l.ID) + w.String(l.Channel) + w.Bytes(l.Data) + return nil } func (l *LoginPluginMessage) Decode(_ *proto.PacketContext, rd io.Reader) (err error) { - l.ID, err = util.ReadVarInt(rd) - if err != nil { - return err - } - l.Channel, err = util.ReadString(rd) - if err != nil { - return err - } + r := util.PanicReader(rd) + r.VarInt(&l.ID) + r.String(&l.Channel) l.Data, err = util.ReadBytes(rd) if errors.Is(err, io.EOF) { // Ignore if we couldn't read data diff --git a/pkg/edition/java/proto/util/preader.go b/pkg/edition/java/proto/util/preader.go new file mode 100644 index 00000000..b3525c99 --- /dev/null +++ b/pkg/edition/java/proto/util/preader.go @@ -0,0 +1,116 @@ +package util + +import "io" + +type PReader struct { + r io.Reader +} + +func PanicReader(r io.Reader) *PReader { + return &PReader{r} +} + +func (r *PReader) VarInt(i *int) { + PVarInt(r.r, i) +} + +func (r *PReader) String(s *string) { + PReadString(r.r, s) +} + +func (r *PReader) Bytes(b *[]byte) { + PReadBytes(r.r, b) +} + +func (r *PReader) Bool(b *bool) { + PReadBool(r.r, b) +} +func (r *PReader) Ok() bool { + var ok bool + PReadBool(r.r, &ok) + return ok +} + +func (r *PReader) Int64(i *int64) { + PReadInt64(r.r, i) +} + +func (r *PReader) Int(i *int) { + PReadInt(r.r, i) +} + +func (r *PReader) Strings(i *[]string) { + PReadStrings(r.r, i) +} + +func PReadStrings(r io.Reader, i *[]string) { + v, err := ReadStringArray(r) + if err != nil { + panic(err) + } + *i = v +} + +func PReadInt(r io.Reader, i *int) { + v, err := ReadInt(r) + if err != nil { + panic(err) + } + *i = v +} + +func PReadInt64(r io.Reader, i *int64) { + v, err := ReadInt64(r) + if err != nil { + panic(err) + } + *i = v +} + +func PReadBool(r io.Reader, b *bool) { + v, err := ReadBool(r) + if err != nil { + panic(err) + } + *b = v +} + +func PVarInt(rd io.Reader, i *int) { + v, err := ReadVarInt(rd) + if err != nil { + panic(err) + } + *i = v +} + +func PReadString(rd io.Reader, s *string) { + v, err := ReadString(rd) + if err != nil { + panic(err) + } + *s = v +} + +func PReadBytes(rd io.Reader, b *[]byte) { + v, err := ReadBytes(rd) + if err != nil { + panic(err) + } + *b = v +} + +func PReadByte(rd io.Reader, b *byte) { + v, err := ReadByte(rd) + if err != nil { + panic(err) + } + *b = v +} + +func PReadByteVal(rd io.Reader) byte { + v, err := ReadByte(rd) + if err != nil { + panic(err) + } + return v +} diff --git a/pkg/edition/java/proto/util/pwriter.go b/pkg/edition/java/proto/util/pwriter.go new file mode 100644 index 00000000..e50a8416 --- /dev/null +++ b/pkg/edition/java/proto/util/pwriter.go @@ -0,0 +1,131 @@ +package util + +import "io" + +// Recover is a helper function to recover from a panic and set the error pointer to the recovered error. +// If the panic is not an error, it will be re-panicked. +// +// Usage: +// +// func fn() (err error) { +// defer Recover(&err) +// // code that may panic(err) +// } +func Recover(err *error) { + if r := recover(); r != nil { + if e, ok := r.(error); ok { + *err = e + } else { + panic(r) + } + } +} + +// RecoverFunc is a helper function to recover from a panic and set the error pointer to the recovered error. +// If the panic is not an error, it will be re-panicked. +// +// Usage: +// +// return RecoverFunc(func() error { +// // code that may panic(err) +// }) +func RecoverFunc(fn func() error) (err error) { + defer Recover(&err) + return fn() +} + +type PWriter struct { + w io.Writer +} + +func PanicWriter(w io.Writer) *PWriter { + return &PWriter{w} +} + +func (w *PWriter) VarInt(i int) { + PWriteVarInt(w.w, i) +} + +func (w *PWriter) String(s string) { + PWriteString(w.w, s) +} + +func (w *PWriter) Bytes(b []byte) { + PWriteBytes(w.w, b) +} + +func (w *PWriter) Bool(b bool) { + PWriteBool(w.w, b) +} + +func (w *PWriter) Int64(i int64) { + PWriteInt64(w.w, i) +} + +func (w *PWriter) Int(i int) { + PWriteInt(w.w, i) +} + +func (w *PWriter) Byte(b byte) { + PWriteByte(w.w, b) +} + +func (w *PWriter) Strings(s []string) { + PWriteStrings(w.w, s) +} + +func (w *PWriter) NBT(nbt NBT) { + PWriteNBT(w.w, nbt) +} + +func PWriteNBT(w io.Writer, nbt NBT) { + if err := WriteNBT(w, nbt); err != nil { + panic(err) + } +} + +func PWriteStrings(w io.Writer, s []string) { + if err := WriteStrings(w, s); err != nil { + panic(err) + } +} + +func PWriteByte(w io.Writer, b byte) { + if err := WriteByte(w, b); err != nil { + panic(err) + } +} + +func PWriteInt(w io.Writer, i int) { + if err := WriteInt(w, i); err != nil { + panic(err) + } +} + +func PWriteInt64(w io.Writer, i int64) { + if err := WriteInt64(w, i); err != nil { + panic(err) + } +} + +func PWriteBool(w io.Writer, b bool) { + if err := WriteBool(w, b); err != nil { + panic(err) + } +} + +func PWriteVarInt(wr io.Writer, i int) { + if err := WriteVarInt(wr, i); err != nil { + panic(err) + } +} +func PWriteString(wr io.Writer, s string) { + if err := WriteString(wr, s); err != nil { + panic(err) + } +} +func PWriteBytes(wr io.Writer, b []byte) { + if err := WriteBytes(wr, b); err != nil { + panic(err) + } +}