Skip to content

Commit 2c758a9

Browse files
committed
Rust: Add type inference for closures and calls to first-class functions
1 parent 8c6c28d commit 2c758a9

File tree

5 files changed

+341
-14
lines changed

5 files changed

+341
-14
lines changed

rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,26 @@ class FutureTrait extends Trait {
140140
}
141141
}
142142

143+
/**
144+
* The [`FnOnce` trait][1].
145+
*
146+
* [1]: https://doc.rust-lang.org/std/ops/trait.FnOnce.html
147+
*/
148+
class FnOnceTrait extends Trait {
149+
pragma[nomagic]
150+
FnOnceTrait() { this.getCanonicalPath() = "core::ops::function::FnOnce" }
151+
152+
/** Gets the type parameter of this trait. */
153+
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }
154+
155+
/** Gets the `Output` associated type. */
156+
pragma[nomagic]
157+
TypeAlias getOutputType() {
158+
result = this.getAssocItemList().getAnAssocItem() and
159+
result.getName().getText() = "Output"
160+
}
161+
}
162+
143163
/**
144164
* The [`Iterator` trait][1].
145165
*

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,17 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
383383
prefix2.isEmpty() and
384384
s = getRangeType(n1)
385385
)
386+
or
387+
exists(ClosureExpr ce, int index |
388+
n1 = ce and
389+
n2 = ce.getParam(index).getPat() and
390+
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
391+
prefix2.isEmpty()
392+
)
393+
or
394+
n1.(ClosureExpr).getBody() = n2 and
395+
prefix1 = closureReturnPath() and
396+
prefix2.isEmpty()
386397
}
387398

388399
pragma[nomagic]
@@ -1435,6 +1446,120 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
14351446
)
14361447
}
14371448

1449+
/**
1450+
* An invoked expression, the target of a call that is either a local variable
1451+
* or a non-path expression. This means that the expression denotes a
1452+
* first-class function.
1453+
*/
1454+
final private class InvokedClosureExpr extends Expr {
1455+
private CallExpr call;
1456+
1457+
InvokedClosureExpr() {
1458+
call.getFunction() = this and
1459+
(not this instanceof PathExpr or this = any(Variable v).getAnAccess())
1460+
}
1461+
1462+
Type getTypeAt(TypePath path) { result = inferType(this, path) }
1463+
1464+
CallExpr getCall() { result = call }
1465+
}
1466+
1467+
private module InvokedClosureSatisfiesConstraintInput implements
1468+
SatisfiesConstraintInputSig<InvokedClosureExpr>
1469+
{
1470+
predicate relevantConstraint(InvokedClosureExpr term, Type constraint) {
1471+
exists(term) and
1472+
constraint.(TraitType).getTrait() instanceof FnOnceTrait
1473+
}
1474+
}
1475+
1476+
/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
1477+
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
1478+
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>::satisfiesConstraintType(ce,
1479+
_, path, result)
1480+
}
1481+
1482+
/** Gets the path to a closure's return type. */
1483+
private TypePath closureReturnPath() {
1484+
result = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getOutputType()))
1485+
}
1486+
1487+
/** Gets the path to a closure with arity `arity`s `index`th parameter type. */
1488+
private TypePath closureParameterPath(int arity, int index) {
1489+
result =
1490+
TypePath::cons(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam()),
1491+
TypePath::singleton(TTupleTypeParameter(arity, index)))
1492+
}
1493+
1494+
/** Gets the path to the return type of the `FnOnce` trait. */
1495+
private TypePath fnReturnPath() {
1496+
result = TypePath::singleton(TAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
1497+
}
1498+
1499+
/**
1500+
* Gets the path to the parameter type of the `FnOnce` trait with arity `arity`
1501+
* and index `index`.
1502+
*/
1503+
private TypePath fnParameterPath(int arity, int index) {
1504+
result =
1505+
TypePath::cons(TTypeParamTypeParameter(any(FnOnceTrait t).getTypeParam()),
1506+
TypePath::singleton(TTupleTypeParameter(arity, index)))
1507+
}
1508+
1509+
pragma[nomagic]
1510+
private Type inferDynamicCallExprType(Expr n, TypePath path) {
1511+
exists(InvokedClosureExpr ce |
1512+
// Propagate the function's return type to the call expression
1513+
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
1514+
n = ce.getCall() and
1515+
path = path0.stripPrefix(fnReturnPath())
1516+
or
1517+
// Propagate the function's parameter type to the arguments
1518+
exists(int index |
1519+
n = ce.getCall().getArgList().getArg(index) and
1520+
path = path0.stripPrefix(fnParameterPath(ce.getCall().getNumberOfArgs(), index))
1521+
)
1522+
)
1523+
or
1524+
// _If_ the invoked expression has the type of a closure, then we propagate
1525+
// the surrounding types into the closure.
1526+
exists(int arity, TypePath path0 |
1527+
ce.getTypeAt(TypePath::nil()).(DynTraitType).getTrait() instanceof FnOnceTrait
1528+
|
1529+
// Propagate the type of arguments to the parameter types of closure
1530+
exists(int index |
1531+
n = ce and
1532+
arity = ce.getCall().getNumberOfArgs() and
1533+
result = inferType(ce.getCall().getArg(index), path0) and
1534+
path = closureParameterPath(arity, index).append(path0)
1535+
)
1536+
or
1537+
// Propagate the type of the call expression to the return type of the closure
1538+
n = ce and
1539+
arity = ce.getCall().getNumberOfArgs() and
1540+
result = inferType(ce.getCall(), path0) and
1541+
path = closureReturnPath().append(path0)
1542+
)
1543+
)
1544+
}
1545+
1546+
pragma[nomagic]
1547+
private Type inferClosureExprType(AstNode n, TypePath path) {
1548+
exists(ClosureExpr ce |
1549+
n = ce and
1550+
path.isEmpty() and
1551+
result = TDynTraitType(any(FnOnceTrait t))
1552+
or
1553+
n = ce and
1554+
path = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam())) and
1555+
result = TTuple(ce.getNumberOfParams())
1556+
or
1557+
// Propagate return type annotation to body
1558+
n = ce.getBody() and
1559+
result = ce.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
1560+
)
1561+
}
1562+
14381563
pragma[nomagic]
14391564
private Type inferCastExprType(CastExpr ce, TypePath path) {
14401565
result = ce.getTypeRepr().(TypeMention).resolveTypeAt(path)
@@ -2062,6 +2187,10 @@ private module Cached {
20622187
or
20632188
result = inferForLoopExprType(n, path)
20642189
or
2190+
result = inferDynamicCallExprType(n, path)
2191+
or
2192+
result = inferClosureExprType(n, path)
2193+
or
20652194
result = inferCastExprType(n, path)
20662195
or
20672196
result = inferStructPatType(n, path)

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/** Provides classes for representing type mentions, used in type inference. */
22

33
private import rust
4+
private import codeql.rust.frameworks.stdlib.Stdlib
45
private import Type
56
private import PathResolution
67
private import TypeInference
@@ -26,6 +27,18 @@ class TupleTypeReprMention extends TypeMention instanceof TupleTypeRepr {
2627
}
2728
}
2829

30+
class ParenthesizedArgListMention extends TypeMention instanceof ParenthesizedArgList {
31+
override Type resolveTypeAt(TypePath path) {
32+
path.isEmpty() and
33+
result = TTuple(super.getNumberOfTypeArgs())
34+
or
35+
exists(TypePath suffix, int index |
36+
result = super.getTypeArg(index).getTypeRepr().(TypeMention).resolveTypeAt(suffix) and
37+
path = TypePath::cons(TTupleTypeParameter(super.getNumberOfTypeArgs(), index), suffix)
38+
)
39+
}
40+
}
41+
2942
class ArrayTypeReprMention extends TypeMention instanceof ArrayTypeRepr {
3043
override Type resolveTypeAt(TypePath path) {
3144
path.isEmpty() and
@@ -215,6 +228,17 @@ class NonAliasPathTypeMention extends PathTypeMention {
215228
.(TraitItemNode)
216229
.getAssocItem(pragma[only_bind_into](name)))
217230
)
231+
or
232+
// Handle the special syntactic sugar for function traits. For now we only
233+
// support `FnOnce` as we can't support the "inherited" associated types of
234+
// `Fn` and `FnMut` yet.
235+
exists(FnOnceTrait t | t = resolved |
236+
tp = TTypeParamTypeParameter(t.getTypeParam()) and
237+
result = this.getSegment().getParenthesizedArgList()
238+
or
239+
tp = TAssociatedTypeTypeParameter(t.getOutputType()) and
240+
result = this.getSegment().getRetType().getTypeRepr()
241+
)
218242
}
219243

220244
Type resolveRootType() {

rust/ql/test/library-tests/type-inference/closure.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,38 @@ mod simple_closures {
66
let my_closure = |a, b| a && b;
77

88
let x: i64 = 1i64; // $ type=x:i64
9-
let add_one = |n| n + 1i64; // $ MISSING: target=add
10-
let _y = add_one(x); // $ MISSING: type=y:i64
9+
let add_one = |n| n + 1i64; // $ target=add
10+
let _y = add_one(x); // $ type=_y:i64
1111

1212
// The type of `x` is inferred from the closure's argument type.
13-
let x = Default::default(); // $ MISSING: type=x:i64 target=default
13+
let x = Default::default(); // $ type=x:i64 target=default
1414
let add_zero = |n: i64| n;
15-
let _y = add_zero(x); // $ MISSING: type=_y:i64
15+
let _y = add_zero(x); // $ type=_y:i64
1616

1717
let _get_bool = || -> bool {
1818
// The return type annotation on the closure lets us infer the type of `b`.
19-
let b = Default::default(); // $ MISSING: type=b:bool target=default
19+
let b = Default::default(); // $ type=b:bool target=default
2020
b
2121
};
2222

2323
// The parameter type of `id` is inferred from the argument.
24-
let id = |b| b; // $ MISSING: type=x:bool
25-
let _b = id(true); // $ MISSING: type=_b:bool
24+
let id = |b| b; // $ type=b:bool
25+
let _b = id(true); // $ type=_b:bool
2626

2727
// The return type of `id2` is inferred from the type of the call expression.
2828
let id2 = |b| b;
29-
let arg = Default::default(); // $ MISSING: target=default type=arg:bool
30-
let _b2: bool = id2(arg); // $ MISSING: type=_b:bool
29+
let arg = Default::default(); // $ target=default type=arg:bool
30+
let _b2: bool = id2(arg); // $ type=_b2:bool
3131
}
3232
}
3333

3434
mod fn_once_trait {
3535
fn return_type<F: FnOnce(bool) -> i64>(f: F) {
36-
let _return = f(true); // $ MISSING: type=_return:i64
36+
let _return = f(true); // $ type=_return:i64
3737
}
3838

3939
fn argument_type<F: FnOnce(bool) -> i64>(f: F) {
40-
let arg = Default::default(); // $ MISSING: type=arg:bool target=default
40+
let arg = Default::default(); // $ target=default type=arg:bool
4141
f(arg);
4242
}
4343

@@ -57,7 +57,7 @@ mod fn_once_trait {
5757
0
5858
}
5959
};
60-
let _r = apply(f, true); // $ target=apply MISSING: type=_r:i64
60+
let _r = apply(f, true); // $ target=apply type=_r:i64
6161

6262
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
6363
let _r2 = apply_two(f); // $ target=apply_two type=_r2:i64
@@ -70,7 +70,7 @@ mod dyn_fn_once {
7070
}
7171

7272
fn apply_boxed_dyn<A, B>(f: Box<dyn FnOnce(A) -> B>, arg: A) {
73-
let _r1 = apply_boxed(f, arg); // $ target=apply_boxed MISSING: type=_r1:B
74-
let _r2 = apply_boxed(Box::new(|_: i64| true), 3); // $ target=apply_boxed target=new MISSING: type=_r2:bool
73+
let _r1 = apply_boxed(f, arg); // $ target=apply_boxed type=_r1:B
74+
let _r2 = apply_boxed(Box::new(|_: i64| true), 3); // $ target=apply_boxed target=new type=_r2:bool
7575
}
7676
}

0 commit comments

Comments
 (0)