From b8d3901724b5816d1e552e1bd2c75c0e4aa16a76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20N=C3=BCcke?= Date: Sun, 24 Jul 2022 22:02:40 +0200 Subject: [PATCH] Add a wrapper multipart message for sending messages that do not fit into a single custom payload packet. Fixes #109. --- .../item/FileImportExportCardItemDevice.java | 5 +- .../li/cil/oc2/common/network/Network.java | 4 + .../network/message/MultipartMessage.java | 177 ++++++++++++++++++ .../message/RequestImportedFileMessage.java | 2 +- .../resources/META-INF/accesstransformer.cfg | 1 + 5 files changed, 186 insertions(+), 3 deletions(-) create mode 100644 src/main/java/li/cil/oc2/common/network/message/MultipartMessage.java diff --git a/src/main/java/li/cil/oc2/common/bus/device/rpc/item/FileImportExportCardItemDevice.java b/src/main/java/li/cil/oc2/common/bus/device/rpc/item/FileImportExportCardItemDevice.java index 4a33dc9b..17726971 100644 --- a/src/main/java/li/cil/oc2/common/bus/device/rpc/item/FileImportExportCardItemDevice.java +++ b/src/main/java/li/cil/oc2/common/bus/device/rpc/item/FileImportExportCardItemDevice.java @@ -7,6 +7,7 @@ import li.cil.oc2.api.bus.device.object.Callback; import li.cil.oc2.api.bus.device.object.DocumentedDevice; import li.cil.oc2.api.bus.device.object.Parameter; import li.cil.oc2.api.capabilities.TerminalUserProvider; +import li.cil.oc2.common.Constants; import li.cil.oc2.common.network.Network; import li.cil.oc2.common.network.message.ExportedFileMessage; import li.cil.oc2.common.network.message.RequestImportedFileMessage; @@ -27,7 +28,7 @@ import java.util.Set; import java.util.WeakHashMap; public final class FileImportExportCardItemDevice extends AbstractItemRPCDevice implements DocumentedDevice { - public static final int MAX_TRANSFERRED_FILE_SIZE = 512 * 1024; + public static final int MAX_TRANSFERRED_FILE_SIZE = 512 * Constants.KILOBYTE; private static final String BEGIN_EXPORT_FILE = "beginExportFile"; private static final String WRITE_EXPORT_FILE = "writeExportFile"; @@ -257,7 +258,7 @@ public final class FileImportExportCardItemDevice extends AbstractItemRPCDevice return new byte[0]; } - final byte[] buffer = new byte[1024]; + final byte[] buffer = new byte[512]; final int count = importedFile.data.read(buffer); if (count <= 0) { reset(); diff --git a/src/main/java/li/cil/oc2/common/network/Network.java b/src/main/java/li/cil/oc2/common/network/Network.java index 554b1e3f..8e895efe 100644 --- a/src/main/java/li/cil/oc2/common/network/Network.java +++ b/src/main/java/li/cil/oc2/common/network/Network.java @@ -81,6 +81,10 @@ public final class Network { registerMessage(ProjectorStateMessage.class, ProjectorStateMessage::new, NetworkDirection.PLAY_TO_CLIENT); registerMessage(KeyboardInputMessage.class, KeyboardInputMessage::new, NetworkDirection.PLAY_TO_SERVER); + + registerMessage(MultipartMessage.class, MultipartMessage::new, NetworkDirection.PLAY_TO_SERVER); + + MultipartMessage.registerMessage(ImportedFileMessage.class, ImportedFileMessage::new); } public static void sendToServer(final T message) { diff --git a/src/main/java/li/cil/oc2/common/network/message/MultipartMessage.java b/src/main/java/li/cil/oc2/common/network/message/MultipartMessage.java new file mode 100644 index 00000000..ac79da22 --- /dev/null +++ b/src/main/java/li/cil/oc2/common/network/message/MultipartMessage.java @@ -0,0 +1,177 @@ +/* SPDX-License-Identifier: MIT */ + +package li.cil.oc2.common.network.message; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import li.cil.oc2.common.Constants; +import li.cil.oc2.common.network.Network; +import net.minecraft.network.FriendlyByteBuf; +import net.minecraft.network.protocol.game.ServerboundCustomPayloadPacket; +import net.minecraftforge.network.NetworkEvent; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * Utility wrapper message for client to server messages exceeding the regular custom payload size. + */ +public final class MultipartMessage extends AbstractMessage { + private static final Logger LOGGER = LogManager.getLogger(); + + private static final int MAX_MULTIPART_MESSAGE_SIZE = 1024 * Constants.KILOBYTE; + private static final int MAX_PAYLOAD_SIZE = ServerboundCustomPayloadPacket.MAX_PAYLOAD_SIZE; + private static final int HEADER_SIZE = + 1 /* forge message index */ + + 4 /* message id */ + + 4 /* multipart message id */ + + 2 /* length */; + + /////////////////////////////////////////////////////////////////// + + /** + * Cache for collecting multipart messages on the server into one big buffer again. Discard them after some + * time to avoid malicious clients being able to grow the memory used by this cache to grow infinitely. + */ + private static final Cache MULTIPART_MESSAGE_BUFFER_CACHE = CacheBuilder.newBuilder() + .expireAfterAccess(Duration.ofSeconds(30)) + .build(); + private static int lastAssignedMultipartMessageId; + + /////////////////////////////////////////////////////////////////// + + private static final Map, Entry> ENTRY_BY_TYPE = new HashMap<>(); + private static final Int2ObjectMap ENTRY_BY_ID = new Int2ObjectArrayMap<>(); + private static int lastAssignedId; + + public static void registerMessage(final Class type, final Function factory) { + if (ENTRY_BY_TYPE.containsKey(type)) { + throw new IllegalArgumentException("Message of this type has already been registered."); + } + final int id = ++lastAssignedId; + final Entry entry = new Entry(id, factory); + ENTRY_BY_TYPE.put(type, entry); + ENTRY_BY_ID.put(id, entry); + } + + /////////////////////////////////////////////////////////////////// + + public static void sendToServer(final AbstractMessage message) { + final FriendlyByteBuf buffer = new FriendlyByteBuf(Unpooled.buffer()); + message.toBytes(buffer); + if (buffer.readableBytes() <= MAX_PAYLOAD_SIZE) { + // Message fits into one custom payload packet, send it as is. + Network.sendToServer(message); + return; + } + if (buffer.readableBytes() > MAX_MULTIPART_MESSAGE_SIZE) { + throw new IllegalArgumentException("Message too large."); + } + + final Entry entry = ENTRY_BY_TYPE.get(message.getClass()); + if (entry == null) { + throw new IllegalArgumentException("Trying to send multipart message of unregistered message (" + message.getClass().getName() + ")."); + } + + final int messageId = entry.id(); + final int multipartMessageId = ++lastAssignedMultipartMessageId; + + while (buffer.readableBytes() > 0) { + final int dataLength = Math.min(buffer.readableBytes(), MAX_PAYLOAD_SIZE - HEADER_SIZE); + final byte[] data = new byte[dataLength]; + buffer.readBytes(data); + Network.sendToServer(new MultipartMessage(messageId, multipartMessageId, data)); + } + } + + /////////////////////////////////////////////////////////////////// + + /** + * Automatically computed on client. Implicit because all but last packets are max size. + */ + private boolean isFinalPart; + + private int messageId; + private int multipartMessageId; + private byte[] data; + + /////////////////////////////////////////////////////////////////// + + public MultipartMessage(final int messageId, final int multipartMessageId, final byte[] data) { + this.messageId = messageId; + this.multipartMessageId = multipartMessageId; + this.data = data; + } + + public MultipartMessage(final FriendlyByteBuf buffer) { + super(buffer); + } + + /////////////////////////////////////////////////////////////////// + + @Override + public void fromBytes(final FriendlyByteBuf buffer) { + isFinalPart = buffer.readableBytes() < MAX_PAYLOAD_SIZE - 1 /* forge message index */; + + messageId = buffer.readInt(); + multipartMessageId = buffer.readInt(); + final int length = buffer.readUnsignedShort(); + data = new byte[length]; + buffer.readBytes(data); + } + + @Override + public void toBytes(final FriendlyByteBuf buffer) { + buffer.writeInt(messageId); + buffer.writeInt(multipartMessageId); + buffer.writeShort(data.length); + buffer.writeBytes(data); + } + + /////////////////////////////////////////////////////////////////// + + @Override + protected void handleMessage(final Supplier contextSupplier) { + try { + final ByteBuf buffer = MULTIPART_MESSAGE_BUFFER_CACHE.get(lastAssignedMultipartMessageId, Unpooled::buffer); + if (buffer.capacity() == 0) { + return; // Invalidated entry due to being over-sized. + } + + buffer.writeBytes(data); + if (buffer.readableBytes() > MAX_MULTIPART_MESSAGE_SIZE) { + LOGGER.error("Received over-sized multipart message from client [{}], ignoring.", contextSupplier.get().getSender()); + MULTIPART_MESSAGE_BUFFER_CACHE.put(lastAssignedMultipartMessageId, Unpooled.buffer(0)); + return; + } + + if (isFinalPart) { + MULTIPART_MESSAGE_BUFFER_CACHE.invalidate(lastAssignedMultipartMessageId); + + final Entry entry = ENTRY_BY_ID.get(messageId); + if (entry == null) { + LOGGER.error("Received multipart message for unregistered message from client [{}]. Are the mod version on the server and client the same?", contextSupplier.get().getSender()); + return; + } + + entry.factory.apply(new FriendlyByteBuf(buffer)).handleMessage(contextSupplier); + } + } catch (final ExecutionException e) { + LOGGER.error("Error when handling multipart message received from client [{}]: {}", contextSupplier.get().getSender(), e); + } + } + + /////////////////////////////////////////////////////////////////// + + private record Entry(int id, Function factory) { } +} diff --git a/src/main/java/li/cil/oc2/common/network/message/RequestImportedFileMessage.java b/src/main/java/li/cil/oc2/common/network/message/RequestImportedFileMessage.java index def4a733..e559a918 100644 --- a/src/main/java/li/cil/oc2/common/network/message/RequestImportedFileMessage.java +++ b/src/main/java/li/cil/oc2/common/network/message/RequestImportedFileMessage.java @@ -64,7 +64,7 @@ public final class RequestImportedFileMessage extends AbstractMessage { Minecraft.getInstance().gui.getChat().addMessage(FILE_TOO_LARGE_TEXT .withStyle(s -> s.withColor(TextColor.fromRgb(0xFFA0A0)))); } else { - Network.sendToServer(new ImportedFileMessage(id, fileName, data)); + MultipartMessage.sendToServer(new ImportedFileMessage(id, fileName, data)); } } catch (final IOException e) { LOGGER.error(e); diff --git a/src/main/resources/META-INF/accesstransformer.cfg b/src/main/resources/META-INF/accesstransformer.cfg index 19d8a480..5c7ef746 100644 --- a/src/main/resources/META-INF/accesstransformer.cfg +++ b/src/main/resources/META-INF/accesstransformer.cfg @@ -1,3 +1,4 @@ public net.minecraft.client.MouseHandler f_91520_ # mouseGrabbed public-f net.minecraft.world.entity.Entity m_142467_(Lnet/minecraft/world/entity/Entity$RemovalReason;)V # setRemoved public-f net.minecraft.client.renderer.GameRenderer f_109054_ # mainCamera +public net.minecraft.network.protocol.game.ServerboundCustomPayloadPacket f_179586_ # MAX_PAYLOAD_SIZE