Skip to content

Rust: Implement type inference for closures and calls to closures #20130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,26 @@ class FutureTrait extends Trait {
}
}

/**
* The [`FnOnce` trait][1].
*
* [1]: https://doc.rust-lang.org/std/ops/trait.FnOnce.html
*/
class FnOnceTrait extends Trait {
pragma[nomagic]
FnOnceTrait() { this.getCanonicalPath() = "core::ops::function::FnOnce" }

/** Gets the type parameter of this trait. */
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }

/** Gets the `Output` associated type. */
pragma[nomagic]
TypeAlias getOutputType() {
result = this.getAssocItemList().getAnAssocItem() and
result.getName().getText() = "Output"
}
}

/**
* The [`Iterator` trait][1].
*
Expand Down
129 changes: 129 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,17 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
prefix2.isEmpty() and
s = getRangeType(n1)
)
or
exists(ClosureExpr ce, int index |
n1 = ce and
n2 = ce.getParam(index).getPat() and
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
prefix2.isEmpty()
)
or
n1.(ClosureExpr).getBody() = n2 and
prefix1 = closureReturnPath() and
prefix2.isEmpty()
}

pragma[nomagic]
Expand Down Expand Up @@ -1441,6 +1452,120 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
)
}

/**
* An invoked expression, the target of a call that is either a local variable
* or a non-path expression. This means that the expression denotes a
* first-class function.
*/
final private class InvokedClosureExpr extends Expr {
private CallExpr call;

InvokedClosureExpr() {
call.getFunction() = this and
(not this instanceof PathExpr or this = any(Variable v).getAnAccess())
}

Type getTypeAt(TypePath path) { result = inferType(this, path) }

CallExpr getCall() { result = call }
}

private module InvokedClosureSatisfiesConstraintInput implements
SatisfiesConstraintInputSig<InvokedClosureExpr>
{
predicate relevantConstraint(InvokedClosureExpr term, Type constraint) {
exists(term) and
constraint.(TraitType).getTrait() instanceof FnOnceTrait
}
}

/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>::satisfiesConstraintType(ce,
_, path, result)
}

/** Gets the path to a closure's return type. */
private TypePath closureReturnPath() {
result = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getOutputType()))
}

/** Gets the path to a closure with arity `arity`s `index`th parameter type. */
private TypePath closureParameterPath(int arity, int index) {
result =
TypePath::cons(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam()),
TypePath::singleton(TTupleTypeParameter(arity, index)))
}

/** Gets the path to the return type of the `FnOnce` trait. */
private TypePath fnReturnPath() {
result = TypePath::singleton(TAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
}

/**
* Gets the path to the parameter type of the `FnOnce` trait with arity `arity`
* and index `index`.
*/
private TypePath fnParameterPath(int arity, int index) {
result =
TypePath::cons(TTypeParamTypeParameter(any(FnOnceTrait t).getTypeParam()),
TypePath::singleton(TTupleTypeParameter(arity, index)))
}

pragma[nomagic]
private Type inferDynamicCallExprType(Expr n, TypePath path) {
exists(InvokedClosureExpr ce |
// Propagate the function's return type to the call expression
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
n = ce.getCall() and
path = path0.stripPrefix(fnReturnPath())
or
// Propagate the function's parameter type to the arguments
exists(int index |
n = ce.getCall().getArgList().getArg(index) and
path = path0.stripPrefix(fnParameterPath(ce.getCall().getNumberOfArgs(), index))
)
)
or
// _If_ the invoked expression has the type of a closure, then we propagate
// the surrounding types into the closure.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to understand how these differs from the cases above a bit better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key difference is the direction.

  • The cases above propagate type information out from the invoked expression (which implements FnOnce) to the arguments and the entire call.
  • The cases below propagate type information in to the invoked expression. This however we only do if the invoked expression has a closure type (isn't some other type that implements FnOnce).

Propagating type information into the invoked expression for arbitrary FnOnce implementations would make sense, but we don't have that machinery as of right now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for the explanation.

exists(int arity, TypePath path0 |
ce.getTypeAt(TypePath::nil()).(DynTraitType).getTrait() instanceof FnOnceTrait
|
// Propagate the type of arguments to the parameter types of closure
exists(int index |
n = ce and
arity = ce.getCall().getNumberOfArgs() and
result = inferType(ce.getCall().getArg(index), path0) and
path = closureParameterPath(arity, index).append(path0)
)
or
// Propagate the type of the call expression to the return type of the closure
n = ce and
arity = ce.getCall().getNumberOfArgs() and
result = inferType(ce.getCall(), path0) and
path = closureReturnPath().append(path0)
)
)
}

pragma[nomagic]
private Type inferClosureExprType(AstNode n, TypePath path) {
exists(ClosureExpr ce |
n = ce and
path.isEmpty() and
result = TDynTraitType(any(FnOnceTrait t))
or
n = ce and
path = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam())) and
result = TTuple(ce.getNumberOfParams())
or
// Propagate return type annotation to body
n = ce.getBody() and
result = ce.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
)
}

