Support nested arrays in serialization.

This commit is contained in:
Florian Nücke
2020-12-18 19:35:24 +01:00
parent 17e54fdbb9
commit 7ce0eccd43

View File

@@ -1,17 +1,19 @@
package li.cil.oc2.common.serialization;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import li.cil.ceres.Ceres;
import li.cil.ceres.api.DeserializationVisitor;
import li.cil.ceres.api.SerializationException;
import li.cil.ceres.api.SerializationVisitor;
import li.cil.oc2.common.util.NBTTagIds;
import net.minecraft.nbt.CompoundNBT;
import net.minecraft.nbt.ListNBT;
import net.minecraft.nbt.StringNBT;
import net.minecraft.nbt.*;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import javax.annotation.Nullable;
import java.lang.reflect.Array;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
public final class NBTSerialization {
@@ -42,6 +44,22 @@ public final class NBTSerialization {
///////////////////////////////////////////////////////////////////
private static final String IS_NULL_KEY = "<is_null>";
private static final Map<Class<?>, ArraySerializer> ARRAY_SERIALIZERS;
static {
ARRAY_SERIALIZERS = new HashMap<>();
ARRAY_SERIALIZERS.put(boolean.class, new BooleanArraySerializer());
ARRAY_SERIALIZERS.put(byte.class, new ByteArraySerializer());
ARRAY_SERIALIZERS.put(char.class, new CharArraySerializer());
ARRAY_SERIALIZERS.put(short.class, new ShortArraySerializer());
ARRAY_SERIALIZERS.put(int.class, new IntArraySerializer());
ARRAY_SERIALIZERS.put(long.class, new LongArraySerializer());
ARRAY_SERIALIZERS.put(float.class, new FloatArraySerializer());
ARRAY_SERIALIZERS.put(double.class, new DoubleArraySerializer());
ARRAY_SERIALIZERS.put(Enum.class, new EnumArraySerializer());
ARRAY_SERIALIZERS.put(String.class, new StringArraySerializer());
ARRAY_SERIALIZERS.put(UUID.class, new UUIDArraySerializer());
}
private static final class Serializer implements SerializationVisitor {
private final CompoundNBT nbt;
@@ -98,88 +116,7 @@ public final class NBTSerialization {
}
if (type.isArray()) {
final Class<?> componentType = type.getComponentType();
if (componentType == boolean.class) {
final boolean[] data = (boolean[]) value;
final byte[] convertedData = new byte[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = data[i] ? (byte) 1 : (byte) 0;
}
nbt.putByteArray(name, convertedData);
} else if (componentType == byte.class) {
nbt.putByteArray(name, (byte[]) value);
} else if (componentType == char.class) {
final char[] data = (char[]) value;
final int[] convertedData = new int[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = data[i];
}
nbt.putIntArray(name, convertedData);
} else if (componentType == short.class) {
final short[] data = (short[]) value;
final int[] convertedData = new int[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = data[i];
}
nbt.putIntArray(name, convertedData);
} else if (componentType == int.class) {
nbt.putIntArray(name, (int[]) value);
} else if (componentType == long.class) {
nbt.putLongArray(name, (long[]) value);
} else if (componentType == float.class) {
final float[] data = (float[]) value;
final int[] convertedData = new int[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = Float.floatToRawIntBits(data[i]);
}
nbt.putIntArray(name, convertedData);
} else if (componentType == double.class) {
final double[] data = (double[]) value;
final long[] convertedData = new long[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = Double.doubleToRawLongBits(data[i]);
}
nbt.putLongArray(name, convertedData);
} else if (componentType.isEnum()) {
final Enum[] data = (Enum[]) value;
final int[] convertedData = new int[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = data[i].ordinal();
}
nbt.putIntArray(name, convertedData);
} else if (componentType == UUID.class) {
final UUID[] data = (UUID[]) value;
final ListNBT list = new ListNBT();
for (final UUID datum : data) {
list.add(StringNBT.valueOf(datum.toString()));
}
nbt.put(name, list);
} else if (componentType == String.class) {
final String[] data = (String[]) value;
final ListNBT list = new ListNBT();
for (final String datum : data) {
list.add(StringNBT.valueOf(datum));
}
nbt.put(name, list);
} else {
final li.cil.ceres.api.Serializer<?> serializer = Ceres.getSerializer(componentType);
final Object[] data = (Object[]) value;
final ListNBT listNBT = new ListNBT();
for (final Object datum : data) {
final CompoundNBT itemNBT = new CompoundNBT();
if (datum == null) {
itemNBT.putBoolean(IS_NULL_KEY, true);
} else {
if (datum.getClass() != componentType) {
throw new SerializationException(String.format("Polymorphism detected in generic array [%s]. This is not supported.", name));
}
serializer.serialize(new Serializer(itemNBT), (Class) componentType, datum);
}
listNBT.add(itemNBT);
}
nbt.put(name, listNBT);
}
nbt.put(name, putArray(name, type, value));
} else if (type.isEnum()) {
nbt.putString(name, ((Enum) value).name());
} else if (type == String.class) {
@@ -197,6 +134,59 @@ public final class NBTSerialization {
}
}
@FunctionalInterface
private interface ArrayComponentSerializer {
INBT serialize(Class<?> type, Object value);
}
@SuppressWarnings({"unchecked", "rawtypes"})
private INBT putArray(final String name, final Class<?> type, final @NotNull Object value) {
final Class<?> componentType = type.getComponentType();
final ArraySerializer arraySerializer = ARRAY_SERIALIZERS.get(componentType);
if (arraySerializer != null) {
return arraySerializer.serialize(value);
} else {
final ArrayComponentSerializer componentSerializer;
if (componentType.isArray()) {
componentSerializer = (t, v) -> putArray(name, t, v);
} else {
final li.cil.ceres.api.Serializer<?> serializer = Ceres.getSerializer(componentType);
componentSerializer = (t, v) -> {
final CompoundNBT nbt = new CompoundNBT();
serializer.serialize(new Serializer(nbt), (Class) t, v);
return nbt;
};
}
final ListNBT listNBT = new ListNBT();
final IntArrayList nullIndices = new IntArrayList();
final Object[] data = (Object[]) value;
for (int i = 0; i < data.length; i++) {
final Object datum = data[i];
if (datum == null) {
nullIndices.add(i);
} else {
if (datum.getClass() != componentType) {
throw new SerializationException(String.format("Polymorphism detected in generic array [%s]. This is not supported.", name));
}
listNBT.add(componentSerializer.serialize(componentType, datum));
}
}
if (nullIndices.isEmpty()) {
return listNBT;
} else {
final CompoundNBT arrayNbt = new CompoundNBT();
arrayNbt.put("value", listNBT);
arrayNbt.putIntArray("nulls", nullIndices);
return arrayNbt;
}
}
}
@Contract(value = "_, null -> true")
private boolean putIsNull(final String name, @Nullable final Object value) {
final boolean isNull = value == null;
@@ -270,155 +260,9 @@ public final class NBTSerialization {
}
if (type.isArray()) {
final Class<?> componentType = type.getComponentType();
if (componentType == boolean.class) {
boolean[] data = (boolean[]) into;
if (nbt.contains(name, NBTTagIds.TAG_BYTE_ARRAY)) {
final byte[] convertedData = nbt.getByteArray(name);
if (data == null || data.length != convertedData.length) {
data = new boolean[convertedData.length];
}
for (int i = 0; i < convertedData.length; i++) {
data[i] = convertedData[i] != 0;
}
}
return data;
} else if (componentType == byte.class) {
final byte[] data = (byte[]) into;
if (nbt.contains(name, NBTTagIds.TAG_BYTE_ARRAY)) {
final byte[] serializedData = nbt.getByteArray(name);
if (data == null || data.length != serializedData.length) {
return serializedData;
}
System.arraycopy(serializedData, 0, data, 0, serializedData.length);
}
return data;
} else if (componentType == char.class) {
char[] data = (char[]) into;
if (nbt.contains(name, NBTTagIds.TAG_INT_ARRAY)) {
final int[] convertedData = nbt.getIntArray(name);
if (data == null || data.length != convertedData.length) {
data = new char[convertedData.length];
}
for (int i = 0; i < convertedData.length; i++) {
data[i] = (char) convertedData[i];
}
}
return data;
} else if (componentType == short.class) {
short[] data = (short[]) into;
if (nbt.contains(name, NBTTagIds.TAG_INT_ARRAY)) {
final int[] convertedData = nbt.getIntArray(name);
if (data == null || data.length != convertedData.length) {
data = new short[convertedData.length];
}
for (int i = 0; i < convertedData.length; i++) {
data[i] = (short) convertedData[i];
}
}
return data;
} else if (componentType == int.class) {
final int[] data = (int[]) into;
if (nbt.contains(name, NBTTagIds.TAG_INT_ARRAY)) {
final int[] serializedData = nbt.getIntArray(name);
if (data == null || data.length != serializedData.length) {
return serializedData;
}
System.arraycopy(serializedData, 0, data, 0, serializedData.length);
}
return data;
} else if (componentType == long.class) {
final long[] data = (long[]) into;
if (nbt.contains(name, NBTTagIds.TAG_LONG_ARRAY)) {
final long[] serializedData = nbt.getLongArray(name);
if (data == null || data.length != serializedData.length) {
return serializedData;
}
System.arraycopy(serializedData, 0, data, 0, serializedData.length);
}
return data;
} else if (componentType == float.class) {
float[] data = (float[]) into;
if (nbt.contains(name, NBTTagIds.TAG_INT_ARRAY)) {
final int[] convertedData = nbt.getIntArray(name);
if (data == null || data.length != convertedData.length) {
data = new float[convertedData.length];
}
for (int i = 0; i < convertedData.length; i++) {
data[i] = Float.intBitsToFloat(convertedData[i]);
}
}
return data;
} else if (componentType == double.class) {
double[] data = (double[]) into;
if (nbt.contains(name, NBTTagIds.TAG_LONG_ARRAY)) {
final long[] convertedData = nbt.getLongArray(name);
if (data == null || data.length != convertedData.length) {
data = new double[convertedData.length];
}
for (int i = 0; i < convertedData.length; i++) {
data[i] = Double.longBitsToDouble(convertedData[i]);
}
}
return data;
} else if (componentType.isEnum()) {
Enum[] data = (Enum[]) into;
if (nbt.contains(name, NBTTagIds.TAG_INT_ARRAY)) {
final int[] serializedData = nbt.getIntArray(name);
if (data == null || data.length != serializedData.length) {
data = (Enum[]) Array.newInstance(componentType, serializedData.length);
}
for (int i = 0; i < serializedData.length; i++) {
data[i] = (Enum) componentType.getEnumConstants()[serializedData[i]];
}
}
return data;
} else if (componentType == String.class) {
String[] data = (String[]) into;
if (nbt.contains(name, NBTTagIds.TAG_LIST)) {
final ListNBT serializedData = nbt.getList(name, NBTTagIds.TAG_STRING);
if (data == null || data.length != serializedData.size()) {
data = new String[serializedData.size()];
}
for (int i = 0; i < serializedData.size(); i++) {
data[i] = serializedData.getString(i);
}
}
return data;
} else if (componentType == UUID.class) {
UUID[] data = (UUID[]) into;
if (nbt.contains(name, NBTTagIds.TAG_LIST)) {
final ListNBT serializedData = nbt.getList(name, NBTTagIds.TAG_STRING);
if (data == null || data.length != serializedData.size()) {
data = new UUID[serializedData.size()];
}
for (int i = 0; i < serializedData.size(); i++) {
data[i] = UUID.fromString(serializedData.getString(i));
}
}
return data;
} else {
Object[] data = (Object[]) into;
if (nbt.contains(name, NBTTagIds.TAG_LIST)) {
final ListNBT listNBT = nbt.getList(name, NBTTagIds.TAG_COMPOUND);
final int length = listNBT.size();
if (data == null || data.length != length) {
data = (Object[]) Array.newInstance(componentType, length);
}
final li.cil.ceres.api.Serializer<?> serializer = Ceres.getSerializer(componentType);
for (int i = 0; i < length; i++) {
final CompoundNBT itemNBT = listNBT.getCompound(i);
if (itemNBT.contains(IS_NULL_KEY)) {
continue;
}
data[i] = serializer.deserialize(new Deserializer(itemNBT), (Class) componentType, data[i]);
}
}
return data;
}
final INBT arrayNbt = nbt.get(name);
assert arrayNbt != null;
return getArray(arrayNbt, type, into);
} else if (type.isEnum()) {
return Enum.valueOf((Class) type, nbt.getString(name));
} else if (type == String.class) {
@@ -431,6 +275,70 @@ public final class NBTSerialization {
}
}
@FunctionalInterface
private interface ArrayComponentDeserializer {
@Nullable
Object deserialize(INBT nbt, Class<?> type, @Nullable Object into);
}
@SuppressWarnings({"unchecked", "rawtypes"})
@Nullable
private static Object getArray(final INBT nbt, final Class<?> type, final @Nullable Object into) {
final Class<?> componentType = type.getComponentType();
final ArraySerializer arraySerializer = ARRAY_SERIALIZERS.get(componentType);
if (arraySerializer != null) {
return arraySerializer.deserialize(nbt, type, into);
} else {
final ArrayComponentDeserializer componentDeserializer;
if (componentType.isArray()) {
componentDeserializer = Deserializer::getArray;
} else {
final li.cil.ceres.api.Serializer<?> serializer = Ceres.getSerializer(componentType);
componentDeserializer = (n, t, i) -> serializer.deserialize(new Deserializer((CompoundNBT) n), (Class) t, i);
}
Object[] data = (Object[]) into;
final ListNBT listNBT;
final int[] nulls;
int nullsIndex = 0;
if (nbt instanceof ListNBT) {
listNBT = (ListNBT) nbt;
nulls = new int[0];
} else if (nbt instanceof CompoundNBT) {
listNBT = (ListNBT) ((CompoundNBT) nbt).get("value");
nulls = ((CompoundNBT) nbt).getIntArray("nulls");
} else {
return data;
}
if (listNBT == null) {
return data;
}
final int length = listNBT.size() + nulls.length;
if (data == null || data.length != length) {
data = (Object[]) Array.newInstance(componentType, length);
}
for (int i = 0; i < length; i++) {
if (nullsIndex < nulls.length && i == nulls[nullsIndex]) {
nullsIndex++;
continue;
}
final INBT itemNBT = listNBT.get(i - nullsIndex);
if (itemNBT == null) {
continue;
}
data[i] = componentDeserializer.deserialize(itemNBT, componentType, data[i]);
}
return data;
}
}
@Override
public boolean exists(final String name) {
return nbt.contains(name);
@@ -440,4 +348,293 @@ public final class NBTSerialization {
return nbt.getCompound(name).getBoolean(IS_NULL_KEY);
}
}
///////////////////////////////////////////////////////////////////
private interface ArraySerializer {
INBT serialize(Object value);
@Nullable
Object deserialize(INBT nbt, final Class<?> type, @Nullable final Object into);
}
private static final class BooleanArraySerializer implements ArraySerializer {
@Override
public INBT serialize(final Object value) {
final boolean[] data = (boolean[]) value;
final byte[] convertedData = new byte[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = data[i] ? (byte) 1 : (byte) 0;
}
return new ByteArrayNBT(convertedData);
}
@Override
public Object deserialize(final INBT nbt, final Class<?> type, @Nullable final Object into) {
boolean[] data = (boolean[]) into;
if (nbt instanceof ByteArrayNBT) {
final byte[] convertedData = ((ByteArrayNBT) nbt).getByteArray();
if (data == null || data.length != convertedData.length) {
data = new boolean[convertedData.length];
}
for (int i = 0; i < convertedData.length; i++) {
data[i] = convertedData[i] != 0;
}
}
return data;
}
}
private static final class ByteArraySerializer implements ArraySerializer {
@Override
public INBT serialize(final Object value) {
return new ByteArrayNBT((byte[]) value);
}
@Override
public Object deserialize(final INBT nbt, final Class<?> type, @Nullable final Object into) {
final byte[] data = (byte[]) into;
if (nbt instanceof ByteArrayNBT) {
final byte[] serializedData = ((ByteArrayNBT) nbt).getByteArray();
if (data == null || data.length != serializedData.length) {
return serializedData;
}
System.arraycopy(serializedData, 0, data, 0, serializedData.length);
}
return data;
}
}
private static final class CharArraySerializer implements ArraySerializer {
@Override
public INBT serialize(final Object value) {
final char[] data = (char[]) value;
final int[] convertedData = new int[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = data[i];
}
return new IntArrayNBT(convertedData);
}
@Override
public Object deserialize(final INBT nbt, final Class<?> type, @Nullable final Object into) {
char[] data = (char[]) into;
if (nbt instanceof IntArrayNBT) {
final int[] convertedData = ((IntArrayNBT) nbt).getIntArray();
if (data == null || data.length != convertedData.length) {
data = new char[convertedData.length];
}
for (int i = 0; i < convertedData.length; i++) {
data[i] = (char) convertedData[i];
}
}
return data;
}
}
private static final class ShortArraySerializer implements ArraySerializer {
@Override
public INBT serialize(final Object value) {
final short[] data = (short[]) value;
final int[] convertedData = new int[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = data[i];
}
return new IntArrayNBT(convertedData);
}
@Override
public Object deserialize(final INBT nbt, final Class<?> type, @Nullable final Object into) {
short[] data = (short[]) into;
if (nbt instanceof IntArrayNBT) {
final int[] convertedData = ((IntArrayNBT) nbt).getIntArray();
if (data == null || data.length != convertedData.length) {
data = new short[convertedData.length];
}
for (int i = 0; i < convertedData.length; i++) {
data[i] = (short) convertedData[i];
}
}
return data;
}
}
private static final class IntArraySerializer implements ArraySerializer {
@Override
public INBT serialize(final Object value) {
return new IntArrayNBT((int[]) value);
}
@Override
public Object deserialize(final INBT nbt, final Class<?> type, @Nullable final Object into) {
final int[] data = (int[]) into;
if (nbt instanceof IntArrayNBT) {
final int[] serializedData = ((IntArrayNBT) nbt).getIntArray();
if (data == null || data.length != serializedData.length) {
return serializedData;
}
System.arraycopy(serializedData, 0, data, 0, serializedData.length);
}
return data;
}
}
private static final class LongArraySerializer implements ArraySerializer {
@Override
public INBT serialize(final Object value) {
return new LongArrayNBT((long[]) value);
}
@Override
public Object deserialize(final INBT nbt, final Class<?> type, @Nullable final Object into) {
final long[] data = (long[]) into;
if (nbt instanceof LongArrayNBT) {
final long[] serializedData = ((LongArrayNBT) nbt).getAsLongArray();
if (data == null || data.length != serializedData.length) {
return serializedData;
}
System.arraycopy(serializedData, 0, data, 0, serializedData.length);
}
return data;
}
}
private static final class FloatArraySerializer implements ArraySerializer {
@Override
public INBT serialize(final Object value) {
final float[] data = (float[]) value;
final int[] convertedData = new int[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = Float.floatToRawIntBits(data[i]);
}
return new IntArrayNBT(convertedData);
}
@Override
public Object deserialize(final INBT nbt, final Class<?> type, @Nullable final Object into) {
float[] data = (float[]) into;
if (nbt instanceof IntArrayNBT) {
final int[] convertedData = ((IntArrayNBT) nbt).getIntArray();
if (data == null || data.length != convertedData.length) {
data = new float[convertedData.length];
}
for (int i = 0; i < convertedData.length; i++) {
data[i] = Float.intBitsToFloat(convertedData[i]);
}
}
return data;
}
}
private static final class DoubleArraySerializer implements ArraySerializer {
@Override
public INBT serialize(final Object value) {
final double[] data = (double[]) value;
final long[] convertedData = new long[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = Double.doubleToRawLongBits(data[i]);
}
return new LongArrayNBT(convertedData);
}
@Override
public Object deserialize(final INBT nbt, final Class<?> type, @Nullable final Object into) {
double[] data = (double[]) into;
if (nbt instanceof LongArrayNBT) {
final long[] convertedData = ((LongArrayNBT) nbt).getAsLongArray();
if (data == null || data.length != convertedData.length) {
data = new double[convertedData.length];
}
for (int i = 0; i < convertedData.length; i++) {
data[i] = Double.longBitsToDouble(convertedData[i]);
}
}
return data;
}
}
@SuppressWarnings("rawtypes")
private static final class EnumArraySerializer implements ArraySerializer {
@Override
public INBT serialize(final Object value) {
final Enum[] data = (Enum[]) value;
final int[] convertedData = new int[data.length];
for (int i = 0; i < data.length; i++) {
convertedData[i] = data[i].ordinal();
}
return new IntArrayNBT(convertedData);
}
@Override
public Object deserialize(final INBT nbt, final Class<?> type, @Nullable final Object into) {
final Class<?> componentType = type.getComponentType();
final Object[] enumConstants = componentType.getEnumConstants();
Enum[] data = (Enum[]) into;
if (nbt instanceof IntArrayNBT) {
final int[] serializedData = ((IntArrayNBT) nbt).getIntArray();
if (data == null || data.length != serializedData.length) {
data = (Enum[]) Array.newInstance(componentType, serializedData.length);
}
for (int i = 0; i < serializedData.length; i++) {
data[i] = (Enum) enumConstants[serializedData[i]];
}
}
return data;
}
}
private static final class StringArraySerializer implements ArraySerializer {
@Override
public INBT serialize(final Object value) {
final String[] data = (String[]) value;
final ListNBT list = new ListNBT();
for (final String datum : data) {
list.add(StringNBT.valueOf(datum));
}
return list;
}
@Override
public Object deserialize(final INBT nbt, final Class<?> type, @Nullable final Object into) {
String[] data = (String[]) into;
if (nbt instanceof ListNBT && ((ListNBT) nbt).getTagType() == NBTTagIds.TAG_STRING) {
final ListNBT serializedData = (ListNBT) nbt;
if (data == null || data.length != serializedData.size()) {
data = new String[serializedData.size()];
}
for (int i = 0; i < serializedData.size(); i++) {
data[i] = serializedData.getString(i);
}
}
return data;
}
}
private static final class UUIDArraySerializer implements ArraySerializer {
@Override
public INBT serialize(final Object value) {
final UUID[] data = (UUID[]) value;
final ListNBT list = new ListNBT();
for (final UUID datum : data) {
list.add(StringNBT.valueOf(datum.toString()));
}
return list;
}
@Override
public Object deserialize(final INBT nbt, final Class<?> type, @Nullable final Object into) {
UUID[] data = (UUID[]) into;
if (nbt instanceof ListNBT && ((ListNBT) nbt).getTagType() == NBTTagIds.TAG_STRING) {
final ListNBT serializedData = (ListNBT) nbt;
if (data == null || data.length != serializedData.size()) {
data = new UUID[serializedData.size()];
}
for (int i = 0; i < serializedData.size(); i++) {
data[i] = UUID.fromString(serializedData.getString(i));
}
}
return data;
}
}
}