From 851e5de2e4b30d0ac59c847f4c0b2ec91dff2fdd Mon Sep 17 00:00:00 2001 From: qxyuan853 <142065397+qxyuan853@users.noreply.github.com> Date: Mon, 30 Sep 2024 10:48:13 +0800 Subject: [PATCH] optimize: optimize raftsnapshot read (#6896) --- .../raft/snapshot/RaftSnapshotSerializer.java | 35 ++++++++++++++++--- .../server/raft/RaftSyncMessageTest.java | 12 +++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/apache/seata/server/cluster/raft/snapshot/RaftSnapshotSerializer.java b/server/src/main/java/org/apache/seata/server/cluster/raft/snapshot/RaftSnapshotSerializer.java index cdf065a1552..1d275920326 100644 --- a/server/src/main/java/org/apache/seata/server/cluster/raft/snapshot/RaftSnapshotSerializer.java +++ b/server/src/main/java/org/apache/seata/server/cluster/raft/snapshot/RaftSnapshotSerializer.java @@ -21,8 +21,13 @@ import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; +import org.apache.seata.common.exception.ErrorCode; +import org.apache.seata.common.exception.SeataRuntimeException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,9 +42,19 @@ public class RaftSnapshotSerializer { private static final Logger LOGGER = LoggerFactory.getLogger(RaftSnapshotSerializer.class); + private static final List PERMITS = new ArrayList<>(); + static { + PERMITS.add(RaftSnapshot.class.getName()); + PERMITS.add(RaftSnapshot.SnapshotType.class.getName()); + PERMITS.add(io.seata.server.cluster.raft.snapshot.RaftSnapshot.class.getName()); + PERMITS.add(io.seata.server.cluster.raft.snapshot.RaftSnapshot.SnapshotType.class.getName()); + PERMITS.add(java.lang.Enum.class.getName()); + PERMITS.add("[B"); + } + public static byte[] encode(RaftSnapshot raftSnapshot) throws IOException { try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutputStream oos = new ObjectOutputStream(bos)) { + ObjectOutputStream oos = new ObjectOutputStream(bos)) { Serializer serializer = EnhancedServiceLoader.load(Serializer.class, SerializerType.getByCode(raftSnapshot.getCodec()).name()); Optional.ofNullable(raftSnapshot.getBody()).ifPresent(value -> raftSnapshot.setBody( @@ -51,7 +66,7 @@ public static byte[] encode(RaftSnapshot raftSnapshot) throws IOException { public static byte[] encode(io.seata.server.cluster.raft.snapshot.RaftSnapshot raftSnapshot) throws IOException { try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutputStream oos = new ObjectOutputStream(bos)) { + ObjectOutputStream oos = new ObjectOutputStream(bos)) { Serializer serializer = EnhancedServiceLoader.load(Serializer.class, SerializerType.getByCode(raftSnapshot.getCodec()).name()); Optional.ofNullable(raftSnapshot.getBody()).ifPresent(value -> raftSnapshot.setBody( @@ -63,7 +78,16 @@ public static byte[] encode(io.seata.server.cluster.raft.snapshot.RaftSnapshot r public static RaftSnapshot decode(byte[] raftSnapshotByte) throws IOException { try (ByteArrayInputStream bin = new ByteArrayInputStream(raftSnapshotByte); - ObjectInputStream ois = new ObjectInputStream(bin)) { + ObjectInputStream ois = new ObjectInputStream(bin) { + @Override + protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + if (!PERMITS.contains(desc.getName())) { + throw new SeataRuntimeException(ErrorCode.ERR_DESERIALIZATION_SECURITY, + "Failed to deserialize object: " + desc.getName() + " is not permitted"); + } + return super.resolveClass(desc); + } + }) { Object object = ois.readObject(); RaftSnapshot raftSnapshot; if (object instanceof io.seata.server.cluster.raft.snapshot.RaftSnapshot) { @@ -83,8 +107,11 @@ public static RaftSnapshot decode(byte[] raftSnapshotByte) throws IOException { .ifPresent(value -> raftSnapshot.setBody(serializer.deserialize(CompressorFactory .getCompressor(raftSnapshot.getCompressor()).decompress((byte[])raftSnapshot.getBody())))); return raftSnapshot; - } catch (ClassNotFoundException e) { + } catch (Exception e) { LOGGER.info("Failed to read raft snapshot: {}", e.getMessage(), e); + if (e instanceof SeataRuntimeException) { + throw (SeataRuntimeException)e; + } throw new IOException(e); } } diff --git a/server/src/test/java/org/apache/seata/server/raft/RaftSyncMessageTest.java b/server/src/test/java/org/apache/seata/server/raft/RaftSyncMessageTest.java index 97b919516e1..e6f9c8f905d 100644 --- a/server/src/test/java/org/apache/seata/server/raft/RaftSyncMessageTest.java +++ b/server/src/test/java/org/apache/seata/server/raft/RaftSyncMessageTest.java @@ -116,6 +116,18 @@ public void testMsgSerializeCompatible() throws IOException { Assertions.assertEquals(1234, ((RaftBranchSessionSyncMsg) raftSyncMessageByBranch.getBody()).getBranchSession().getBranchId()); } + @Test + public void testSecuritySnapshotSerialize() throws IOException { + TestSecurity testSecurity = new TestSecurity(); + byte[] bytes; + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(bos)) { + oos.writeObject(testSecurity); + bytes = bos.toByteArray(); + } + Assertions.assertThrows(SeataRuntimeException.class,()->RaftSnapshotSerializer.decode(bytes)); + } + @Test public void testSnapshotSerialize() throws IOException, TransactionException { Map sessionMap = new HashMap<>();