-
Notifications
You must be signed in to change notification settings - Fork 145
Expand file tree
/
Copy pathmod.rs
More file actions
208 lines (179 loc) · 6.25 KB
/
mod.rs
File metadata and controls
208 lines (179 loc) · 6.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors
mod kernel;
pub use kernel::*;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_session::VortexSession;
use crate::ArrayRef;
use crate::Canonical;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::arrays::BoolArray;
use crate::arrays::Constant;
use crate::arrays::ConstantArray;
use crate::arrays::bool::BoolArrayExt;
use crate::arrays::masked::mask_validity_canonical;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::expr::Expression;
use crate::expr::and;
use crate::expr::lit;
use crate::scalar::Scalar;
use crate::scalar_fn::Arity;
use crate::scalar_fn::ChildName;
use crate::scalar_fn::EmptyOptions;
use crate::scalar_fn::ExecutionArgs;
use crate::scalar_fn::ScalarFnId;
use crate::scalar_fn::ScalarFnVTable;
use crate::scalar_fn::SimplifyCtx;
use crate::scalar_fn::fns::literal::Literal;
/// An expression that masks an input based on a boolean mask.
///
/// Where the mask is true, the input value is retained; where the mask is false, the output is
/// null. In other words, this performs an intersection of the input's validity with the mask.
#[derive(Clone)]
pub struct Mask;
impl ScalarFnVTable for Mask {
type Options = EmptyOptions;
fn id(&self) -> ScalarFnId {
ScalarFnId::from("vortex.mask")
}
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(vec![]))
}
fn deserialize(
&self,
_metadata: &[u8],
_session: &VortexSession,
) -> VortexResult<Self::Options> {
Ok(EmptyOptions)
}
fn arity(&self, _options: &Self::Options) -> Arity {
Arity::Exact(2)
}
fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
match child_idx {
0 => ChildName::from("input"),
1 => ChildName::from("mask"),
_ => unreachable!("Invalid child index {} for Mask expression", child_idx),
}
}
fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
vortex_ensure!(
arg_dtypes[1] == DType::Bool(Nullability::NonNullable),
"The mask argument to 'mask' must be a non-nullable boolean array, got {}",
arg_dtypes[1]
);
Ok(arg_dtypes[0].as_nullable())
}
fn execute(
&self,
_options: &Self::Options,
args: &dyn ExecutionArgs,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let input = args.get(0)?;
let mask_array = args.get(1)?;
if let Some(result) = execute_constant(&input, &mask_array)? {
return Ok(result);
}
execute_canonical(input, mask_array, ctx)
}
fn simplify(
&self,
_options: &Self::Options,
expr: &Expression,
ctx: &dyn SimplifyCtx,
) -> VortexResult<Option<Expression>> {
let Some(mask_lit) = expr.child(1).as_opt::<Literal>() else {
return Ok(None);
};
let mask_lit = mask_lit
.as_bool()
.value()
.vortex_expect("Mask must be non-nullable");
if mask_lit {
// Mask is all true, so the output is just the input.
Ok(Some(expr.child(0).clone()))
} else {
// Mask is all false, so the output is all nulls.
let input_dtype = ctx.return_dtype(expr.child(0))?;
Ok(Some(lit(Scalar::null(input_dtype.as_nullable()))))
}
}
fn validity(
&self,
_options: &Self::Options,
expression: &Expression,
) -> VortexResult<Option<Expression>> {
Ok(Some(and(
expression.child(0).validity()?,
expression.child(1).clone(),
)))
}
}
/// Try to handle masking when at least one of the input or mask is a constant array.
///
/// Returns `Ok(Some(result))` if the constant case was handled, `Ok(None)` if not.
fn execute_constant(input: &ArrayRef, mask_array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
let len = input.len();
if let Some(constant_mask) = mask_array.as_opt::<Constant>() {
let mask_value = constant_mask.scalar().as_bool().value().unwrap_or(false);
return if mask_value {
input.cast(input.dtype().as_nullable()).map(Some)
} else {
Ok(Some(
ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
))
};
}
if let Some(constant_input) = input.as_opt::<Constant>()
&& constant_input.scalar().is_null()
{
return Ok(Some(
ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
));
}
Ok(None)
}
/// Execute the mask by materializing both inputs to their canonical forms.
fn execute_canonical(
input: ArrayRef,
mask_array: ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let mask_bool = mask_array.execute::<BoolArray>(ctx)?;
let validity_mask = vortex_mask::Mask::from(mask_bool.to_bit_buffer());
let canonical = input.execute::<Canonical>(ctx)?;
Ok(mask_validity_canonical(canonical, &validity_mask, ctx)?.into_array())
}
#[cfg(test)]
mod test {
use vortex_error::VortexExpect;
use crate::dtype::DType;
use crate::dtype::Nullability::Nullable;
use crate::dtype::PType;
use crate::expr::lit;
use crate::expr::mask;
use crate::scalar::Scalar;
#[test]
fn test_simplify() {
let input_expr = lit(42u32);
let true_mask_expr = lit(true);
let false_mask_expr = lit(false);
let mask_true_expr = mask(input_expr.clone(), true_mask_expr);
let simplified_true = mask_true_expr
.optimize(&DType::Null)
.vortex_expect("Simplification");
assert_eq!(&simplified_true, &input_expr);
let mask_false_expr = mask(input_expr, false_mask_expr);
let simplified_false = mask_false_expr
.optimize(&DType::Null)
.vortex_expect("Simplification");
let expected_null_expr = lit(Scalar::null(DType::Primitive(PType::U32, Nullable)));
assert_eq!(&simplified_false, &expected_null_expr);
}
}