diff --git a/src/main/java/li/cil/oc2/common/block/entity/ComputerTileEntity.java b/src/main/java/li/cil/oc2/common/block/entity/ComputerTileEntity.java index 6bff3715..9cfb0434 100644 --- a/src/main/java/li/cil/oc2/common/block/entity/ComputerTileEntity.java +++ b/src/main/java/li/cil/oc2/common/block/entity/ComputerTileEntity.java @@ -4,6 +4,7 @@ import it.unimi.dsi.fastutil.bytes.ByteArrayFIFOQueue; import li.cil.ceres.api.Serialized; import li.cil.oc2.Constants; import li.cil.oc2.OpenComputers; +import li.cil.oc2.api.bus.Device; import li.cil.oc2.common.block.ComputerBlock; import li.cil.oc2.common.bus.TileEntityDeviceBusController; import li.cil.oc2.common.bus.TileEntityDeviceBusElement; @@ -48,6 +49,7 @@ import javax.annotation.Nullable; import java.io.BufferedInputStream; import java.io.InputStream; import java.nio.ByteBuffer; +import java.util.Set; import java.util.UUID; import static java.util.Objects.requireNonNull; @@ -474,12 +476,23 @@ public final class ComputerTileEntity extends AbstractTileEntity implements ITic @Override protected void onDevicesInvalid() { + runState = RunState.LOADING_DEVICES; virtualMachine.rpcAdapter.pause(); } @Override - protected void onDevicesValid() { - virtualMachine.rpcAdapter.resume(); + protected void onDevicesValid(final boolean didDevicesChange) { + virtualMachine.rpcAdapter.resume(didDevicesChange); + } + + @Override + protected void onDevicesAdded(final Set devices) { + virtualMachine.vmAdapter.addDevices(devices); + } + + @Override + protected void onDevicesRemoved(final Set devices) { + virtualMachine.vmAdapter.removeDevices(devices); } } diff --git a/src/main/java/li/cil/oc2/common/bus/RPCAdapter.java b/src/main/java/li/cil/oc2/common/bus/RPCAdapter.java index 9be012b7..cdbb2d24 100644 --- a/src/main/java/li/cil/oc2/common/bus/RPCAdapter.java +++ b/src/main/java/li/cil/oc2/common/bus/RPCAdapter.java @@ -90,8 +90,10 @@ public final class RPCAdapter implements Steppable { devicesById.clear(); } - public void resume() { - if (!isPaused) { + public void resume(final boolean didDevicesChange) { + isPaused = false; + + if (!didDevicesChange) { return; } @@ -136,8 +138,6 @@ public final class RPCAdapter implements Steppable { devices.add(new RPCDeviceWithIdentifier(identifier, device)); devicesById.put(identifier, device); }); - - isPaused = false; } public void tick() { diff --git a/src/main/java/li/cil/oc2/common/bus/TileEntityDeviceBusController.java b/src/main/java/li/cil/oc2/common/bus/TileEntityDeviceBusController.java index 8cb17b67..7501080a 100644 --- a/src/main/java/li/cil/oc2/common/bus/TileEntityDeviceBusController.java +++ b/src/main/java/li/cil/oc2/common/bus/TileEntityDeviceBusController.java @@ -61,9 +61,6 @@ public abstract class TileEntityDeviceBusController implements DeviceBusControll } elements.clear(); - devices.clear(); - deviceIds.clear(); - scanDelay = 0; // scan as soon as possible state = BusState.SCAN_PENDING; } @@ -72,18 +69,42 @@ public abstract class TileEntityDeviceBusController implements DeviceBusControll public void scanDevices() { onDevicesInvalid(); - devices.clear(); - deviceIds.clear(); - + final HashSet newDevices = new HashSet<>(); + final HashMap> newDeviceIds = new HashMap<>(); for (final DeviceBusElement element : elements) { for (final Device device : element.getLocalDevices()) { - devices.add(device); - element.getDeviceIdentifier(device).ifPresent(identifier -> deviceIds + newDevices.add(device); + element.getDeviceIdentifier(device).ifPresent(identifier -> newDeviceIds .computeIfAbsent(device, unused -> new HashSet<>()).add(identifier)); } } - onDevicesValid(); + final HashSet removedDevices = new HashSet<>(devices); + removedDevices.removeAll(newDevices); + onDevicesRemoved(removedDevices); + + final HashSet addedDevices = new HashSet<>(newDevices); + addedDevices.removeAll(devices); + onDevicesAdded(addedDevices); + + final boolean didDevicesChange = !removedDevices.isEmpty() || !addedDevices.isEmpty(); + final boolean didDeviceIdsChange; + if (didDevicesChange) { + devices.clear(); + devices.addAll(newDevices); + + didDeviceIdsChange = true; + } else { + didDeviceIdsChange = deviceIds.entrySet().stream().anyMatch(entry -> + !Objects.equals(entry.getValue(), newDeviceIds.get(entry.getKey()))); + } + + if (didDeviceIdsChange) { + deviceIds.clear(); + deviceIds.putAll(newDeviceIds); + } + + onDevicesValid(didDevicesChange || didDeviceIdsChange); } @Override @@ -208,7 +229,13 @@ public abstract class TileEntityDeviceBusController implements DeviceBusControll protected void onDevicesInvalid() { } - protected void onDevicesValid() { + protected void onDevicesValid(final boolean didDevicesChange) { + } + + protected void onDevicesAdded(final Set devices) { + } + + protected void onDevicesRemoved(final Set devices) { } /////////////////////////////////////////////////////////////////// diff --git a/src/main/java/li/cil/oc2/common/vm/ManagedInterruptAllocator.java b/src/main/java/li/cil/oc2/common/vm/ManagedInterruptAllocator.java index 3108e3f3..6b75c83b 100644 --- a/src/main/java/li/cil/oc2/common/vm/ManagedInterruptAllocator.java +++ b/src/main/java/li/cil/oc2/common/vm/ManagedInterruptAllocator.java @@ -19,13 +19,14 @@ public final class ManagedInterruptAllocator implements InterruptAllocator { public ManagedInterruptAllocator(final BitSet interrupts, final BitSet reservedInterrupts, final int interruptCount) { this.interrupts = interrupts; this.reservedInterrupts = reservedInterrupts; - this.managedInterrupts = new BitSet(); + this.managedInterrupts = new BitSet(interruptCount); this.interruptCount = interruptCount; } public void freeze() { + final long[] words = managedInterrupts.toLongArray(); + managedMask = words.length > 0 ? (int) words[0] : 0; isFrozen = true; - managedMask = (int) managedInterrupts.toLongArray()[0]; } public void invalidate() { @@ -64,7 +65,11 @@ public final class ManagedInterruptAllocator implements InterruptAllocator { throw new IllegalStateException(); } - final int interruptBit = reservedInterrupts.nextClearBit(0); + final BitSet claimedInterrupts = new BitSet(); + claimedInterrupts.or(interrupts); + claimedInterrupts.or(reservedInterrupts); + + final int interruptBit = claimedInterrupts.nextClearBit(0); if (interruptBit >= interruptCount) { return OptionalInt.empty(); } diff --git a/src/main/java/li/cil/oc2/common/vm/VirtualMachineDeviceBusAdapter.java b/src/main/java/li/cil/oc2/common/vm/VirtualMachineDeviceBusAdapter.java index 5d0837ef..845e9247 100644 --- a/src/main/java/li/cil/oc2/common/vm/VirtualMachineDeviceBusAdapter.java +++ b/src/main/java/li/cil/oc2/common/vm/VirtualMachineDeviceBusAdapter.java @@ -8,7 +8,10 @@ import li.cil.sedna.api.device.InterruptController; import li.cil.sedna.api.memory.MemoryMap; import li.cil.sedna.riscv.device.R5PlatformLevelInterruptController; -import java.util.*; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.HashMap; +import java.util.Set; public final class VirtualMachineDeviceBusAdapter { private final MemoryMap memoryMap; @@ -63,7 +66,14 @@ public final class VirtualMachineDeviceBusAdapter { } } - return incompleteLoads.isEmpty(); + if (!incompleteLoads.isEmpty()) { + return false; + } + + reservedInterrupts.clear(); + reservedInterrupts.or(allocatedInterrupts); + + return true; } public void unload() { @@ -76,31 +86,35 @@ public final class VirtualMachineDeviceBusAdapter { incompleteLoads.addAll(deviceContexts.keySet()); } - public void setDevices(final Collection devices) { - final HashSet oldDevices = new HashSet<>(deviceContexts.keySet()); - final HashSet newDevices = new HashSet<>(); + public void addDevices(final Set devices) { for (final Device device : devices) { if (device instanceof VMDevice) { - newDevices.add((VMDevice) device); + final VMDevice vmDevice = (VMDevice) device; + + final ManagedVMContext context = deviceContexts.put(vmDevice, null); + if (context != null) { + context.invalidate(); + } + + incompleteLoads.add(vmDevice); } } + } - final HashSet removedDevices = new HashSet<>(oldDevices); - removedDevices.removeAll(newDevices); - for (final VMDevice device : removedDevices) { - deviceContexts.remove(device).invalidate(); - incompleteLoads.remove(device); - device.unload(); + public void removeDevices(final Set devices) { + for (final Device device : devices) { + if (device instanceof VMDevice) { + final VMDevice vmDevice = (VMDevice) device; + + final ManagedVMContext context = deviceContexts.remove(vmDevice); + if (context != null) { + context.invalidate(); + } + + incompleteLoads.remove(vmDevice); + + vmDevice.unload(); + } } - - final HashSet addedDevices = new HashSet<>(newDevices); - addedDevices.removeAll(oldDevices); - for (final VMDevice device : addedDevices) { - deviceContexts.put(device, null); - incompleteLoads.add(device); - } - - reservedInterrupts.clear(); - reservedInterrupts.or(allocatedInterrupts); } } diff --git a/src/test/java/li/cil/oc2/common/bus/VMDeviceTests.java b/src/test/java/li/cil/oc2/common/bus/VMDeviceTests.java new file mode 100644 index 00000000..387dd735 --- /dev/null +++ b/src/test/java/li/cil/oc2/common/bus/VMDeviceTests.java @@ -0,0 +1,125 @@ +package li.cil.oc2.common.bus; + +import li.cil.oc2.api.bus.device.vm.VMContext; +import li.cil.oc2.api.bus.device.vm.VMDevice; +import li.cil.oc2.api.bus.device.vm.VMDeviceLoadResult; +import li.cil.oc2.common.vm.VirtualMachineDeviceBusAdapter; +import li.cil.sedna.api.device.InterruptController; +import li.cil.sedna.api.memory.MemoryMap; +import li.cil.sedna.memory.SimpleMemoryMap; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.OptionalInt; + +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +public final class VMDeviceTests { + private MemoryMap memoryMap; + private InterruptController interruptController; + private VirtualMachineDeviceBusAdapter adapter; + + @BeforeEach + public void setupEach() { + memoryMap = new SimpleMemoryMap(); + interruptController = mock(InterruptController.class); + adapter = new VirtualMachineDeviceBusAdapter(memoryMap, interruptController); + } + + @Test + public void addedDevicesHaveLoadCalled() { + final VMDevice device1 = mock(VMDevice.class); + final VMDevice device2 = mock(VMDevice.class); + when(device1.load(any())).thenReturn(VMDeviceLoadResult.success()); + when(device2.load(any())).thenReturn(VMDeviceLoadResult.success()); + + adapter.addDevices(Collections.singleton(device1)); + assertTrue(adapter.load()); + verify(device1).load(any()); + + adapter.addDevices(Collections.singleton(device2)); + assertTrue(adapter.load()); + + verifyNoMoreInteractions(device1); + verify(device2).load(any()); + } + + @Test + public void removedDevicesHaveUnloadCalled() { + final VMDevice device = mock(VMDevice.class); + when(device.load(any())).thenReturn(VMDeviceLoadResult.success()); + + adapter.addDevices(Collections.singleton(device)); + assertTrue(adapter.load()); + + adapter.removeDevices(Collections.singleton(device)); + verify(device).unload(); + } + + @Test + public void devicesHaveUnloadCalledOnGlobalUnload() { + final VMDevice device = mock(VMDevice.class); + when(device.load(any())).thenReturn(VMDeviceLoadResult.success()); + + adapter.addDevices(Collections.singleton(device)); + assertTrue(adapter.load()); + + adapter.unload(); + verify(device).unload(); + } + + @Test + public void devicesHaveLoadCalledAfterGlobalUnload() { + final VMDevice device = mock(VMDevice.class); + when(device.load(any())).thenReturn(VMDeviceLoadResult.success()); + + adapter.addDevices(Collections.singleton(device)); + assertTrue(adapter.load()); + verify(device).load(any()); + + adapter.unload(); + verify(device).unload(); + + assertTrue(adapter.load()); + verify(device, times(2)).load(any()); + } + + @Test + public void deviceCanClaimInterrupts() { + final VMDevice device = mock(VMDevice.class); + when(device.load(any())).thenAnswer(invocation -> { + final VMContext context = invocation.getArgument(0); + final OptionalInt interrupt = context.getInterruptAllocator().claimInterrupt(); + assertTrue(interrupt.isPresent()); + return VMDeviceLoadResult.success(); + }); + + adapter.addDevices(Collections.singleton(device)); + assertTrue(adapter.load()); + + verify(device).load(any()); + } + + @Test + public void deviceCannotClaimClaimedInterrupts() { + final int claimedInterrupt = 1; + + final VMDevice device = mock(VMDevice.class); + when(device.load(any())).thenAnswer(invocation -> { + final VMContext context = invocation.getArgument(0); + final OptionalInt interrupt = context.getInterruptAllocator().claimInterrupt(claimedInterrupt); + assertTrue(interrupt.isPresent()); + assertNotEquals(claimedInterrupt, interrupt.getAsInt()); + return VMDeviceLoadResult.success(); + }); + + adapter.claimInterrupt(claimedInterrupt); + + adapter.addDevices(Collections.singleton(device)); + assertTrue(adapter.load()); + } +} diff --git a/src/test/java/li/cil/oc2/common/vm/RPCAdapterTests.java b/src/test/java/li/cil/oc2/common/vm/RPCAdapterTests.java index 0903e1e1..3da321bf 100644 --- a/src/test/java/li/cil/oc2/common/vm/RPCAdapterTests.java +++ b/src/test/java/li/cil/oc2/common/vm/RPCAdapterTests.java @@ -100,8 +100,7 @@ public class RPCAdapterTests { when(busController.getDeviceIdentifiers(device)).thenReturn(singleton(deviceId)); // trigger device cache rebuild - rpcAdapter.pause(); - rpcAdapter.resume(); + rpcAdapter.resume(true); } private JsonElement invokeMethod(final UUID deviceId, final String name, final Object... parameters) {