Skip to content

Commit 0ca8419

Browse files
committed
Rust: Improve type inference for closures and function traits
1 parent 196f6e1 commit 0ca8419

File tree

6 files changed

+121
-61
lines changed

6 files changed

+121
-61
lines changed

rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,23 +143,46 @@ class FutureTrait extends Trait {
143143
TypeAlias getOutputType() { result = this.(TraitItemNode).getAssocItem("Output") }
144144
}
145145

146+
/** Any of the function traits `FnOnce`, `FnMut`, or `Fn`. */
147+
class AnyFnTrait extends Trait {
148+
/** Gets the `Args` type parameter of this trait. */
149+
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }
150+
}
151+
146152
/**
147153
* The [`FnOnce` trait][1].
148154
*
149155
* [1]: https://doc.rust-lang.org/std/ops/trait.FnOnce.html
150156
*/
151-
class FnOnceTrait extends Trait {
157+
class FnOnceTrait extends AnyFnTrait {
152158
pragma[nomagic]
153159
FnOnceTrait() { this.getCanonicalPath() = "core::ops::function::FnOnce" }
154160

155-
/** Gets the type parameter of this trait. */
156-
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }
157-
158161
/** Gets the `Output` associated type. */
159162
pragma[nomagic]
160163
TypeAlias getOutputType() { result = this.(TraitItemNode).getAssocItem("Output") }
161164
}
162165

166+
/**
167+
* The [`FnMut` trait][1].
168+
*
169+
* [1]: https://doc.rust-lang.org/std/ops/trait.FnMut.html
170+
*/
171+
class FnMutTrait extends AnyFnTrait {
172+
pragma[nomagic]
173+
FnMutTrait() { this.getCanonicalPath() = "core::ops::function::FnMut" }
174+
}
175+
176+
/**
177+
* The [`Fn` trait][1].
178+
*
179+
* [1]: https://doc.rust-lang.org/std/ops/trait.Fn.html
180+
*/
181+
class FnTrait extends AnyFnTrait {
182+
pragma[nomagic]
183+
FnTrait() { this.getCanonicalPath() = "core::ops::function::Fn" }
184+
}
185+
163186
/**
164187
* The [`Iterator` trait][1].
165188
*

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3825,16 +3825,29 @@ private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
38253825
_, path, result)
38263826
}
38273827

3828+
/**
3829+
* Gets the root type of a closure.
3830+
*
3831+
* We model closures as `dyn Fn` trait object types. A closure might implement
3832+
* only `Fn`, `FnMut`, or `FnOnce`. But since `Fn` is a subtrait of the others,
3833+
* giving closures the type `dyn Fn` works well in practice—even if not entirely
3834+
* accurate.
3835+
*/
3836+
private DynTraitType closureRootType() {
3837+
result = TDynTraitType(any(FnTrait t)) // always exists because of the mention in `builtins/mentions.rs`
3838+
}
3839+
38283840
/** Gets the path to a closure's return type. */
38293841
private TypePath closureReturnPath() {
3830-
result = TypePath::singleton(getDynTraitTypeParameter(any(FnOnceTrait t).getOutputType()))
3842+
result =
3843+
TypePath::singleton(TDynTraitTypeParameter(any(FnTrait t), any(FnOnceTrait t).getOutputType()))
38313844
}
38323845

38333846
/** Gets the path to a closure with arity `arity`s `index`th parameter type. */
38343847
pragma[nomagic]
38353848
private TypePath closureParameterPath(int arity, int index) {
38363849
result =
3837-
TypePath::cons(TDynTraitTypeParameter(_, any(FnOnceTrait t).getTypeParam()),
3850+
TypePath::cons(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam()),
38383851
TypePath::singleton(getTupleTypeParameter(arity, index)))
38393852
}
38403853

