Skip to content

Commit 3bc1d47

Browse files
authored
Merge pull request #20130 from paldepind/rust/type-inference-fn
Rust: Implement type inference for closures and calls to closures
2 parents 5ca9c09 + 5b152cf commit 3bc1d47

File tree

7 files changed

+479
-78
lines changed

7 files changed

+479
-78
lines changed

rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,26 @@ class FutureTrait extends Trait {
140140
}
141141
}
142142

143+
/**
144+
* The [`FnOnce` trait][1].
145+
*
146+
* [1]: https://doc.rust-lang.org/std/ops/trait.FnOnce.html
147+
*/
148+
class FnOnceTrait extends Trait {
149+
pragma[nomagic]
150+
FnOnceTrait() { this.getCanonicalPath() = "core::ops::function::FnOnce" }
151+
152+
/** Gets the type parameter of this trait. */
153+
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }
154+
155+
/** Gets the `Output` associated type. */
156+
pragma[nomagic]
157+
TypeAlias getOutputType() {
158+
result = this.getAssocItemList().getAnAssocItem() and
159+
result.getName().getText() = "Output"
160+
}
161+
}
162+
143163
/**
144164
* The [`Iterator` trait][1].
145165
*

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,17 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
389389
prefix2.isEmpty() and
390390
s = getRangeType(n1)
391391
)
392+
or
393+
exists(ClosureExpr ce, int index |
394+
n1 = ce and
395+
n2 = ce.getParam(index).getPat() and
396+
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
397+
prefix2.isEmpty()
398+
)
399+
or
400+
n1.(ClosureExpr).getBody() = n2 and
401+
prefix1 = closureReturnPath() and
402+
prefix2.isEmpty()
392403
}
393404

394405
pragma[nomagic]
@@ -1441,6 +1452,120 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
14411452
)
14421453
}
14431454

1455+
/**
1456+
* An invoked expression, the target of a call that is either a local variable
1457+
* or a non-path expression. This means that the expression denotes a
1458+
* first-class function.
1459+
*/
1460+
final private class InvokedClosureExpr extends Expr {
1461+
private CallExpr call;
1462+
1463+
InvokedClosureExpr() {
1464+
call.getFunction() = this and
1465+
(not this instanceof PathExpr or this = any(Variable v).getAnAccess())
1466+
}
1467+
1468+
Type getTypeAt(TypePath path) { result = inferType(this, path) }
1469+
1470+
CallExpr getCall() { result = call }
1471+
}
1472+
1473+
private module InvokedClosureSatisfiesConstraintInput implements
1474+
SatisfiesConstraintInputSig<InvokedClosureExpr>
1475+
{
1476+
predicate relevantConstraint(InvokedClosureExpr term, Type constraint) {
1477+
exists(term) and
1478+
constraint.(TraitType).getTrait() instanceof FnOnceTrait
1479+
}
1480+
}
1481+
1482+
/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
1483+
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
1484+
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>::satisfiesConstraintType(ce,
1485+
_, path, result)
1486+
}
1487+
1488+
/** Gets the path to a closure's return type. */
1489+
private TypePath closureReturnPath() {
1490+
result = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getOutputType()))
1491+
}
1492+
1493+
/** Gets the path to a closure with arity `arity`s `index`th parameter type. */
1494+
private TypePath closureParameterPath(int arity, int index) {
1495+
result =
1496+
TypePath::cons(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam()),
1497+
TypePath::singleton(TTupleTypeParameter(arity, index)))
1498+
}
1499+
1500+
/** Gets the path to the return type of the `FnOnce` trait. */
1501+
private TypePath fnReturnPath() {
1502+
result = TypePath::singleton(TAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
1503+
}
1504+
1505+
/**
1506+
* Gets the path to the parameter type of the `FnOnce` trait with arity `arity`
1507+
* and index `index`.
1508+
*/
1509+
private TypePath fnParameterPath(int arity, int index) {
1510+
result =
1511+
TypePath::cons(TTypeParamTypeParameter(any(FnOnceTrait t).getTypeParam()),
1512+
TypePath::singleton(TTupleTypeParameter(arity, index)))
1513+
}
1514+
1515+
pragma[nomagic]
1516+
private Type inferDynamicCallExprType(Expr n, TypePath path) {
1517+
exists(InvokedClosureExpr ce |
1518+
// Propagate the function's return type to the call expression
1519+
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
1520+
n = ce.getCall() and
1521+
path = path0.stripPrefix(fnReturnPath())
1522+
or
1523+
// Propagate the function's parameter type to the arguments
1524+
exists(int index |
1525+
n = ce.getCall().getArgList().getArg(index) and
1526+
path = path0.stripPrefix(fnParameterPath(ce.getCall().getNumberOfArgs(), index))
1527+
)
1528+
)
1529+
or
1530+
// _If_ the invoked expression has the type of a closure, then we propagate
1531+
// the surrounding types into the closure.
1532+
exists(int arity, TypePath path0 |
1533+
ce.getTypeAt(TypePath::nil()).(DynTraitType).getTrait() instanceof FnOnceTrait
1534+
|
1535+
// Propagate the type of arguments to the parameter types of closure
1536+
exists(int index |
1537+
n = ce and
1538+
arity = ce.getCall().getNumberOfArgs() and
1539+
result = inferType(ce.getCall().getArg(index), path0) and
1540+
path = closureParameterPath(arity, index).append(path0)
1541+
)
1542+
or
1543+
// Propagate the type of the call expression to the return type of the closure
1544+
n = ce and
1545+
arity = ce.getCall().getNumberOfArgs() and
1546+
result = inferType(ce.getCall(), path0) and
1547+
path = closureReturnPath().append(path0)
1548+
)
1549+
)
1550+
}
1551+
1552+
pragma[nomagic]
1553+
private Type inferClosureExprType(AstNode n, TypePath path) {
1554+
exists(ClosureExpr ce |
1555+
n = ce and
1556+
path.isEmpty() and
1557+
result = TDynTraitType(any(FnOnceTrait t))
1558+
or
1559+
n = ce and
1560+
path = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam())) and
1561+
result = TTuple(ce.getNumberOfParams())
1562+
or
1563+
// Propagate return type annotation to body
1564+
n = ce.getBody() and
1565+
result = ce.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
1566+
)
1567+
}
1568+
14441569
pragma[nomagic]
14451570
private Type inferCastExprType(CastExpr ce, TypePath path) {
14461571
result = ce.getTypeRepr().(TypeMention).resolveTypeAt(path)
@@ -2068,6 +2193,10 @@ private module Cached {
20682193
or
20692194
result = inferForLoopExprType(n, path)
20702195
or
2196+
result = inferDynamicCallExprType(n, path)
2197+
or
2198+
result = inferClosureExprType(n, path)
2199+
or
20712200
result = inferCastExprType(n, path)
20722201
or
20732202
result = inferStructPatType(n, path)

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/** Provides classes for representing type mentions, used in type inference. */
22

33
private import rust
4+
private import codeql.rust.frameworks.stdlib.Stdlib
45
private import Type
56
private import PathResolution
67
private import TypeInference
@@ -26,6 +27,18 @@ class TupleTypeReprMention extends TypeMention instanceof TupleTypeRepr {
2627
}
2728
}
2829

30+
class ParenthesizedArgListMention extends TypeMention instanceof ParenthesizedArgList {
31+
override Type resolveTypeAt(TypePath path) {
32+
path.isEmpty() and
33+
result = TTuple(super.getNumberOfTypeArgs())
34+
or
35+
exists(TypePath suffix, int index |
36+
result = super.getTypeArg(index).getTypeRepr().(TypeMention).resolveTypeAt(suffix) and
37+
path = TypePath::cons(TTupleTypeParameter(super.getNumberOfTypeArgs(), index), suffix)
38+
)
39+
}
40+
}
41+
2942
class ArrayTypeReprMention extends TypeMention instanceof ArrayTypeRepr {
3043
override Type resolveTypeAt(TypePath path) {
3144
path.isEmpty() and
@@ -215,6 +228,17 @@ class NonAliasPathTypeMention extends PathTypeMention {
215228
.(TraitItemNode)
216229
.getAssocItem(pragma[only_bind_into](name)))
217230
)
231+
or
232+
// Handle the special syntactic sugar for function traits. For now we only
233+
// support `FnOnce` as we can't support the "inherited" associated types of
234+
// `Fn` and `FnMut` yet.
235+
exists(FnOnceTrait t | t = resolved |
236+
tp = TTypeParamTypeParameter(t.getTypeParam()) and
237+
result = this.getSegment().getParenthesizedArgList()
238+
or
239+
tp = TAssociatedTypeTypeParameter(t.getOutputType()) and
240+
result = this.getSegment().getRetType().getTypeRepr()
241+
)
218242
}
219243

220244
Type resolveRootType() {
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
category: minorAnalysis
3+
---
4+
* Type inference now supports closures, calls to closures, and trait bounds
5+
using the `FnOnce` trait.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/// Tests for type inference for closures and higher-order functions.
2+
3+
mod simple_closures {
4+
pub fn test() {
5+
// A simple closure without type annotations or invocations.
6+
let my_closure = |a, b| a && b;
7+
8+
let x: i64 = 1i64; // $ type=x:i64
9+
let add_one = |n| n + 1i64; // $ target=add
10+
let _y = add_one(x); // $ type=_y:i64
11+
12+
// The type of `x` is inferred from the closure's argument type.
13+
let x = Default::default(); // $ type=x:i64 target=default
14+
let add_zero = |n: i64| n;
15+
let _y = add_zero(x); // $ type=_y:i64
16+
17+
let _get_bool = || -> bool {
18+
// The return type annotation on the closure lets us infer the type of `b`.
19+
let b = Default::default(); // $ type=b:bool target=default
20+
b
21+
};
22+
23+
// The parameter type of `id` is inferred from the argument.
24+
let id = |b| b; // $ type=b:bool
25+
let _b = id(true); // $ type=_b:bool
26+
27+
// The return type of `id2` is inferred from the type of the call expression.
28+
let id2 = |b| b;
29+
let arg = Default::default(); // $ target=default type=arg:bool
30+
let _b2: bool = id2(arg); // $ type=_b2:bool
31+
}
32+
}
33+
34+
mod fn_once_trait {
35+
fn return_type<F: FnOnce(bool) -> i64>(f: F) {
36+
let _return = f(true); // $ type=_return:i64
37+
}
38+
39+
fn argument_type<F: FnOnce(bool) -> i64>(f: F) {
40+
let arg = Default::default(); // $ target=default type=arg:bool
41+
f(arg);
42+
}
43+
44+
fn apply<A, B, F: FnOnce(A) -> B>(f: F, a: A) -> B {
45+
f(a)
46+
}
47+
48+
fn apply_two(f: impl FnOnce(i64) -> i64) -> i64 {
49+
f(2)
50+
}
51+
52+
fn test() {
53+
let f = |x: bool| -> i64 {
54+
if x {
55+
1
56+
} else {
57+
0
58+
}
59+
};
60+
let _r = apply(f, true); // $ target=apply type=_r:i64
61+
62+
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
63+
let _r2 = apply_two(f); // $ target=apply_two type=_r2:i64
64+
}
65+
}
66+
67+
mod dyn_fn_once {
68+
fn apply_boxed<A, B, F: FnOnce(A) -> B + ?Sized>(f: Box<F>, arg: A) -> B {
69+
f(arg)
70+
}
71+
72+
fn apply_boxed_dyn<A, B>(f: Box<dyn FnOnce(A) -> B>, arg: A) {
73+
let _r1 = apply_boxed(f, arg); // $ target=apply_boxed type=_r1:B
74+
let _r2 = apply_boxed(Box::new(|_: i64| true), 3); // $ target=apply_boxed target=new type=_r2:bool
75+
}
76+
}

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2484,46 +2484,7 @@ pub mod pattern_matching_experimental {
24842484
}
24852485
}
24862486

2487-
mod closures {
2488-
struct Row {
2489-
data: i64,
2490-
}
2491-
2492-
impl Row {
2493-
fn get(&self) -> i64 {
2494-
self.data // $ fieldof=Row
2495-
}
2496-
}
2497-
2498-
struct Table {
2499-
rows: Vec<Row>,
2500-
}
2501-
2502-
impl Table {
2503-
fn new() -> Self {
2504-
Table { rows: Vec::new() } // $ target=new
2505-
}
2506-
2507-
fn count_with(&self, property: impl Fn(Row) -> bool) -> i64 {
2508-
0 // (not implemented)
2509-
}
2510-
}
2511-
2512-
pub fn f() {
2513-
Some(1).map(|x| {
2514-
let x = x; // $ MISSING: type=x:i32
2515-
println!("{x}");
2516-
}); // $ target=map
2517-
2518-
let table = Table::new(); // $ target=new type=table:Table
2519-
let result = table.count_with(|row| // $ type=result:i64
2520-
{
2521-
let v = row.get(); // $ MISSING: target=get type=v:i64
2522-
v > 0 // $ MISSING: target=gt
2523-
}); // $ target=count_with
2524-
}
2525-
}
2526-
2487+
mod closure;
25272488
mod dereference;
25282489
mod dyn_type;
25292490

@@ -2557,6 +2518,5 @@ fn main() {
25572518
dereference::test(); // $ target=test
25582519
pattern_matching::test_all_patterns(); // $ target=test_all_patterns
25592520
pattern_matching_experimental::box_patterns(); // $ target=box_patterns
2560-
closures::f(); // $ target=f
25612521
dyn_type::test(); // $ target=test
25622522
}

0 commit comments

Comments
 (0)