pragma[nomagic]
private Type inferCastExprType(CastExpr ce, TypePath path) {
result = ce.getTypeRepr().(TypeMention).resolveTypeAt(path)
Expand Down Expand Up @@ -2068,6 +2193,10 @@ private module Cached {
or
result = inferForLoopExprType(n, path)
or
result = inferDynamicCallExprType(n, path)
or
result = inferClosureExprType(n, path)
or
result = inferCastExprType(n, path)
or
result = inferStructPatType(n, path)
Expand Down
24 changes: 24 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/** Provides classes for representing type mentions, used in type inference. */

private import rust
private import codeql.rust.frameworks.stdlib.Stdlib
private import Type
private import PathResolution
private import TypeInference
Expand All @@ -26,6 +27,18 @@ class TupleTypeReprMention extends TypeMention instanceof TupleTypeRepr {
}
}

class ParenthesizedArgListMention extends TypeMention instanceof ParenthesizedArgList {
override Type resolveTypeAt(TypePath path) {
path.isEmpty() and
result = TTuple(super.getNumberOfTypeArgs())
or
exists(TypePath suffix, int index |
result = super.getTypeArg(index).getTypeRepr().(TypeMention).resolveTypeAt(suffix) and
path = TypePath::cons(TTupleTypeParameter(super.getNumberOfTypeArgs(), index), suffix)
)
}
}

class ArrayTypeReprMention extends TypeMention instanceof ArrayTypeRepr {
override Type resolveTypeAt(TypePath path) {
path.isEmpty() and
Expand Down Expand Up @@ -215,6 +228,17 @@ class NonAliasPathTypeMention extends PathTypeMention {
.(TraitItemNode)
.getAssocItem(pragma[only_bind_into](name)))
)
or
// Handle the special syntactic sugar for function traits. For now we only
// support `FnOnce` as we can't support the "inherited" associated types of
// `Fn` and `FnMut` yet.
exists(FnOnceTrait t | t = resolved |
tp = TTypeParamTypeParameter(t.getTypeParam()) and
result = this.getSegment().getParenthesizedArgList()
or
tp = TAssociatedTypeTypeParameter(t.getOutputType()) and
result = this.getSegment().getRetType().getTypeRepr()
)
}

Type resolveRootType() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
category: minorAnalysis
---
* Type inference now supports closures, calls to closures, and trait bounds
using the `FnOnce` trait.
76 changes: 76 additions & 0 deletions rust/ql/test/library-tests/type-inference/closure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/// Tests for type inference for closures and higher-order functions.
mod simple_closures {
pub fn test() {
// A simple closure without type annotations or invocations.
let my_closure = |a, b| a && b;

let x: i64 = 1i64; // $ type=x:i64
let add_one = |n| n + 1i64; // $ target=add
let _y = add_one(x); // $ type=_y:i64

// The type of `x` is inferred from the closure's argument type.
let x = Default::default(); // $ type=x:i64 target=default
let add_zero = |n: i64| n;
let _y = add_zero(x); // $ type=_y:i64

let _get_bool = || -> bool {
// The return type annotation on the closure lets us infer the type of `b`.
let b = Default::default(); // $ type=b:bool target=default
b
};

// The parameter type of `id` is inferred from the argument.
let id = |b| b; // $ type=b:bool
let _b = id(true); // $ type=_b:bool

// The return type of `id2` is inferred from the type of the call expression.
let id2 = |b| b;
let arg = Default::default(); // $ target=default type=arg:bool
let _b2: bool = id2(arg); // $ type=_b2:bool
}
}

mod fn_once_trait {
fn return_type<F: FnOnce(bool) -> i64>(f: F) {
let _return = f(true); // $ type=_return:i64
}

fn argument_type<F: FnOnce(bool) -> i64>(f: F) {
let arg = Default::default(); // $ target=default type=arg:bool
f(arg);
}

fn apply<A, B, F: FnOnce(A) -> B>(f: F, a: A) -> B {
f(a)
}

fn apply_two(f: impl FnOnce(i64) -> i64) -> i64 {
f(2)
}

fn test() {
let f = |x: bool| -> i64 {
if x {
1
} else {
0
}
};
let _r = apply(f, true); // $ target=apply type=_r:i64

let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
let _r2 = apply_two(f); // $ target=apply_two type=_r2:i64
}
}

mod dyn_fn_once {
fn apply_boxed<A, B, F: FnOnce(A) -> B + ?Sized>(f: Box<F>, arg: A) -> B {
f(arg)
}

fn apply_boxed_dyn<A, B>(f: Box<dyn FnOnce(A) -> B>, arg: A) {
let _r1 = apply_boxed(f, arg); // $ target=apply_boxed type=_r1:B
let _r2 = apply_boxed(Box::new(|_: i64| true), 3); // $ target=apply_boxed target=new type=_r2:bool
}
}
42 changes: 1 addition & 41 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2484,46 +2484,7 @@ pub mod pattern_matching_experimental {
}
}

