From d4fffc40b920ad8cacb6855baa4e6239c4ffd2e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20N=C3=BCcke?= Date: Sat, 12 Dec 2020 12:36:10 +0100 Subject: [PATCH] VM device interrupts zero based since we use them for masking anyway. --- .../common/vm/ManagedInterruptAllocator.java | 22 +++-- .../li/cil/oc2/common/bus/VMDeviceTests.java | 86 ++++++++++++++++++- 2 files changed, 93 insertions(+), 15 deletions(-) diff --git a/src/main/java/li/cil/oc2/common/vm/ManagedInterruptAllocator.java b/src/main/java/li/cil/oc2/common/vm/ManagedInterruptAllocator.java index 6b75c83b..cccc3a8a 100644 --- a/src/main/java/li/cil/oc2/common/vm/ManagedInterruptAllocator.java +++ b/src/main/java/li/cil/oc2/common/vm/ManagedInterruptAllocator.java @@ -44,17 +44,16 @@ public final class ManagedInterruptAllocator implements InterruptAllocator { throw new IllegalStateException(); } - if (interrupt < 1 || interrupt > R5PlatformLevelInterruptController.INTERRUPT_COUNT) { + if (interrupt < 0 || interrupt >= R5PlatformLevelInterruptController.INTERRUPT_COUNT) { throw new IllegalArgumentException(); } - final int interruptBit = interrupt - 1; - if (interrupts.get(interruptBit)) { + if (interrupts.get(interrupt)) { return claimInterrupt(); } else { - interrupts.set(interruptBit); - reservedInterrupts.set(interruptBit); - managedInterrupts.set(interruptBit); + interrupts.set(interrupt); + reservedInterrupts.set(interrupt); + managedInterrupts.set(interrupt); return OptionalInt.of(interrupt); } } @@ -69,16 +68,15 @@ public final class ManagedInterruptAllocator implements InterruptAllocator { claimedInterrupts.or(interrupts); claimedInterrupts.or(reservedInterrupts); - final int interruptBit = claimedInterrupts.nextClearBit(0); - if (interruptBit >= interruptCount) { + final int interrupt = claimedInterrupts.nextClearBit(0); + if (interrupt >= interruptCount) { return OptionalInt.empty(); } - interrupts.set(interruptBit); - reservedInterrupts.set(interruptBit); - managedInterrupts.set(interruptBit); + interrupts.set(interrupt); + reservedInterrupts.set(interrupt); + managedInterrupts.set(interrupt); - final int interrupt = interruptBit + 1; return OptionalInt.of(interrupt); } } diff --git a/src/test/java/li/cil/oc2/common/bus/VMDeviceTests.java b/src/test/java/li/cil/oc2/common/bus/VMDeviceTests.java index 387dd735..28ed1ec5 100644 --- a/src/test/java/li/cil/oc2/common/bus/VMDeviceTests.java +++ b/src/test/java/li/cil/oc2/common/bus/VMDeviceTests.java @@ -7,14 +7,14 @@ import li.cil.oc2.common.vm.VirtualMachineDeviceBusAdapter; import li.cil.sedna.api.device.InterruptController; import li.cil.sedna.api.memory.MemoryMap; import li.cil.sedna.memory.SimpleMemoryMap; +import li.cil.sedna.riscv.device.R5PlatformLevelInterruptController; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.util.Collections; import java.util.OptionalInt; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; @@ -26,7 +26,7 @@ public final class VMDeviceTests { @BeforeEach public void setupEach() { memoryMap = new SimpleMemoryMap(); - interruptController = mock(InterruptController.class); + interruptController = new R5PlatformLevelInterruptController(); adapter = new VirtualMachineDeviceBusAdapter(memoryMap, interruptController); } @@ -122,4 +122,84 @@ public final class VMDeviceTests { adapter.addDevices(Collections.singleton(device)); assertTrue(adapter.load()); } + + @Test + public void deviceCanRaiseClaimedInterrupts() { + final DeviceData deviceData = new DeviceData(); + final VMDevice device = mock(VMDevice.class); + when(device.load(any())).thenAnswer(invocation -> { + final VMContext context = invocation.getArgument(0); + final OptionalInt interrupt = context.getInterruptAllocator().claimInterrupt(); + assertTrue(interrupt.isPresent()); + + deviceData.context = context; + deviceData.interrupt = interrupt.getAsInt(); + + return VMDeviceLoadResult.success(); + }); + + adapter.addDevices(Collections.singleton(device)); + assertTrue(adapter.load()); + + verify(device).load(any()); + + final int claimedInterruptMask = 1 << deviceData.interrupt; + deviceData.context.getInterruptController().raiseInterrupts(claimedInterruptMask); + + assertTrue((interruptController.getRaisedInterrupts() & claimedInterruptMask) != 0); + } + + @Test + public void devicesCannotRaiseUnclaimedInterrupts() { + final DeviceData deviceData = new DeviceData(); + final VMDevice device = mock(VMDevice.class); + when(device.load(any())).thenAnswer(invocation -> { + deviceData.context = invocation.getArgument(0); + return VMDeviceLoadResult.success(); + }); + + adapter.addDevices(Collections.singleton(device)); + assertTrue(adapter.load()); + + verify(device).load(any()); + + final int someInterruptMask = 0x1; + assertThrows(IllegalArgumentException.class, () -> + deviceData.context.getInterruptController().raiseInterrupts(someInterruptMask)); + } + + @Test + public void unloadLowersClaimedInterrupts() { + final DeviceData deviceData = new DeviceData(); + final VMDevice device = mock(VMDevice.class); + when(device.load(any())).thenAnswer(invocation -> { + final VMContext context = invocation.getArgument(0); + final OptionalInt interrupt = context.getInterruptAllocator().claimInterrupt(); + assertTrue(interrupt.isPresent()); + + deviceData.context = context; + deviceData.interrupt = interrupt.getAsInt(); + + return VMDeviceLoadResult.success(); + }); + + adapter.addDevices(Collections.singleton(device)); + assertTrue(adapter.load()); + + verify(device).load(any()); + + final int claimedInterruptMask = 1 << deviceData.interrupt; + deviceData.context.getInterruptController().raiseInterrupts(claimedInterruptMask); + + assertTrue((interruptController.getRaisedInterrupts() & claimedInterruptMask) != 0); + + adapter.unload(); + + assertFalse((interruptController.getRaisedInterrupts() & claimedInterruptMask) != 0); + } + + private static final class DeviceData { + public VMContext context; + public int interrupt; + } }