Skip to content

Commit 53822e6

Browse files
committed
feat: Java Bindings for FixedSizeList
Signed-off-by: JingsongLi <jingsonglee0@gmail.com>
1 parent 66236f8 commit 53822e6

File tree

8 files changed

+212
-8
lines changed

8 files changed

+212
-8
lines changed

java/vortex-jni/src/main/java/dev/vortex/api/DType.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,15 @@ public interface DType extends AutoCloseable {
4646
List<DType> getFieldTypes();
4747

4848
/**
49-
* Get the element type for a LIST type.
49+
* Get the element type for a LIST or FIXED_SIZE_LIST type.
5050
*/
5151
DType getElementType();
5252

53+
/**
54+
* Get the fixed size for a FIXED_SIZE_LIST type.
55+
*/
56+
int getFixedSizeListSize();
57+
5358
/**
5459
* Checks if this data type represents a date.
5560
*
@@ -234,6 +239,19 @@ static DType newList(DType element, boolean isNullable) {
234239
return new JNIDType(NativeDTypeMethods.newList(jniType.getPointer(), isNullable), true);
235240
}
236241

242+
/**
243+
* Create a new FixedSizeList data type.
244+
*
245+
* @param element DType of the list elements
246+
* @param size The fixed size of each list
247+
* @param isNullable True if the values can be null
248+
* @return The new DType instance, allocated in native heap memory
249+
*/
250+
static DType newFixedSizeList(DType element, int size, boolean isNullable) {
251+
JNIDType jniType = (JNIDType) element;
252+
return new JNIDType(NativeDTypeMethods.newFixedSizeList(jniType.getPointer(), size, isNullable), true);
253+
}
254+
237255
/**
238256
* Create a new Struct data type.
239257
*
@@ -467,12 +485,17 @@ enum Variant {
467485
* Decimal type for precise numeric values
468486
*/
469487
DECIMAL,
488+
489+
/**
490+
* Fixed-size list type containing a fixed number of elements of a single type
491+
*/
492+
FIXED_SIZE_LIST,
470493
;
471494

472495
/**
473496
* Converts a byte value to the corresponding Variant enum.
474497
*
475-
* @param variant the byte value representing the variant (0-18)
498+
* @param variant the byte value representing the variant (0-19)
476499
* @return the corresponding {@link Variant} enum value
477500
* @throws RuntimeException if the variant value is not recognized
478501
*/
@@ -516,6 +539,8 @@ public static Variant from(byte variant) {
516539
return EXTENSION;
517540
case 18:
518541
return DECIMAL;
542+
case 19:
543+
return FIXED_SIZE_LIST;
519544
default:
520545
throw new IllegalArgumentException("Unknown DType variant: " + variant);
521546
}

java/vortex-jni/src/main/java/dev/vortex/jni/JNIDType.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ public DType getElementType() {
6464
return new JNIDType(NativeDTypeMethods.getElementType(pointer.getAsLong()));
6565
}
6666

67+
@Override
68+
public int getFixedSizeListSize() {
69+
return NativeDTypeMethods.getFixedSizeListSize(pointer.getAsLong());
70+
}
71+
6772
@Override
6873
public boolean isDate() {
6974
return NativeDTypeMethods.isDate(pointer.getAsLong());

java/vortex-jni/src/main/java/dev/vortex/jni/NativeDTypeMethods.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,16 @@ private NativeDTypeMethods() {}
103103
*/
104104
public static native long newList(long elementTypePtr, boolean isNullable);
105105

106+
/**
107+
* Create a new native DType for a FixedSizeList type. The created object lives in native memory.
108+
*
109+
* @param elementTypePtr A native pointer to a DType containing the type of the elements
110+
* @param size The fixed size of each list
111+
* @param isNullable true if the values can be null
112+
* @return Pointer to a new heap-allocated {@code DType}.
113+
*/
114+
public static native long newFixedSizeList(long elementTypePtr, int size, boolean isNullable);
115+
106116
/**
107117
* Create a new native DType for a Struct type. The created object lives in native memory.
108118
*
@@ -154,6 +164,8 @@ private NativeDTypeMethods() {}
154164

155165
public static native long getElementType(long pointer);
156166

167+
public static native int getFixedSizeListSize(long pointer);
168+
157169
public static native boolean isDate(long pointer);
158170

159171
public static native boolean isTime(long pointer);
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
package dev.vortex.api;
5+
6+
import static org.junit.jupiter.api.Assertions.*;
7+
8+
import org.junit.jupiter.api.Test;
9+
10+
public final class DTypeTest {
11+
12+
@Test
13+
public void testNewFixedSizeListNonNullable() {
14+
var elementType = DType.newInt(false);
15+
var fslType = DType.newFixedSizeList(elementType, 3, false);
16+
assertEquals(DType.Variant.FIXED_SIZE_LIST, fslType.getVariant());
17+
assertFalse(fslType.isNullable());
18+
assertEquals(3, fslType.getFixedSizeListSize());
19+
20+
var innerType = fslType.getElementType();
21+
assertEquals(DType.Variant.PRIMITIVE_I32, innerType.getVariant());
22+
}
23+
24+
@Test
25+
public void testNewFixedSizeListNullable() {
26+
var elementType = DType.newUtf8(true);
27+
var fslType = DType.newFixedSizeList(elementType, 5, true);
28+
assertEquals(DType.Variant.FIXED_SIZE_LIST, fslType.getVariant());
29+
assertTrue(fslType.isNullable());
30+
assertEquals(5, fslType.getFixedSizeListSize());
31+
32+
var innerType = fslType.getElementType();
33+
assertEquals(DType.Variant.UTF8, innerType.getVariant());
34+
}
35+
36+
@Test
37+
public void testNewListGetElementType() {
38+
var elementType = DType.newDouble(false);
39+
var listType = DType.newList(elementType, false);
40+
assertEquals(DType.Variant.LIST, listType.getVariant());
41+
42+
var innerType = listType.getElementType();
43+
assertEquals(DType.Variant.PRIMITIVE_F64, innerType.getVariant());
44+
}
45+
46+
@Test
47+
public void testNestedFixedSizeList() {
48+
var innerElement = DType.newLong(false);
49+
var innerFsl = DType.newFixedSizeList(innerElement, 2, false);
50+
var outerFsl = DType.newFixedSizeList(innerFsl, 4, true);
51+
assertEquals(DType.Variant.FIXED_SIZE_LIST, outerFsl.getVariant());
52+
assertTrue(outerFsl.isNullable());
53+
assertEquals(4, outerFsl.getFixedSizeListSize());
54+
55+
var inner = outerFsl.getElementType();
56+
assertEquals(DType.Variant.FIXED_SIZE_LIST, inner.getVariant());
57+
}
58+
59+
@Test
60+
public void testFixedSizeListInStruct() {
61+
var elementType = DType.newFloat(false);
62+
var fslType = DType.newFixedSizeList(elementType, 3, false);
63+
var structType = DType.newStruct(
64+
new String[] {"id", "embedding"},
65+
new DType[] {DType.newInt(false), fslType},
66+
false);
67+
assertEquals(DType.Variant.STRUCT, structType.getVariant());
68+
69+
var fieldTypes = structType.getFieldTypes();
70+
assertEquals(2, fieldTypes.size());
71+
72+
var embeddingType = fieldTypes.get(1);
73+
assertEquals(DType.Variant.FIXED_SIZE_LIST, embeddingType.getVariant());
74+
}
75+
}

java/vortex-spark/src/main/java/dev/vortex/spark/SparkTypes.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ public static DataType toDataType(DType dType) {
121121

122122
return DataTypes.createStructType(fields);
123123
case LIST:
124+
case FIXED_SIZE_LIST:
124125
return DataTypes.createArrayType(toDataType(dType.getElementType()), dType.isNullable());
125126
case EXTENSION:
126127
/*
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
package dev.vortex.spark;
5+
6+
import static org.junit.jupiter.api.Assertions.*;
7+
8+
import dev.vortex.api.DType;
9+
import dev.vortex.jni.NativeLoader;
10+
import org.apache.spark.sql.types.ArrayType;
11+
import org.apache.spark.sql.types.DataTypes;
12+
import org.junit.jupiter.api.BeforeAll;
13+
import org.junit.jupiter.api.DisplayName;
14+
import org.junit.jupiter.api.Test;
15+
16+
public final class SparkTypesTest {
17+
18+
@BeforeAll
19+
public static void loadLibrary() {
20+
NativeLoader.loadJni();
21+
}
22+
23+
@Test
24+
@DisplayName("toDataType should convert FIXED_SIZE_LIST to Spark ArrayType")
25+
public void testFixedSizeListToDataType() {
26+
var elementType = DType.newInt(false);
27+
var fslType = DType.newFixedSizeList(elementType, 3, true);
28+
var sparkType = SparkTypes.toDataType(fslType);
29+
assertInstanceOf(ArrayType.class, sparkType);
30+
ArrayType arrayType = (ArrayType) sparkType;
31+
assertEquals(DataTypes.IntegerType, arrayType.elementType());
32+
}
33+
34+
@Test
35+
@DisplayName("toDataType should convert LIST to Spark ArrayType")
36+
public void testListToDataType() {
37+
var elementType = DType.newDouble(false);
38+
var listType = DType.newList(elementType, true);
39+
var sparkType = SparkTypes.toDataType(listType);
40+
assertInstanceOf(ArrayType.class, sparkType);
41+
ArrayType arrayType = (ArrayType) sparkType;
42+
assertEquals(DataTypes.DoubleType, arrayType.elementType());
43+
}
44+
}

vortex-jni/src/array.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,12 @@ fn data_type_no_views(data_type: DataType) -> DataType {
158158
}
159159
DataType::Decimal128(precision, scale) => DataType::Decimal128(precision, scale),
160160
DataType::Decimal256(precision, scale) => DataType::Decimal256(precision, scale),
161-
DataType::FixedSizeList(..) => unreachable!("Vortex never returns FixedSizeList"),
161+
DataType::FixedSizeList(inner, size) => {
162+
let new_inner = (*inner)
163+
.clone()
164+
.with_data_type(data_type_no_views(inner.data_type().clone()));
165+
DataType::FixedSizeList(FieldRef::new(new_inner), size)
166+
}
162167
DataType::Union(..) => unreachable!("Vortex never returns Union"),
163168
DataType::Dictionary(..) => unreachable!("Vortex never returns Dictionary"),
164169
DataType::Map(..) => unreachable!("Vortex never returns Map"),

vortex-jni/src/dtype.rs

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pub const DTYPE_STRUCT: jbyte = 15;
5252
pub const DTYPE_LIST: jbyte = 16;
5353
pub const DTYPE_EXTENSION: jbyte = 17;
5454
pub const DTYPE_DECIMAL: jbyte = 18;
55+
pub const DTYPE_FIXED_SIZE_LIST: jbyte = 19;
5556

5657
static LONG_CLASS: &str = "java/lang/Long";
5758

@@ -94,9 +95,7 @@ pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_getVariant(
9495
DType::Binary(_) => DTYPE_BINARY,
9596
DType::Struct(..) => DTYPE_STRUCT,
9697
DType::List(..) => DTYPE_LIST,
97-
DType::FixedSizeList(..) => {
98-
unimplemented!("TODO(connor)[FixedSizeList]")
99-
}
98+
DType::FixedSizeList(..) => DTYPE_FIXED_SIZE_LIST,
10099
DType::Extension(_) => DTYPE_EXTENSION,
101100
DType::Variant(_) => unimplemented!("Variant DType is not supported in JNI yet"),
102101
}
@@ -184,8 +183,11 @@ pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_getElementType(
184183
let dtype = unsafe { &*(dtype_ptr as *const DType) };
185184

186185
try_or_throw(&mut env, |_| {
187-
let Some(element_type) = dtype.as_list_element_opt() else {
188-
throw_runtime!("DType should be LIST, was {dtype}");
186+
let element_type = dtype
187+
.as_list_element_opt()
188+
.or_else(|| dtype.as_fixed_size_list_element_opt());
189+
let Some(element_type) = element_type else {
190+
throw_runtime!("DType should be LIST or FIXED_SIZE_LIST, was {dtype}");
189191
};
190192

191193
Ok(element_type.as_ref() as *const DType as jlong)
@@ -506,6 +508,41 @@ pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_newList(
506508
Box::into_raw(Box::new(list_type)) as jlong
507509
}
508510

511+
/// FixedSizeList constructor
512+
#[unsafe(no_mangle)]
513+
pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_newFixedSizeList(
514+
_env: JNIEnv,
515+
_class: JClass,
516+
element_ptr: jlong,
517+
size: jint,
518+
is_nullable: jboolean,
519+
) -> jlong {
520+
let element_dtype = unsafe { *Box::from_raw(element_ptr as *mut DType) };
521+
let element_dtype = Arc::new(element_dtype);
522+
523+
let fsl_type = DType::FixedSizeList(element_dtype, size as u32, to_nullability(is_nullable));
524+
525+
Box::into_raw(Box::new(fsl_type)) as jlong
526+
}
527+
528+
/// Get the fixed size of a FixedSizeList DType.
529+
#[unsafe(no_mangle)]
530+
pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_getFixedSizeListSize(
531+
mut env: JNIEnv,
532+
_class: JClass,
533+
dtype_ptr: jlong,
534+
) -> jint {
535+
let dtype = unsafe { &*(dtype_ptr as *const DType) };
536+
537+
try_or_throw(&mut env, |_| {
538+
let DType::FixedSizeList(_, size, _) = dtype else {
539+
throw_runtime!("DType should be FIXED_SIZE_LIST, was {dtype}");
540+
};
541+
542+
Ok(*size as jint)
543+
})
544+
}
545+
509546
/// Struct constructor
510547
#[unsafe(no_mangle)]
511548
pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_newStruct<'local>(

0 commit comments

Comments
 (0)