diff --git a/src/main/java/li/cil/oc2/api/inet/session/Session.java b/src/main/java/li/cil/oc2/api/inet/session/Session.java index 535f5dd4..ecc73376 100644 --- a/src/main/java/li/cil/oc2/api/inet/session/Session.java +++ b/src/main/java/li/cil/oc2/api/inet/session/Session.java @@ -2,6 +2,7 @@ package li.cil.oc2.api.inet.session; import javax.annotation.Nullable; import java.net.InetSocketAddress; +import java.time.Instant; public interface Session { long getId(); @@ -11,12 +12,14 @@ public interface Session { States getState(); @Nullable - Object getUserdata(); + Object getAttachment(); - void setUserdata(final Object userdata); + void setAttachment(@Nullable final Object userdata); InetSocketAddress getDestination(); + Instant getLastUpdateTime(); + default boolean isClosed() { return switch (getState()) { case FINISH, REJECT, EXPIRED -> true; diff --git a/src/main/java/li/cil/oc2/common/Config.java b/src/main/java/li/cil/oc2/common/Config.java index 8ebc8488..33285954 100644 --- a/src/main/java/li/cil/oc2/common/Config.java +++ b/src/main/java/li/cil/oc2/common/Config.java @@ -68,7 +68,7 @@ public final class Config { @Path("internet-card") public static String defaultNameServer = "1.1.1.1"; @Path("internet-card") public static boolean useSynchronisedNAT = false; @Path("internet-card") public static int streamBufferSize = 2000; - @Path("internet-card") public static int tcpRetransmissionTimeoutMs = 30 * 1000; + @Path("internet-card") public static int tcpRetransmissionTimeoutMs = 2 * 1000; public static boolean computersUseEnergy() { return computerEnergyPerTick > 0 && computerEnergyStorage > 0; diff --git a/src/main/java/li/cil/oc2/common/inet/DatagramSessionBase.java b/src/main/java/li/cil/oc2/common/inet/DatagramSessionBase.java new file mode 100644 index 00000000..b13a19f4 --- /dev/null +++ b/src/main/java/li/cil/oc2/common/inet/DatagramSessionBase.java @@ -0,0 +1,33 @@ +package li.cil.oc2.common.inet; + +public abstract class DatagramSessionBase extends SessionBase { + + private States state = States.NEW; + + public DatagramSessionBase(final int ipAddress, final short port) { + super(ipAddress, port); + } + + @Override + public void close() { + switch (state) { + case NEW -> state = States.REJECT; + case ESTABLISHED -> state = States.FINISH; + default -> throw new IllegalStateException(); + } + } + + @Override + public States getState() { + return state; + } + + @Override + public void expire() { + state = States.EXPIRED; + } + + public void setState(final States state) { + this.state = state; + } +} diff --git a/src/main/java/li/cil/oc2/common/inet/DatagramSessionImpl.java b/src/main/java/li/cil/oc2/common/inet/DatagramSessionImpl.java index c1f28c41..f7687ec6 100644 --- a/src/main/java/li/cil/oc2/common/inet/DatagramSessionImpl.java +++ b/src/main/java/li/cil/oc2/common/inet/DatagramSessionImpl.java @@ -2,7 +2,7 @@ package li.cil.oc2.common.inet; import li.cil.oc2.api.inet.session.DatagramSession; -public final class DatagramSessionImpl extends SessionBase implements DatagramSession { +public final class DatagramSessionImpl extends DatagramSessionBase implements DatagramSession { private final DatagramSessionDiscriminator discriminator; public DatagramSessionImpl(final int ipAddress, final short port, final DatagramSessionDiscriminator discriminator) { diff --git a/src/main/java/li/cil/oc2/common/inet/DefaultSessionLayer.java b/src/main/java/li/cil/oc2/common/inet/DefaultSessionLayer.java index 6f04d68e..13cab8b7 100644 --- a/src/main/java/li/cil/oc2/common/inet/DefaultSessionLayer.java +++ b/src/main/java/li/cil/oc2/common/inet/DefaultSessionLayer.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.jetbrains.annotations.Nullable; import java.io.IOException; +import java.net.ConnectException; import java.net.InetAddress; import java.net.SocketAddress; import java.nio.ByteBuffer; @@ -61,7 +62,35 @@ public final class DefaultSessionLayer implements SessionLayer { return; } - final boolean somethingRead = processQueue(readySessions.getToRead(), session -> { + final boolean somethingConnected = processQueue(readySessions.getToConnect(), session -> { + if (session instanceof StreamSession streamSession) { + LOGGER.info("Connected {}", session); + if (session.getState() != Session.States.NEW) { + return false; + } + receiver.receive(streamSession); + try { + final SocketChannel channel = getChannel(streamSession); + channel.finishConnect(); + streamSession.connect(); + return true; + } catch (final ConnectException exception) { + LOGGER.info("Connection rejected for {}", session); + closeSession(session); + return true; + } catch (final IOException exception) { + LOGGER.error("Error on socket.finishConnect()", exception); + closeSession(session); + return true; + } + } + return false; + }); + if (somethingConnected) { + return; + } + + processQueue(readySessions.getToRead(), session -> { if (session instanceof DatagramSession datagramSession) { LOGGER.info("Datagram received"); final DatagramChannel channel = getChannel(datagramSession); @@ -83,14 +112,15 @@ public final class DefaultSessionLayer implements SessionLayer { LOGGER.info("Datagram received"); } else if (session instanceof StreamSession streamSession) { LOGGER.info("Stream received"); - final SocketChannel channel = getChannel(streamSession); + final ByteBuffer stream = receiver.receive(streamSession); try { - final ByteBuffer stream = receiver.receive(streamSession); + final SocketChannel channel = getChannel(streamSession); assert stream != null; + assert false; final int read = channel.read(stream); - if (read != 0) { - // some data still remaining in socket, read it later - readySessions.getToRead().add(streamSession); + LOGGER.info("Read from real world: {}", read); + if (read == -1) { + closeSession(session); } return true; } catch (final IOException exception) { @@ -99,18 +129,6 @@ public final class DefaultSessionLayer implements SessionLayer { } return false; }); - if (somethingRead) { - return; - } - - processQueue(readySessions.getToConnect(), session -> { - if (session instanceof StreamSession streamSession) { - receiver.receive(streamSession); - streamSession.connect(); - return true; - } - return false; - }); } @Override @@ -136,7 +154,7 @@ public final class DefaultSessionLayer implements SessionLayer { case NEW: { final DatagramChannel channel = socketManager.createDatagramChannel(datagramSession, readySessions); - datagramSession.setUserdata(channel); + datagramSession.setAttachment(channel); LOGGER.info("Open datagram socket {}", session.getDestination()); /* Fallthrough */ } @@ -162,9 +180,9 @@ public final class DefaultSessionLayer implements SessionLayer { switch (session.getState()) { case NEW -> { final SocketChannel channel = socketManager.createStreamChannel(streamSession, readySessions); - session.setUserdata(channel); - channel.connect(session.getDestination()); - LOGGER.info("Open stream socket {}", session.getDestination()); + streamSession.setAttachment(channel); + channel.connect(streamSession.getDestination()); + LOGGER.info("Open stream socket {}", streamSession.getDestination()); } case ESTABLISHED -> { final SocketChannel channel = getChannel(streamSession); @@ -203,13 +221,16 @@ public final class DefaultSessionLayer implements SessionLayer { private void closeSession(final Session session) { try { getChannel(session).close(); + if (!session.isClosed()) { + session.close(); + } } catch (final IOException exception) { LOGGER.error("Error on closing channel", exception); } } private Object getExistingUserdata(final Session session) { - final Object channel = session.getUserdata(); + final Object channel = session.getAttachment(); assert channel != null; return channel; } diff --git a/src/main/java/li/cil/oc2/common/inet/DefaultTransportLayer.java b/src/main/java/li/cil/oc2/common/inet/DefaultTransportLayer.java index 3daacffa..3dcf82c0 100644 --- a/src/main/java/li/cil/oc2/common/inet/DefaultTransportLayer.java +++ b/src/main/java/li/cil/oc2/common/inet/DefaultTransportLayer.java @@ -16,6 +16,7 @@ import org.apache.logging.log4j.Logger; import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.*; import java.util.function.Consumer; import java.util.function.Function; @@ -42,7 +43,7 @@ public final class DefaultTransportLayer implements TransportLayer { private final SessionReceiver receiver = new SessionReceiver(); private final NavigableMap expirationQueue = new TreeMap<>(); - private final NavigableMap retransmissionQueue = new TreeMap<>(); + private StreamSessionImpl streamToAck = null; private final Map, SessionBase> sessions = new HashMap<>(); private ICMPReply icmpReply = null; @@ -54,49 +55,66 @@ public final class DefaultTransportLayer implements TransportLayer { this.sessionLayer = sessionLayer; } - private void processExpirationQueue(final Map queue, final Consumer action) { + private void processExpirationQueue(final NavigableMap queue, final Consumer action) { if (queue.isEmpty()) { return; } - final Instant now = Instant.now(); - final Iterator> iterator = queue.entrySet().iterator(); + final Instant expireTime = Instant.now().minus(Config.defaultSessionLifetimeMs, ChronoUnit.MILLIS); + final Iterator iterator = queue.navigableKeySet().iterator(); while (iterator.hasNext()) { - Map.Entry entry = iterator.next(); - if (entry.getKey().compareTo(now) < 0) { + final Instant time = iterator.next(); + if (time.compareTo(expireTime) < 0) { + final T value = queue.get(time); iterator.remove(); - final T entryValue = entry.getValue(); - action.accept(entryValue); + action.accept(value); } else { return; } } } + @Nullable + private StreamSessionImpl getNextStreamForRetransmission() { + if (expirationQueue.isEmpty()) { + return null; + } + final Instant retransmissionTime = Instant.now().minus(Config.tcpRetransmissionTimeoutMs, ChronoUnit.MILLIS); + for (final Instant time : expirationQueue.navigableKeySet()) { + if (time.compareTo(retransmissionTime) < 0) { + final SessionBase session = expirationQueue.get(time); + if (session instanceof StreamSessionImpl stream && stream.isNeedsAcknowledgment()) { + return stream; + } + } else { + break; + } + } + return null; + } + private void processSessionExpirationQueue() { processExpirationQueue(expirationQueue, session -> { sessions.remove(session.getDiscriminator()); --allSessionCount; LOGGER.info("Expired session {}", session.getDiscriminator()); - session.setState(Session.States.EXPIRED); + session.expire(); sessionLayer.sendSession(session, null); }); } private void updateSession(final SessionBase session) { - final Instant oldKey = session.getExpireTime(); - if (oldKey != null) { - expirationQueue.remove(oldKey); - } - session.updateExpireTime(); - final Instant newExpireTime = session.getExpireTime(); - SessionBase previous = expirationQueue.put(newExpireTime, session); + final Instant oldKey = session.getLastUpdateTime(); + expirationQueue.remove(oldKey); + session.update(); + final Instant newLastUpdateTime = session.getLastUpdateTime(); + SessionBase previous = expirationQueue.put(newLastUpdateTime, session); assert previous == null; } private void closeSession(final SessionBase session) { LOGGER.info("Close session {}", session.getDiscriminator()); sessions.remove(session.getDiscriminator()); - expirationQueue.remove(session.getExpireTime()); + expirationQueue.remove(session.getLastUpdateTime()); --allSessionCount; } @@ -148,7 +166,7 @@ public final class DefaultTransportLayer implements TransportLayer { ); } - private void sessionSendFinish(final SessionBase session, final ByteBuffer payload, final int srcIpAddress) { + private void sessionSendFinish(final DatagramSessionBase session, final ByteBuffer payload, final int srcIpAddress) { final Session.States state = session.getState(); switch (state) { case NEW: @@ -169,29 +187,36 @@ public final class DefaultTransportLayer implements TransportLayer { } } - private boolean prepareTCPSegment(final TransportMessage message, final StreamSessionImpl stream) { + private SessionActions prepareTCPSegment(final TransportMessage message, final StreamSessionImpl stream) { final ByteBuffer data = message.getData(); final StreamSessionDiscriminator discriminator = stream.getDiscriminator(); final int position = data.position(); final int limit = data.limit(); data.putShort(discriminator.getDstPort()); data.putShort(discriminator.getSrcPort()); - final boolean recv = stream.onReceive(data); - if (!recv) { - data.position(position); - data.limit(limit); - return false; + final SessionActions recv = stream.receive(data); + switch (recv) { + case DROP, IGNORE -> { + data.position(position); + data.limit(limit); + return recv; + } + case FORWARD -> { + data.position(position); + final short checksum = InetUtils.transportRfc1071Checksum( + data, + discriminator.getDstIpAddress(), + discriminator.getSrcIpAddress(), + PROTOCOL_TCP + ); + data.putShort(position + 16, checksum); + data.position(position); + message.updateIpv4(discriminator.getDstIpAddress(), discriminator.getSrcIpAddress()); + LOGGER.info("Prepared TCP packet to receive {}", stream.getHeader()); + return SessionActions.FORWARD; + } + default -> throw new IllegalStateException(); } - data.position(position); - final short checksum = InetUtils.transportRfc1071Checksum( - data, - discriminator.getDstIpAddress(), - discriminator.getSrcIpAddress(), - PROTOCOL_TCP - ); - data.putShort(position + 16, checksum); - data.position(position); - return true; } @Override @@ -201,8 +226,9 @@ public final class DefaultTransportLayer implements TransportLayer { while (true) { if (rejectedStream != null) { // This branch should be checked first! Stream needs to be closed properly - boolean success = prepareTCPSegment(message, rejectedStream); - assert success; + LOGGER.info("Rejecting stream {}", rejectedStream.getDiscriminator()); + final SessionActions success = prepareTCPSegment(message, rejectedStream); + assert success == SessionActions.FORWARD; closeSession(rejectedStream); rejectedStream = null; return PROTOCOL_TCP; @@ -222,11 +248,37 @@ public final class DefaultTransportLayer implements TransportLayer { return PROTOCOL_ICMP; } - if (!retransmissionQueue.isEmpty()) { + if (streamToAck != null) { + final StreamSessionImpl stream = streamToAck; + streamToAck = null; + updateSession(stream); + switch (prepareTCPSegment(message, stream)) { + case FORWARD -> { + if (stream.isClosed()) { + closeSession(stream); + } + return PROTOCOL_TCP; + } + case DROP -> closeSession(stream); + } + } + /* + final StreamSessionImpl retransmitSession = getNextStreamForRetransmission(); + if (retransmitSession != null) { // Process retransmission queue - processExpirationQueue(retransmissionQueue, stream -> prepareTCPSegment(message, stream)); + updateSession(retransmitSession); + switch (prepareTCPSegment(message, retransmitSession)) { + case FORWARD -> { + if (retransmitSession.isClosed()) { + closeSession(retransmitSession); + } + return PROTOCOL_TCP; + } + case DROP -> closeSession(retransmitSession); + } return PROTOCOL_TCP; } + */ receiver.prepare(message.getData()); sessionLayer.receiveSession(receiver); @@ -282,8 +334,14 @@ public final class DefaultTransportLayer implements TransportLayer { } } else if (session instanceof StreamSession) { final StreamSessionImpl streamSession = (StreamSessionImpl) session; - if (prepareTCPSegment(message, streamSession)) { - return PROTOCOL_TCP; + switch (prepareTCPSegment(message, streamSession)) { + case FORWARD -> { + if (streamSession.isClosed()) { + closeSession(streamSession); + } + return PROTOCOL_TCP; + } + case DROP -> closeSession(streamSession); } } else { throw new IllegalStateException(); @@ -303,7 +361,7 @@ public final class DefaultTransportLayer implements TransportLayer { @Override public void onStop() { for (final SessionBase session : sessions.values()) { - session.setState(Session.States.FINISH); + session.expire(); sessionLayer.sendSession(session, null); closeSession(session); } @@ -390,35 +448,29 @@ public final class DefaultTransportLayer implements TransportLayer { if (session == null) { reject(data, srcIpAddress); } else { - if (session.onSend(data)) { - if (session.getState() == Session.States.NEW) - sessionLayer.sendSession(session, data); - final Session.States state = session.getState(); - if (state == Session.States.REJECT || state == Session.States.FINISH) { - rejectedStream = session; + LOGGER.info("GOT TCP"); + switch (session.send(data)) { + case FORWARD -> { + switch (session.getState()) { + case NEW, FINISH -> sessionLayer.sendSession(session, null); + case ESTABLISHED -> sessionLayer.sendSession(session, session.getSendBuffer()); + } + final Session.States state = session.getState(); + if (state == Session.States.REJECT || state == Session.States.FINISH) { + rejectedStream = session; + } + if (session.isNeedsAcknowledgment()) { + streamToAck = session; + } } - } else { - closeSession(session); + case DROP -> closeSession(session); } } } } } - private static final class ICMPReply { - private final byte type; - private final byte code; - private final int srcIpAddress; - private final int dstIpAddress; - private final byte[] payload; - - public ICMPReply(final byte type, final byte code, final int srcIpAddress, final int dstIpAddress, final byte[] payload) { - this.type = type; - this.code = code; - this.srcIpAddress = srcIpAddress; - this.dstIpAddress = dstIpAddress; - this.payload = payload; - } + private record ICMPReply(byte type, byte code, int srcIpAddress, int dstIpAddress, byte[] payload) { } private static final class SessionReceiver implements SessionLayer.Receiver { diff --git a/src/main/java/li/cil/oc2/common/inet/EchoSessionImpl.java b/src/main/java/li/cil/oc2/common/inet/EchoSessionImpl.java index 80edbb6a..e080bc5c 100644 --- a/src/main/java/li/cil/oc2/common/inet/EchoSessionImpl.java +++ b/src/main/java/li/cil/oc2/common/inet/EchoSessionImpl.java @@ -2,7 +2,7 @@ package li.cil.oc2.common.inet; import li.cil.oc2.api.inet.session.EchoSession; -public final class EchoSessionImpl extends SessionBase implements EchoSession { +public final class EchoSessionImpl extends DatagramSessionBase implements EchoSession { private final EchoSessionDiscriminator discriminator; private byte ttl; private short sequenceNumber; diff --git a/src/main/java/li/cil/oc2/common/inet/SessionActions.java b/src/main/java/li/cil/oc2/common/inet/SessionActions.java new file mode 100644 index 00000000..cb72f2ae --- /dev/null +++ b/src/main/java/li/cil/oc2/common/inet/SessionActions.java @@ -0,0 +1,12 @@ +package li.cil.oc2.common.inet; + +public enum SessionActions { + // Bad session. Drop the whole session + DROP, + + // Transfer message to session layer on send and to network layer on receive + FORWARD, + + // Do nothing upon return + IGNORE, +} diff --git a/src/main/java/li/cil/oc2/common/inet/SessionBase.java b/src/main/java/li/cil/oc2/common/inet/SessionBase.java index 4a81480d..bf4682fe 100644 --- a/src/main/java/li/cil/oc2/common/inet/SessionBase.java +++ b/src/main/java/li/cil/oc2/common/inet/SessionBase.java @@ -1,7 +1,6 @@ package li.cil.oc2.common.inet; import li.cil.oc2.api.inet.session.Session; -import li.cil.oc2.common.Config; import javax.annotation.Nullable; import java.net.InetSocketAddress; @@ -13,13 +12,12 @@ public abstract class SessionBase implements Session { private final long id = idGenerator.getAndIncrement(); private final InetSocketAddress destination; - private States state; - private Instant expireTime; - private Object userdata; + private Instant lastUpdateTime = Instant.now(); + @Nullable + private Object attachment; public SessionBase(final int ipAddress, final short port) { destination = new InetSocketAddress(InetUtils.toJavaInetAddress(ipAddress), Short.toUnsignedInt(port)); - state = States.NEW; } @Override @@ -27,42 +25,24 @@ public abstract class SessionBase implements Session { return id; } - @Override - public void close() { - switch (state) { - case NEW -> state = States.REJECT; - case ESTABLISHED -> state = States.FINISH; - default -> throw new IllegalStateException(); - } + public void update() { + lastUpdateTime = Instant.now(); } @Override - public States getState() { - return state; - } - - public void setState(final States state) { - this.state = state; - } - - public void updateExpireTime() { - expireTime = Instant.now().plusMillis(Config.defaultSessionLifetimeMs); - } - - @Nullable - public Instant getExpireTime() { - return expireTime; + public Instant getLastUpdateTime() { + return lastUpdateTime; } @Nullable @Override - public Object getUserdata() { - return this.userdata; + public Object getAttachment() { + return this.attachment; } @Override - public void setUserdata(final Object userdata) { - this.userdata = userdata; + public void setAttachment(@Nullable final Object userdata) { + this.attachment = userdata; } @Override @@ -71,4 +51,6 @@ public abstract class SessionBase implements Session { } public abstract SessionDiscriminator getDiscriminator(); + + public abstract void expire(); } diff --git a/src/main/java/li/cil/oc2/common/inet/SocketManager.java b/src/main/java/li/cil/oc2/common/inet/SocketManager.java index a59142d9..f9c6f42d 100644 --- a/src/main/java/li/cil/oc2/common/inet/SocketManager.java +++ b/src/main/java/li/cil/oc2/common/inet/SocketManager.java @@ -12,7 +12,6 @@ import java.nio.channels.DatagramChannel; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.SocketChannel; -import java.util.Set; public final class SocketManager { private static final Logger LOGGER = LogManager.getLogger(); diff --git a/src/main/java/li/cil/oc2/common/inet/StreamSessionImpl.java b/src/main/java/li/cil/oc2/common/inet/StreamSessionImpl.java index e68bf4f8..d2dd53b5 100644 --- a/src/main/java/li/cil/oc2/common/inet/StreamSessionImpl.java +++ b/src/main/java/li/cil/oc2/common/inet/StreamSessionImpl.java @@ -2,21 +2,23 @@ package li.cil.oc2.common.inet; import li.cil.oc2.api.inet.session.StreamSession; import li.cil.oc2.common.Config; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; -import javax.annotation.Nullable; import java.nio.ByteBuffer; -import java.time.Instant; -import java.time.temporal.ChronoUnit; import java.util.Random; public class StreamSessionImpl extends SessionBase implements StreamSession { + + private static final Logger LOGGER = LogManager.getLogger(); + private static final Random random = new Random(); private final StreamSessionDiscriminator discriminator; // Data from session layer implementation private final ByteBuffer receiveBuffer = ByteBuffer.allocate(Config.streamBufferSize); - private int receiveWindow = 0; + private int vmWindow = 0; private int nextSegmentMark = 0; // for retransmission // Data from virtual machine @@ -27,175 +29,48 @@ public class StreamSessionImpl extends SessionBase implements StreamSession { private final TcpHeader header = new TcpHeader(); + private TcpStates state = TcpStates.CONNECT; + private boolean needsAcknowledgment = false; - private Instant retransmitTime = Instant.now(); ///////////////////////////////////////////////////////////////////////// public StreamSessionImpl( - final int ipAddress, - final short port, - final StreamSessionDiscriminator discriminator + final int ipAddress, + final short port, + final StreamSessionDiscriminator discriminator ) { super(ipAddress, port); this.discriminator = discriminator; sendBuffer.limit(0); } - private int computeWindow() { - return sendBuffer.remaining(); + public SessionActions receive(final ByteBuffer segment) { + return state.receive(this, segment); } - public boolean newConnection(final ByteBuffer data) { - final boolean correct = header.read(data); - if (!correct) { - return false; - } - final boolean isInitiation = header.isConnectionInitiation(); - if (!isInitiation) { - return false; - } - vmSequence = header.sequenceNumber; - receiveWindow = header.window; - return true; + public SessionActions send(final ByteBuffer segment) { + return state.send(this, segment); } - private void acceptConnection(final ByteBuffer data) { - header.acceptConnection(mySequence++, ++vmSequence, computeWindow()); - header.write(data); - data.flip(); - setState(States.ESTABLISHED); + boolean isNeedsAcknowledgment() { + return needsAcknowledgment; } - private void denyConnection(final ByteBuffer data) { - header.denyConnection(mySequence, vmSequence + 1); - header.write(data); - data.flip(); + @Override + public ByteBuffer getReceiveBuffer() { + switch (state) { + case EXPIRED, FINISH, REJECT -> throw new IllegalStateException(); + } + return receiveBuffer; } - private boolean onPacket(final ByteBuffer data) { - final boolean correct = header.read(data); - if (!correct) { - return false; + @Override + public ByteBuffer getSendBuffer() { + switch (state) { + case EXPIRED, REJECT -> throw new IllegalStateException(); } - if (header.syn) { - return false; - } - if (header.sequenceNumber != vmSequence) { - return false; - } - if (header.ack) { - // Segment received - if (header.acknowledgmentNumber != mySequence) { - return false; - } - receiveWindow = header.window; - final int newPosition = receiveBuffer.position() - nextSegmentMark; - receiveBuffer.position(nextSegmentMark); - receiveBuffer.compact(); - receiveBuffer.position(newPosition); - receiveBuffer.limit(receiveBuffer.capacity()); - nextSegmentMark = 0; - } else { - receiveWindow = header.window; - } - if (header.psh) { - // Data to be sent - final int length = data.remaining(); - if (length > computeWindow()) { - // TODO: State changed, but packet rejected - return false; - } - vmSequence += length; - sendBuffer.compact(); - sendBuffer.limit(sendBuffer.limit() + length); - sendBuffer.put(data); - needsAcknowledgment = true; - } - if (header.fin) { - setState(States.FINISH); - ++vmSequence; - } - return true; - } - - private void pushNextReceivedDataTo(final ByteBuffer data) { - final int position = data.position(); - data.position(position + TcpHeader.MIN_HEADER_SIZE_NO_PORTS); - - // Copy payload (yes, it is easier to prepare payload first) - final int recvPos = receiveBuffer.position(); - final int recvLim = receiveBuffer.limit(); - receiveBuffer.limit(nextSegmentMark); - receiveBuffer.position(0); - data.put(receiveBuffer); - receiveBuffer.position(recvPos); - receiveBuffer.limit(recvLim); - data.position(position); - - // Update time - retransmitTime = Instant.now().plus(Config.tcpRetransmissionTimeoutMs, ChronoUnit.MILLIS); - } - - private boolean preparePacket(final ByteBuffer data) { - final int length = receiveBuffer.position(); - header.urg = false; - header.syn = false; - header.rst = false; - header.ack = needsAcknowledgment; - header.sequenceNumber = mySequence - nextSegmentMark; - header.acknowledgmentNumber = vmSequence; - header.maxSegmentSize = -1; - header.urgentPointer = 0; - header.window = computeWindow(); - header.psh = length != 0; - if (header.psh) { - header.fin = false; - // We have something to receive - if (nextSegmentMark == 0) { - // Acknowledged, prepare next segment - nextSegmentMark = Math.min(Math.min(receiveWindow, length), data.remaining() - TcpHeader.MIN_HEADER_SIZE_NO_PORTS); - mySequence += nextSegmentMark; - pushNextReceivedDataTo(data); - } else { - // Packet is already sent, is retransmission required? - if (retransmitTime.compareTo(Instant.now()) > 0) { - return false; // no - } else { - pushNextReceivedDataTo(data); - } - } - } else { - header.fin = getState() == States.FINISH; - header.window = 0; - } - header.write(data); - return true; - } - - public boolean onSend(final ByteBuffer data) { - return switch (getState()) { - case NEW -> newConnection(data); - case ESTABLISHED, FINISH -> onPacket(data); - case REJECT, EXPIRED -> throw new IllegalStateException(); - }; - } - - public boolean onReceive(final ByteBuffer data) { - switch (getState()) { - case NEW: - acceptConnection(data); - return true; - case ESTABLISHED: - case FINISH: - return preparePacket(data); - case REJECT: - denyConnection(data); - return true; - case EXPIRED: - throw new IllegalStateException(); - } - return false; + return sendBuffer; } @Override @@ -204,32 +79,271 @@ public class StreamSessionImpl extends SessionBase implements StreamSession { } @Override - public ByteBuffer getReceiveBuffer() { - return receiveBuffer; - } - - @Override - public ByteBuffer getSendBuffer() { - return sendBuffer; - } - - @Nullable - public Instant whenCoolOff() { - if (nextSegmentMark != 0) { - return retransmitTime; - } else { - return null; - } + public void expire() { + state = TcpStates.EXPIRED; } @Override public void connect() { - if (getState() != States.NEW) + if (state != TcpStates.CONNECT) { throw new IllegalStateException(); - setState(States.ESTABLISHED); + } + state = TcpStates.ACCEPT; + } + + @Override + public States getState() { + return state.toSessionState(); + } + + @Override + public void close() { + state = switch (state) { + case ESTABLISHED -> TcpStates.FINISH; + case CONNECT -> TcpStates.REJECT; + default -> throw new IllegalStateException(); + }; } public TcpHeader getHeader() { return header; } + + @Override + public String toString() { + return "StreamSession(" + discriminator + ")"; + } + + private int computeWindow() { + return sendBuffer.capacity() - sendBuffer.limit(); + } + + private enum TcpStates { + CONNECT { + @Override + SessionActions receive(final StreamSessionImpl session, final ByteBuffer segment) { + LOGGER.warn("Incorrect session layer implementation. Stream session is not updated."); + return SessionActions.IGNORE; + } + + @Override + SessionActions send(final StreamSessionImpl session, final ByteBuffer segment) { + final TcpHeader header = session.header; + if (!header.read(segment)) { + return SessionActions.DROP; + } + if (!header.isConnectionInitiation()) { + // weird packet; drop whole session + return SessionActions.DROP; + } + // initialize stream state + session.vmSequence = header.sequenceNumber; + session.vmWindow = header.window; + return SessionActions.FORWARD; + } + + @Override + States toSessionState() { + return States.NEW; + } + }, + ACCEPT { + @Override + SessionActions receive(final StreamSessionImpl session, final ByteBuffer segment) { + final TcpHeader header = session.header; + header.acceptConnection(session.mySequence, session.vmSequence + 1, session.computeWindow()); + header.write(segment); + segment.flip(); + return SessionActions.FORWARD; + } + + @Override + SessionActions send(final StreamSessionImpl session, final ByteBuffer segment) { + final TcpHeader header = session.header; + if (!header.read(segment)) { + // strange incorrect packet; let's ignore it + return SessionActions.IGNORE; + } + if (!header.isAcceptanceOrRejectionAcknowledged()) { + return SessionActions.IGNORE; + } + session.mySequence += 1; + session.vmSequence += 1; + session.state = TcpStates.ESTABLISHED; + session.vmWindow = header.window; + // session layer already knows about this session; do not bother it + return SessionActions.IGNORE; + } + + @Override + States toSessionState() { + return States.ESTABLISHED; + } + }, + REJECT { + @Override + SessionActions receive(final StreamSessionImpl session, final ByteBuffer segment) { + final TcpHeader header = session.header; + header.rejectConnection(session.mySequence, session.vmSequence + 1); + header.write(segment); + segment.flip(); + return SessionActions.FORWARD; + } + + @Override + SessionActions send(final StreamSessionImpl session, final ByteBuffer segment) { + // rejection sent and session should be closed now + throw new IllegalStateException(); + } + + @Override + States toSessionState() { + return States.REJECT; + } + }, + ESTABLISHED { + @Override + SessionActions receive(final StreamSessionImpl session, final ByteBuffer segment) { + final TcpHeader header = session.header; + final ByteBuffer receiveBuffer = session.receiveBuffer; + if (session.nextSegmentMark == 0) { + session.nextSegmentMark = Math.min(Math.min(session.vmWindow, receiveBuffer.position()), segment.remaining() - TcpHeader.MIN_HEADER_SIZE_NO_PORTS); + LOGGER.info("Next segment mark: {}", session.nextSegmentMark); + } + header.urg = false; + header.syn = false; + header.rst = false; + header.ack = true; //session.needsAcknowledgment; + header.sequenceNumber = session.mySequence; //- session.nextSegmentMark; + header.acknowledgmentNumber = /*header.ack ?*/ session.vmSequence /*: 0*/; + header.maxSegmentSize = -1; + header.urgentPointer = 0; + header.psh = session.nextSegmentMark != 0; + header.window = session.computeWindow(); + if (!header.ack && !header.psh && session.state != TcpStates.FINISH) { + // Nothing to send + LOGGER.info("Established session nothing to send"); + return SessionActions.IGNORE; + } + if (header.psh) { + header.fin = false; + header.write(segment); + // We have something to receive + + // Copy payload (yes, it is easier to prepare payload first) + final int recvPos = receiveBuffer.position(); + final int recvLim = receiveBuffer.limit(); + receiveBuffer.limit(session.nextSegmentMark); + receiveBuffer.position(0); + segment.put(receiveBuffer); + receiveBuffer.limit(recvLim); + receiveBuffer.position(recvPos); + } else { + header.fin = session.state == TcpStates.FINISH; + header.write(segment); + } + segment.flip(); + return SessionActions.FORWARD; + } + + @Override + SessionActions send(final StreamSessionImpl session, final ByteBuffer segment) { + final TcpHeader header = session.header; + final boolean correct = header.read(segment); + if (!correct) { + LOGGER.info("Got invalid TCP header"); + return SessionActions.IGNORE; + } + if (header.syn) { + LOGGER.info("Got syn on established connection"); + return SessionActions.IGNORE; + } + if (header.sequenceNumber != session.vmSequence) { + LOGGER.info("VM sent invalid sequence number (expected {}, got {})", session.vmSequence, header.sequenceNumber); + return SessionActions.IGNORE; + } + final int length = segment.remaining(); + if (header.psh && length > session.computeWindow()) { + LOGGER.info("Received length > window size"); + return SessionActions.IGNORE; + } + if (header.ack) { + // Segment received + if (header.acknowledgmentNumber != session.mySequence && header.acknowledgmentNumber != (session.mySequence + session.nextSegmentMark)) { + LOGGER.info("VM acked wrong number (expected {}, got {})", session.mySequence, header.acknowledgmentNumber); + return SessionActions.IGNORE; + } + if (header.acknowledgmentNumber == (session.mySequence + session.nextSegmentMark)) { + final ByteBuffer receiveBuffer = session.receiveBuffer; + // Remove acknowledged data from buffer + final int newPosition = receiveBuffer.position() - session.nextSegmentMark; + receiveBuffer.position(session.nextSegmentMark); + receiveBuffer.compact(); + receiveBuffer.position(newPosition); + receiveBuffer.limit(receiveBuffer.capacity()); + session.mySequence += session.nextSegmentMark; + session.nextSegmentMark = 0; + } + } + session.vmWindow = header.window; + if (header.psh) { + // Data to be sent + session.vmSequence += length; + final ByteBuffer sendBuffer = session.sendBuffer; + sendBuffer.compact(); + sendBuffer.put(segment); + sendBuffer.flip(); + session.needsAcknowledgment = true; + } + if (header.fin) { + ++session.vmSequence; + session.state = FINISH; + } + return SessionActions.FORWARD; + } + + @Override + States toSessionState() { + return States.ESTABLISHED; + } + }, + FINISH { + @Override + SessionActions receive(final StreamSessionImpl session, final ByteBuffer segment) { + return SessionActions.DROP; + } + + @Override + SessionActions send(final StreamSessionImpl session, final ByteBuffer segment) { + return SessionActions.DROP; + } + + @Override + States toSessionState() { + return States.FINISH; + } + }, + EXPIRED { + @Override + SessionActions receive(final StreamSessionImpl session, final ByteBuffer segment) { + return SessionActions.DROP; + } + + @Override + SessionActions send(final StreamSessionImpl session, final ByteBuffer segment) { + return SessionActions.DROP; + } + + @Override + States toSessionState() { + return States.EXPIRED; + } + }; + + abstract SessionActions receive(StreamSessionImpl session, ByteBuffer segment); + + abstract SessionActions send(StreamSessionImpl session, ByteBuffer segment); + + abstract States toSessionState(); + } } diff --git a/src/main/java/li/cil/oc2/common/inet/TcpHeader.java b/src/main/java/li/cil/oc2/common/inet/TcpHeader.java index d54e598f..cf730da5 100644 --- a/src/main/java/li/cil/oc2/common/inet/TcpHeader.java +++ b/src/main/java/li/cil/oc2/common/inet/TcpHeader.java @@ -30,8 +30,8 @@ public class TcpHeader { sequenceNumber = data.getInt(); acknowledgmentNumber = data.getInt(); final int dataOffset = position + ((data.get() >>> 2) & 0x3C) - 4; - if (dataOffset < data.limit()) { - System.out.println("C"); + if (dataOffset > data.limit()) { + System.out.println("C dataOffset=" + dataOffset + ", data.limit()=" + data.limit()); return false; } final int flags = Byte.toUnsignedInt(data.get()); @@ -82,7 +82,7 @@ public class TcpHeader { public void write(final ByteBuffer data) { data.putInt(sequenceNumber); data.putInt(acknowledgmentNumber); - final int headerLength = MIN_HEADER_SIZE_NO_PORTS + (maxSegmentSize == -1 ? 0 : 4); + final int headerLength = 4 + MIN_HEADER_SIZE_NO_PORTS + (maxSegmentSize == -1 ? 0 : 4); data.put((byte) (headerLength << 2)); final int flags = (bool2int(urg) << 5) | @@ -122,7 +122,11 @@ public class TcpHeader { maxSegmentSize = -1; } - public void denyConnection(final int sequence, final int acknowledgment) { + public boolean isAcceptanceOrRejectionAcknowledged() { + return !syn && !urg && ack && !psh && !rst && !fin; + } + + public void rejectConnection(final int sequence, final int acknowledgment) { sequenceNumber = sequence; acknowledgmentNumber = acknowledgment; urg = false; diff --git a/src/main/scripts/bin/curl b/src/main/scripts/bin/curl new file mode 100755 index 00000000..619a048f Binary files /dev/null and b/src/main/scripts/bin/curl differ