mod closures {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why have these test cases been deleted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the stuff being tested by these examples are covered by the new tests for closures. There was quite a lot of "extra" setup in these tests, for what was actually being tested, so I think the new tests are preferable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes they were somewhat close to some real world examples I'd encountered. I guess I'd like reassuring that they would indeed work right now, whether or not they're part of our test suite.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I should have been more clear. They unfortunately still don't work, due to the "Another reason" that I mention in #20130 (comment).

The case is covered by the

        let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
        let _r2 = apply_two(f); // $ target=apply_two type=_r2:i64

test, which is essentially a "minimal" variant of the existing test. It test passing an unannotated closure to a function with a known type.

struct Row {
data: i64,
}

impl Row {
fn get(&self) -> i64 {
self.data // $ fieldof=Row
}
}

struct Table {
rows: Vec<Row>,
}

impl Table {
fn new() -> Self {
Table { rows: Vec::new() } // $ target=new
}

fn count_with(&self, property: impl Fn(Row) -> bool) -> i64 {
0 // (not implemented)
}
}

pub fn f() {
Some(1).map(|x| {
let x = x; // $ MISSING: type=x:i32
println!("{x}");
}); // $ target=map

let table = Table::new(); // $ target=new type=table:Table
let result = table.count_with(|row| // $ type=result:i64
{
let v = row.get(); // $ MISSING: target=get type=v:i64
v > 0 // $ MISSING: target=gt
}); // $ target=count_with
}
}

mod closure;
mod dereference;
mod dyn_type;

Expand Down Expand Up @@ -2557,6 +2518,5 @@ fn main() {
dereference::test(); // $ target=test
pattern_matching::test_all_patterns(); // $ target=test_all_patterns
pattern_matching_experimental::box_patterns(); // $ target=box_patterns
closures::f(); // $ target=f
dyn_type::test(); // $ target=test
}
Loading