Skip to content

Commit 3ba285c

Browse files
committed
Rust: Implement certain type information for annotation and simple calls
1 parent c3349bb commit 3ba285c

File tree

7 files changed

+163
-405
lines changed

7 files changed

+163
-405
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 150 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,13 @@ private module M2 = Make2<Input2>;
221221

222222
private import M2
223223

224-
module Consistency = M2::Consistency;
224+
module Consistency {
225+
import M2::Consistency
226+
227+
query predicate nonUniqueCertainType(AstNode n, TypePath path) {
228+
strictcount(CertainTypeInference::inferCertainType(n, path)) > 1
229+
}
230+
}
225231

226232
/** Gets the type annotation that applies to `n`, if any. */
227233
private TypeMention getTypeAnnotation(AstNode n) {
@@ -249,6 +255,134 @@ private Type inferAnnotatedType(AstNode n, TypePath path) {
249255
result = getTypeAnnotation(n).resolveTypeAt(path)
250256
}
251257

258+
/** Module for inferring certain type information. */
259+
private module CertainTypeInference {
260+
/** Holds if the type mention does not contain any inferred types `_`. */
261+
predicate typeMentionIsComplete(TypeMention tm) {
262+
not exists(InferTypeRepr t | t.getParentNode*() = tm)
263+
}
264+
265+
/**
266+
* Holds if `ce` is a call where we can infer the type with certainty and if
267+
* `f` is the target of the call and `p` the path invoked by the call.
268+
*
269+
* Necessary conditions for this are:
270+
* - We are certain of the call target (i.e., the call target can not depend on type information).
271+
* - The declared type of the function does not contain any generics that we
272+
* need to infer.
273+
* - The call does not contain any arguments, as arguments in calls are coercion sites.
274+
*
275+
* The current requirements are made to allow for call to `new` functions such
276+
* as `Vec<Foo>::new()` but not much more.
277+
*/
278+
predicate certainCallExprTarget(CallExpr ce, Function f, Path p) {
279+
p = CallExprImpl::getFunctionPath(ce) and
280+
f = resolvePath(p) and
281+
// The function is not in a trait
282+
not any(TraitItemNode t).getAnAssocItem() = f and
283+
// The function is not in a trait implementation
284+
not any(ImplItemNode impl | impl.(Impl).hasTrait()).getAnAssocItem() = f and
285+
// The function does not have parameters.
286+
not f.getParamList().hasSelfParam() and
287+
f.getParamList().getNumberOfParams() = 0 and
288+
// The function is not async.
289+
not f.isAsync() and
290+
// For now, exclude functions in macro expansions.
291+
not ce.isInMacroExpansion() and
292+
// The function has no type parameters.
293+
not f.hasGenericParamList() and
294+
// The function does not have `impl` types among its parameters (these are type parameters).
295+
not any(ImplTraitTypeRepr itt | not itt.isInReturnPos()).getFunction() = f and
296+
(
297+
not exists(ImplItemNode impl | impl.getAnAssocItem() = f)
298+
or
299+
// If the function is in an impl then the impl block has no type
300+
// parameters or all the type parameters are given explicitly.
301+
exists(ImplItemNode impl | impl.getAnAssocItem() = f |
302+
not impl.(Impl).hasGenericParamList() or
303+
impl.(Impl).getGenericParamList().getNumberOfGenericParams() =
304+
p.getQualifier().getSegment().getGenericArgList().getNumberOfGenericArgs()
305+
)
306+
)
307+
}
308+
309+
private ImplItemNode getFunctionImpl(FunctionItemNode f) { result.getAnAssocItem() = f }
310+
311+
Type inferCertainCallExprType(CallExpr ce, TypePath path) {
312+
exists(Function f, Type ty, TypePath prefix, Path p |
313+
certainCallExprTarget(ce, f, p) and
314+
ty = f.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(prefix)
315+
|
316+
if ty.(TypeParamTypeParameter).getTypeParam() = getFunctionImpl(f).getTypeParam(_)
317+
then
318+
exists(TypePath pathToTp, TypePath suffix |
319+
// For type parameters of the `impl` block we must resolve their
320+
// instantiation from the path. For instance, for `impl<A> for Foo<A>`
321+
// and the path `Foo<i64>::bar` we must resolve `A` to `i64`.
322+
ty = getFunctionImpl(f).(Impl).getSelfTy().(TypeMention).resolveTypeAt(pathToTp) and
323+
result = p.getQualifier().(TypeMention).resolveTypeAt(pathToTp.appendInverse(suffix)) and
324+
path = prefix.append(suffix)
325+
)
326+
else (
327+
result = ty and path = prefix
328+
)
329+
)
330+
}
331+
332+
predicate certainTypeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
333+
prefix1.isEmpty() and
334+
prefix2.isEmpty() and
335+
(
336+
exists(Variable v | n1 = v.getAnAccess() |
337+
n2 = v.getPat().getName() or n2 = v.getParameter().(SelfParam)
338+
)
339+
or
340+
// A `let` statement with a type annotation is a coercion site and hence
341+
// is not a certain type equality.
342+
exists(LetStmt let | not let.hasTypeRepr() |
343+
let.getPat() = n1 and
344+
let.getInitializer() = n2
345+
)
346+
)
347+
or
348+
n1 =
349+
any(IdentPat ip |
350+
n2 = ip.getName() and
351+
prefix1.isEmpty() and
352+
if ip.isRef() then prefix2 = TypePath::singleton(TRefTypeParameter()) else prefix2.isEmpty()
353+
)
354+
}
355+
356+
pragma[nomagic]
357+
private Type inferCertainTypeEquality(AstNode n, TypePath path) {
358+
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
359+
result = inferCertainType(n2, prefix2.appendInverse(suffix)) and
360+
path = prefix1.append(suffix)
361+
|
362+
certainTypeEquality(n, prefix1, n2, prefix2)
363+
or
364+
certainTypeEquality(n2, prefix2, n, prefix1)
365+
)
366+
}
367+
368+
/**
369+
* Holds if `n` has complete and certain type information and if `n` has the
370+
* resulting type at `path`.
371+
*/
372+
pragma[nomagic]
373+
Type inferCertainType(AstNode n, TypePath path) {
374+
exists(TypeMention tm |
375+
tm = getTypeAnnotation(n) and
376+
typeMentionIsComplete(tm) and
377+
result = tm.resolveTypeAt(path)
378+
)
379+
or
380+
result = inferCertainCallExprType(n, path)
381+
or
382+
result = inferCertainTypeEquality(n, path)
383+
}
384+
}
385+
252386
private Type inferLogicalOperationType(AstNode n, TypePath path) {
253387
exists(Builtins::BuiltinType t, BinaryLogicalOperation be |
254388
n = [be, be.getLhs(), be.getRhs()] and
@@ -288,15 +422,11 @@ private Struct getRangeType(RangeExpr re) {
288422
* through the type equality.
289423
*/
290424
private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
425+
CertainTypeInference::certainTypeEquality(n1, prefix1, n2, prefix2)
426+
or
291427
prefix1.isEmpty() and
292428
prefix2.isEmpty() and
293429
(
294-
exists(Variable v | n1 = v.getAnAccess() |
295-
n2 = v.getPat().getName()
296-
or
297-
n2 = v.getParameter().(SelfParam)
298-
)
299-
or
300430
exists(LetStmt let |
301431
let.getPat() = n1 and
302432
let.getInitializer() = n2
@@ -339,13 +469,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
339469
n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion()
340470
)
341471
or
342-
n1 =
343-
any(IdentPat ip |
344-
n2 = ip.getName() and
345-
prefix1.isEmpty() and
346-
if ip.isRef() then prefix2 = TypePath::singleton(TRefTypeParameter()) else prefix2.isEmpty()
347-
)
348-
or
349472
(
350473
n1 = n2.(RefExpr).getExpr() or
351474
n1 = n2.(RefPat).getPat()
@@ -408,6 +531,9 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
408531

409532
pragma[nomagic]
410533
private Type inferTypeEquality(AstNode n, TypePath path) {
534+
// Don't propagate type information into a node for which we already have
535+
// certain type information.
536+
not exists(CertainTypeInference::inferCertainType(n, _)) and
411537
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
412538
result = inferType(n2, prefix2.appendInverse(suffix)) and
413539
path = prefix1.append(suffix)
@@ -818,6 +944,8 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
818944
}
819945

820946
final class Access extends Call {
947+
Access() { not CertainTypeInference::certainCallExprTarget(this, _, _) }
948+
821949
pragma[nomagic]
822950
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
823951
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
@@ -2150,6 +2278,8 @@ private module Cached {
21502278
cached
21512279
Type inferType(AstNode n, TypePath path) {
21522280
Stages::TypeInferenceStage::ref() and
2281+
result = CertainTypeInference::inferCertainType(n, path)
2282+
or
21532283
result = inferAnnotatedType(n, path)
21542284
or
21552285
result = inferLogicalOperationType(n, path)
@@ -2305,4 +2435,10 @@ private module Debug {
23052435
c = countTypePaths(n, path, t) and
23062436
c = max(countTypePaths(_, _, _))
23072437
}
2438+
2439+
Type debugInferCertainNonUniqueType(AstNode n, TypePath path) {
2440+
n = getRelevantLocatable() and
2441+
Consistency::nonUniqueCertainType(n, path) and
2442+
result = CertainTypeInference::inferCertainType(n, path)
2443+
}
23082444
}

rust/ql/lib/codeql/rust/internal/TypeInferenceConsistency.qll

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
* Provides classes for recognizing type inference inconsistencies.
33
*/
44

5+
private import rust
56
private import Type
67
private import TypeMention
8+
private import TypeInference
79
private import TypeInference::Consistency as Consistency
810
import TypeInference::Consistency
911

@@ -27,4 +29,7 @@ int getTypeInferenceInconsistencyCounts(string type) {
2729
or
2830
type = "Ill-formed type mention" and
2931
result = count(TypeMention tm | illFormedTypeMention(tm) | tm)
32+
or
33+
type = "Non-unique certain type information" and
34+
result = count(AstNode n, TypePath path | nonUniqueCertainType(n, path) | n)
3035
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
nonUniqueCertainType
2+
| web_frameworks.rs:139:30:139:39 | ...::get(...) | |
3+
| web_frameworks.rs:140:34:140:43 | ...::get(...) | |
4+
| web_frameworks.rs:141:30:141:39 | ...::get(...) | |

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2352,7 +2352,7 @@ mod loops {
23522352
#[rustfmt::skip]
23532353
let _ = while a < 10 // $ target=lt type=a:i64
23542354
{
2355-
a += 1; // $ type=a:i64 target=add_assign
2355+
a += 1; // $ type=a:i64 MISSING: target=add_assign
23562356
};
23572357
}
23582358
}

0 commit comments

Comments
 (0)