Skip to content

Commit f5f7b61

Browse files
committed
Rust: Implement certain type information for annotation and simple calls
1 parent ddad971 commit f5f7b61

File tree

7 files changed

+163
-405
lines changed

7 files changed

+163
-405
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 150 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,13 @@ private module M2 = Make2<Input2>;
217217

218218
private import M2
219219

220-
module Consistency = M2::Consistency;
220+
module Consistency {
221+
import M2::Consistency
222+
223+
query predicate nonUniqueCertainType(AstNode n, TypePath path) {
224+
strictcount(CertainTypeInference::inferCertainType(n, path)) > 1
225+
}
226+
}
221227

222228
/** Gets the type annotation that applies to `n`, if any. */
223229
private TypeMention getTypeAnnotation(AstNode n) {
@@ -245,6 +251,134 @@ private Type inferAnnotatedType(AstNode n, TypePath path) {
245251
result = getTypeAnnotation(n).resolveTypeAt(path)
246252
}
247253

254+
/** Module for inferring certain type information. */
255+
private module CertainTypeInference {
256+
/** Holds if the type mention does not contain any inferred types `_`. */
257+
predicate typeMentionIsComplete(TypeMention tm) {
258+
not exists(InferTypeRepr t | t.getParentNode*() = tm)
259+
}
260+
261+
/**
262+
* Holds if `ce` is a call where we can infer the type with certainty and if
263+
* `f` is the target of the call and `p` the path invoked by the call.
264+
*
265+
* Necessary conditions for this are:
266+
* - We are certain of the call target (i.e., the call target can not depend on type information).
267+
* - The declared type of the function does not contain any generics that we
268+
* need to infer.
269+
* - The call does not contain any arguments, as arguments in calls are coercion sites.
270+
*
271+
* The current requirements are made to allow for call to `new` functions such
272+
* as `Vec<Foo>::new()` but not much more.
273+
*/
274+
predicate certainCallExprTarget(CallExpr ce, Function f, Path p) {
275+
p = CallExprImpl::getFunctionPath(ce) and
276+
f = resolvePath(p) and
277+
// The function is not in a trait
278+
not any(TraitItemNode t).getAnAssocItem() = f and
279+
// The function is not in a trait implementation
280+
not any(ImplItemNode impl | impl.(Impl).hasTrait()).getAnAssocItem() = f and
281+
// The function does not have parameters.
282+
not f.getParamList().hasSelfParam() and
283+
f.getParamList().getNumberOfParams() = 0 and
284+
// The function is not async.
285+
not f.isAsync() and
286+
// For now, exclude functions in macro expansions.
287+
not ce.isInMacroExpansion() and
288+
// The function has no type parameters.
289+
not f.hasGenericParamList() and
290+
// The function does not have `impl` types among its parameters (these are type parameters).
291+
not any(ImplTraitTypeRepr itt | not itt.isInReturnPos()).getFunction() = f and
292+
(
293+
not exists(ImplItemNode impl | impl.getAnAssocItem() = f)
294+
or
295+
// If the function is in an impl then the impl block has no type
296+
// parameters or all the type parameters are given explicitly.
297+
exists(ImplItemNode impl | impl.getAnAssocItem() = f |
298+
not impl.(Impl).hasGenericParamList() or
299+
impl.(Impl).getGenericParamList().getNumberOfGenericParams() =
300+
p.getQualifier().getSegment().getGenericArgList().getNumberOfGenericArgs()
301+
)
302+
)
303+
}
304+
305+
private ImplItemNode getFunctionImpl(FunctionItemNode f) { result.getAnAssocItem() = f }
306+
307+
Type inferCertainCallExprType(CallExpr ce, TypePath path) {
308+
exists(Function f, Type ty, TypePath prefix, Path p |
309+
certainCallExprTarget(ce, f, p) and
310+
ty = f.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(prefix)
311+
|
312+
if ty.(TypeParamTypeParameter).getTypeParam() = getFunctionImpl(f).getTypeParam(_)
313+
then
314+
exists(TypePath pathToTp, TypePath suffix |
315+
// For type parameters of the `impl` block we must resolve their
316+
// instantiation from the path. For instance, for `impl<A> for Foo<A>`
317+
// and the path `Foo<i64>::bar` we must resolve `A` to `i64`.
318+
ty = getFunctionImpl(f).(Impl).getSelfTy().(TypeMention).resolveTypeAt(pathToTp) and
319+
result = p.getQualifier().(TypeMention).resolveTypeAt(pathToTp.appendInverse(suffix)) and
320+
path = prefix.append(suffix)
321+
)
322+
else (
323+
result = ty and path = prefix
324+
)
325+
)
326+
}
327+
328+
predicate certainTypeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
329+
prefix1.isEmpty() and
330+
prefix2.isEmpty() and
331+
(
332+
exists(Variable v | n1 = v.getAnAccess() |
333+
n2 = v.getPat().getName() or n2 = v.getParameter().(SelfParam)
334+
)
335+
or
336+
// A `let` statement with a type annotation is a coercion site and hence
337+
// is not a certain type equality.
338+
exists(LetStmt let | not let.hasTypeRepr() |
339+
let.getPat() = n1 and
340+
let.getInitializer() = n2
341+
)
342+
)
343+
or
344+
n1 =
345+
any(IdentPat ip |
346+
n2 = ip.getName() and
347+
prefix1.isEmpty() and
348+
if ip.isRef() then prefix2 = TypePath::singleton(TRefTypeParameter()) else prefix2.isEmpty()
349+
)
350+
}
351+
352+
pragma[nomagic]
353+
private Type inferCertainTypeEquality(AstNode n, TypePath path) {
354+
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
355+
result = inferCertainType(n2, prefix2.appendInverse(suffix)) and
356+
path = prefix1.append(suffix)
357+
|
358+
certainTypeEquality(n, prefix1, n2, prefix2)
359+
or
360+
certainTypeEquality(n2, prefix2, n, prefix1)
361+
)
362+
}
363+
364+
/**
365+
* Holds if `n` has complete and certain type information and if `n` has the
366+
* resulting type at `path`.
367+
*/
368+
pragma[nomagic]
369+
Type inferCertainType(AstNode n, TypePath path) {
370+
exists(TypeMention tm |
371+
tm = getTypeAnnotation(n) and
372+
typeMentionIsComplete(tm) and
373+
result = tm.resolveTypeAt(path)
374+
)
375+
or
376+
result = inferCertainCallExprType(n, path)
377+
or
378+
result = inferCertainTypeEquality(n, path)
379+
}
380+
}
381+
248382
private Type inferLogicalOperationType(AstNode n, TypePath path) {
249383
exists(Builtins::BuiltinType t, BinaryLogicalOperation be |
250384
n = [be, be.getLhs(), be.getRhs()] and
@@ -284,15 +418,11 @@ private Struct getRangeType(RangeExpr re) {
284418
* through the type equality.
285419
*/
286420
private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
421+
CertainTypeInference::certainTypeEquality(n1, prefix1, n2, prefix2)
422+
or
287423
prefix1.isEmpty() and
288424
prefix2.isEmpty() and
289425
(
290-
exists(Variable v | n1 = v.getAnAccess() |
291-
n2 = v.getPat().getName()
292-
or
293-
n2 = v.getParameter().(SelfParam)
294-
)
295-
or
296426
exists(LetStmt let |
297427
let.getPat() = n1 and
298428
let.getInitializer() = n2
@@ -335,13 +465,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
335465
n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion()
336466
)
337467
or
338-
n1 =
339-
any(IdentPat ip |
340-
n2 = ip.getName() and
341-
prefix1.isEmpty() and
342-
if ip.isRef() then prefix2 = TypePath::singleton(TRefTypeParameter()) else prefix2.isEmpty()
343-
)
344-
or
345468
(
346469
n1 = n2.(RefExpr).getExpr() or
347470
n1 = n2.(RefPat).getPat()
@@ -404,6 +527,9 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
404527

405528
pragma[nomagic]
406529
private Type inferTypeEquality(AstNode n, TypePath path) {
530+
// Don't propagate type information into a node for which we already have
531+
// certain type information.
532+
not exists(CertainTypeInference::inferCertainType(n, _)) and
407533
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
408534
result = inferType(n2, prefix2.appendInverse(suffix)) and
409535
path = prefix1.append(suffix)
@@ -814,6 +940,8 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
814940
}
815941

816942
final class Access extends Call {
943+
Access() { not CertainTypeInference::certainCallExprTarget(this, _, _) }
944+
817945
pragma[nomagic]
818946
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
819947
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
@@ -2146,6 +2274,8 @@ private module Cached {
21462274
cached
21472275
Type inferType(AstNode n, TypePath path) {
21482276
Stages::TypeInferenceStage::ref() and
2277+
result = CertainTypeInference::inferCertainType(n, path)
2278+
or
21492279
result = inferAnnotatedType(n, path)
21502280
or
21512281
result = inferLogicalOperationType(n, path)
@@ -2291,4 +2421,10 @@ private module Debug {
22912421
c = countTypePaths(n, path, t) and
22922422
c = max(countTypePaths(_, _, _))
22932423
}
2424+
2425+
Type debugInferCertainNonUniqueType(AstNode n, TypePath path) {
2426+
n = getRelevantLocatable() and
2427+
Consistency::nonUniqueCertainType(n, path) and
2428+
result = CertainTypeInference::inferCertainType(n, path)
2429+
}
22942430
}

rust/ql/lib/codeql/rust/internal/TypeInferenceConsistency.qll

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
* Provides classes for recognizing type inference inconsistencies.
33
*/
44

5+
private import rust
56
private import Type
67
private import TypeMention
8+
private import TypeInference
79
private import TypeInference::Consistency as Consistency
810
import TypeInference::Consistency
911

@@ -27,4 +29,7 @@ int getTypeInferenceInconsistencyCounts(string type) {
2729
or
2830
type = "Ill-formed type mention" and
2931
result = count(TypeMention tm | illFormedTypeMention(tm) | tm)
32+
or
33+
type = "Non-unique certain type information" and
34+
result = count(AstNode n, TypePath path | nonUniqueCertainType(n, path) | n)
3035
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
nonUniqueCertainType
2+
| web_frameworks.rs:139:30:139:39 | ...::get(...) | |
3+
| web_frameworks.rs:140:34:140:43 | ...::get(...) | |
4+
| web_frameworks.rs:141:30:141:39 | ...::get(...) | |

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2352,7 +2352,7 @@ mod loops {
23522352
#[rustfmt::skip]
23532353
let _ = while a < 10 // $ target=lt type=a:i64
23542354
{
2355-
a += 1; // $ type=a:i64 target=add_assign
2355+
a += 1; // $ type=a:i64 MISSING: target=add_assign
23562356
};
23572357
}
23582358
}

0 commit comments

Comments
 (0)