@@ -3872,9 +3885,7 @@ private Type inferDynamicCallExprType(Expr n, TypePath path) {
38723885
or
38733886
// _If_ the invoked expression has the type of a closure, then we propagate
38743887
// the surrounding types into the closure.
3875-
exists(int arity, TypePath path0 |
3876-
ce.getTypeAt(TypePath::nil()).(DynTraitType).getTrait() instanceof FnOnceTrait
3877-
|
3888+
exists(int arity, TypePath path0 | ce.getTypeAt(TypePath::nil()) = closureRootType() |
38783889
// Propagate the type of arguments to the parameter types of closure
38793890
exists(int index, ArgList args |
38803891
n = ce and
@@ -3898,10 +3909,10 @@ private Type inferClosureExprType(AstNode n, TypePath path) {
38983909
exists(ClosureExpr ce |
38993910
n = ce and
39003911
path.isEmpty() and
3901-
result = TDynTraitType(any(FnOnceTrait t)) // always exists because of the mention in `builtins/mentions.rs`
3912+
result = closureRootType()
39023913
or
39033914
n = ce and
3904-
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnOnceTrait t).getTypeParam())) and
3915+
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
39053916
result.(TupleType).getArity() = ce.getNumberOfParams()
39063917
or
39073918
// Propagate return type annotation to body

rust/ql/lib/codeql/rust/internal/typeinference/TypeMention.qll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,15 @@ class NonAliasPathTypeMention extends PathTypeMention {
212212
// associated types of `Fn` and `FnMut` yet.
213213
//
214214
// [1]: https://doc.rust-lang.org/reference/paths.html#grammar-TypePathFn
215-
exists(FnOnceTrait t, PathSegment s |
215+
exists(AnyFnTrait t, PathSegment s |
216216
t = resolved and
217217
s = this.getSegment() and
218218
s.hasParenthesizedArgList()
219219
|
220220
tp = TTypeParamTypeParameter(t.getTypeParam()) and
221221
result = s.getParenthesizedArgList().(TypeMention).resolveTypeAt(path)
222222
or
223-
tp = TAssociatedTypeTypeParameter(t, t.getOutputType()) and
223+
tp = TAssociatedTypeTypeParameter(t, any(FnOnceTrait tr).getOutputType()) and
224224
(
225225
result = s.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
226226
or

rust/ql/test/library-tests/type-inference/closure.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ mod fn_once_trait {
7070

7171
mod fn_mut_trait {
7272
fn return_type<F: FnMut(bool) -> i64>(mut f: F) {
73-
let _return = f(true); // $ MISSING: type=_return:i64
73+
let _return = f(true); // $ type=_return:i64
7474
}
7575

7676
fn return_type_omitted<F: FnMut(bool)>(mut f: F) {
77-
let _return = f(true); // $ MISSING: type=_return:()
77+
let _return = f(true); // $ type=_return:()
7878
}
7979

8080
fn argument_type<F: FnMut(bool) -> i64>(mut f: F) {
81-
let arg = Default::default(); // $ MISSING: target=default type=arg:bool
81+
let arg = Default::default(); // $ target=default type=arg:bool
8282
f(arg);
8383
}
8484

@@ -98,7 +98,7 @@ mod fn_mut_trait {
9898
0
9999
}
100100
};
101-
let _r = apply(f, true); // $ target=apply MISSING: type=_r:i64
101+
let _r = apply(f, true); // $ target=apply type=_r:i64
102102

103103
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
104104
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64
@@ -107,15 +107,15 @@ mod fn_mut_trait {
107107

108108
mod fn_trait {
109109
fn return_type<F: Fn(bool) -> i64>(f: F) {
110-
let _return = f(true); // $ MISSING: type=_return:i64
110+
let _return = f(true); // $ type=_return:i64
111111
}
112112

113113
fn return_type_omitted<F: Fn(bool)>(f: F) {
114-
let _return = f(true); // $ MISSING: type=_return:()
114+
let _return = f(true); // $ type=_return:()
115115
}
116116

117117
fn argument_type<F: Fn(bool) -> i64>(f: F) {
118-
let arg = Default::default(); // $ MISSING: target=default type=arg:bool
118+
let arg = Default::default(); // $ target=default type=arg:bool
119119
f(arg);
120120
}
121121

@@ -135,7 +135,7 @@ mod fn_trait {
135135
0
136136
}
137137
};
138-
let _r = apply(f, true); // $ target=apply MISSING: type=_r:i64
138+
let _r = apply(f, true); // $ target=apply type=_r:i64
139139

140140
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
141141
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64

0 commit comments

Comments
 (0)