Skip to content

Commit 13d9d8a

Browse files
authored
Merge pull request #20122 from paldepind/rust/type-inference-dyn-assoc
Rust: Fix type inference for trait objects for traits with associated types
2 parents 4b947db + b2ee625 commit 13d9d8a

File tree

7 files changed

+3457
-3191
lines changed

7 files changed

+3457
-3191
lines changed

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,27 @@ private import codeql.rust.internal.CachedStages
77
private import codeql.rust.elements.internal.generated.Raw
88
private import codeql.rust.elements.internal.generated.Synth
99

10+
/**
11+
* Holds if a dyn trait type should have a type parameter associated with `n`. A
12+
* dyn trait type inherits the type parameters of the trait it implements. That
13+
* includes the type parameters corresponding to associated types.
14+
*
15+
* For instance in
16+
* ```rust
17+
* trait SomeTrait<A> {
18+
* type AssociatedType;
19+
* }
20+
* ```
21+
* this predicate holds for the nodes `A` and `type AssociatedType`.
22+
*/
23+
private predicate dynTraitTypeParameter(Trait trait, AstNode n) {
24+
trait = any(DynTraitTypeRepr dt).getTrait() and
25+
(
26+
n = trait.getGenericParamList().getATypeParam() or
27+
n = trait.(TraitItemNode).getAnAssocItem().(TypeAlias)
28+
)
29+
}
30+
1031
cached
1132
newtype TType =
1233
TTuple(int arity) {
@@ -30,9 +51,7 @@ newtype TType =
3051
TTypeParamTypeParameter(TypeParam t) or
3152
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
3253
TArrayTypeParameter() or
33-
TDynTraitTypeParameter(TypeParam tp) {
34-
tp = any(DynTraitTypeRepr dt).getTrait().getGenericParamList().getATypeParam()
35-
} or
54+
TDynTraitTypeParameter(AstNode n) { dynTraitTypeParameter(_, n) } or
3655
TRefTypeParameter() or
3756
TSelfTypeParameter(Trait t) or
3857
TSliceTypeParameter()
@@ -406,15 +425,35 @@ class ArrayTypeParameter extends TypeParameter, TArrayTypeParameter {
406425
}
407426

408427
class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
409-
private TypeParam typeParam;
428+
private AstNode n;
410429

411-
DynTraitTypeParameter() { this = TDynTraitTypeParameter(typeParam) }
430+
DynTraitTypeParameter() { this = TDynTraitTypeParameter(n) }
412431

413-
TypeParam getTypeParam() { result = typeParam }
432+
Trait getTrait() { dynTraitTypeParameter(result, n) }
414433

415-
override string toString() { result = "dyn(" + typeParam.toString() + ")" }
434+
/** Gets the dyn trait type that this type parameter belongs to. */
435+
DynTraitType getDynTraitType() { result.getTrait() = this.getTrait() }
416436

417-
override Location getLocation() { result = typeParam.getLocation() }
437+
/** Gets the `TypeParam` of this dyn trait type parameter, if any. */
438+
TypeParam getTypeParam() { result = n }
439+
440+
/** Gets the `TypeAlias` of this dyn trait type parameter, if any. */
441+
TypeAlias getTypeAlias() { result = n }
442+
443+
/** Gets the trait type parameter that this dyn trait type parameter corresponds to. */
444+
TypeParameter getTraitTypeParameter() {
445+
result.(TypeParamTypeParameter).getTypeParam() = n
446+
or
447+
result.(AssociatedTypeTypeParameter).getTypeAlias() = n
448+
}
449+
450+
private string toStringInner() {
451+
result = [this.getTypeParam().toString(), this.getTypeAlias().getName().toString()]
452+
}
453+
454+
override string toString() { result = "dyn(" + this.toStringInner() + ")" }
455+
456+
override Location getLocation() { result = n.getLocation() }
418457
}
419458

420459
/** An implicit reference type parameter. */
@@ -503,8 +542,7 @@ final class ImplTypeAbstraction extends TypeAbstraction, Impl {
503542

504543
final class DynTypeAbstraction extends TypeAbstraction, DynTraitTypeRepr {
505544
override TypeParameter getATypeParameter() {
506-
result.(TypeParamTypeParameter).getTypeParam() =
507-
this.getTrait().getGenericParamList().getATypeParam()
545+
result = any(DynTraitTypeParameter tp | tp.getTrait() = this.getTrait()).getTraitTypeParameter()
508546
}
509547
}
510548

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ private module Input1 implements InputSig1<Location> {
9797
id = 2
9898
or
9999
kind = 1 and
100-
id = idOfTypeParameterAstNode(tp0.(DynTraitTypeParameter).getTypeParam())
100+
id =
101+
idOfTypeParameterAstNode([
102+
tp0.(DynTraitTypeParameter).getTypeParam().(AstNode),
103+
tp0.(DynTraitTypeParameter).getTypeAlias()
104+
])
101105
or
102106
kind = 2 and
103107
exists(AstNode node | id = idOfTypeParameterAstNode(node) |

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,10 @@ class DynTraitTypeReprMention extends TypeMention instanceof DynTraitTypeRepr {
324324
result = dynType
325325
or
326326
exists(DynTraitTypeParameter tp, TypePath path0, TypePath suffix |
327-
tp = dynType.getTypeParameter(_) and
327+
dynType = tp.getDynTraitType() and
328328
path = TypePath::cons(tp, suffix) and
329329
result = super.getTypeBoundList().getBound(0).getTypeRepr().(TypeMention).resolveTypeAt(path0) and
330-
path0.isCons(TTypeParamTypeParameter(tp.getTypeParam()), suffix)
330+
path0.isCons(tp.getTraitTypeParameter(), suffix)
331331
)
332332
}
333333
}
@@ -363,10 +363,10 @@ class DynTypeBoundListMention extends TypeMention instanceof TypeBoundList {
363363
path.isEmpty() and
364364
result.(DynTraitType).getTrait() = trait
365365
or
366-
exists(TypeParam param |
367-
param = trait.getGenericParamList().getATypeParam() and
368-
path = TypePath::singleton(TDynTraitTypeParameter(param)) and
369-
result = TTypeParamTypeParameter(param)
366+
exists(DynTraitTypeParameter tp |
367+
trait = tp.getTrait() and
368+
path = TypePath::singleton(tp) and
369+
result = tp.getTraitTypeParameter()
370370
)
371371
}
372372
}
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
multipleCallTargets
22
| dereference.rs:61:15:61:24 | e1.deref() |
3-
| main.rs:2213:13:2213:31 | ...::from(...) |
4-
| main.rs:2214:13:2214:31 | ...::from(...) |
5-
| main.rs:2215:13:2215:31 | ...::from(...) |
6-
| main.rs:2221:13:2221:31 | ...::from(...) |
7-
| main.rs:2222:13:2222:31 | ...::from(...) |
8-
| main.rs:2223:13:2223:31 | ...::from(...) |
3+
| main.rs:2253:13:2253:31 | ...::from(...) |
4+
| main.rs:2254:13:2254:31 | ...::from(...) |
5+
| main.rs:2255:13:2255:31 | ...::from(...) |
6+
| main.rs:2261:13:2261:31 | ...::from(...) |
7+
| main.rs:2262:13:2262:31 | ...::from(...) |
8+
| main.rs:2263:13:2263:31 | ...::from(...) |

rust/ql/test/library-tests/type-inference/dyn_type.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ trait GenericGet<A> {
1212
fn get(&self) -> A;
1313
}
1414

15+
trait AssocTrait<GP> {
16+
type AP;
17+
// AssocTrait::get
18+
fn get(&self) -> (GP, Self::AP);
19+
}
20+
1521
#[derive(Clone, Debug)]
1622
struct MyStruct {
1723
value: i32,
@@ -36,6 +42,17 @@ impl<A: Clone + Debug> GenericGet<A> for GenStruct<A> {
3642
}
3743
}
3844

45+
impl<GGP> AssocTrait<GGP> for GenStruct<GGP>
46+
where
47+
GGP: Clone + Debug,
48+
{
49+
type AP = bool;
50+
// GenStruct<GGP>::get
51+
fn get(&self) -> (GGP, bool) {
52+
(self.value.clone(), true) // $ fieldof=GenStruct target=clone
53+
}
54+
}
55+
3956
fn get_a<A, G: GenericGet<A> + ?Sized>(a: &G) -> A {
4057
a.get() // $ target=GenericGet::get
4158
}
@@ -58,10 +75,34 @@ fn test_poly_dyn_trait() {
5875
let _result = (*obj).get(); // $ target=deref target=GenericGet::get type=_result:bool
5976
}
6077

78+
fn assoc_dyn_get<A, B>(a: &dyn AssocTrait<A, AP = B>) -> (A, B) {
79+
a.get() // $ target=AssocTrait::get
80+
}
81+
82+
fn assoc_get<A, B, T: AssocTrait<A, AP = B> + ?Sized>(a: &T) -> (A, B) {
83+
a.get() // $ target=AssocTrait::get
84+
}
85+
86+
fn test_assoc_type(obj: &dyn AssocTrait<i64, AP = bool>) {
87+
let (
88+
_gp, // $ type=_gp:i64
89+
_ap, // $ type=_ap:bool
90+
) = (*obj).get(); // $ target=deref target=AssocTrait::get
91+
let (
92+
_gp, // $ type=_gp:i64
93+
_ap, // $ type=_ap:bool
94+
) = assoc_dyn_get(obj); // $ target=assoc_dyn_get
95+
let (
96+
_gp, // $ type=_gp:i64
97+
_ap, // $ type=_ap:bool
98+
) = assoc_get(obj); // $ target=assoc_get
99+
}
100+
61101
pub fn test() {
62102
test_basic_dyn_trait(&MyStruct { value: 42 }); // $ target=test_basic_dyn_trait
63103
test_generic_dyn_trait(&GenStruct {
64104
value: "".to_string(),
65105
}); // $ target=test_generic_dyn_trait
66106
test_poly_dyn_trait(); // $ target=test_poly_dyn_trait
107+
test_assoc_type(&GenStruct { value: 100 }); // $ target=test_assoc_type
67108
}

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ mod function_trait_bounds {
653653
}
654654
}
655655

656-
mod trait_associated_type {
656+
mod associated_type_in_trait {
657657
#[derive(Debug)]
658658
struct Wrapper<A> {
659659
field: A,
@@ -803,6 +803,46 @@ mod trait_associated_type {
803803
}
804804
}
805805

806+
mod associated_type_in_supertrait {
807+
trait Supertrait {
808+
type Content;
809+
fn insert(content: Self::Content);
810+
}
811+
812+
trait Subtrait: Supertrait {
813+
// Subtrait::get_content
814+
fn get_content(&self) -> Self::Content;
815+
}
816+
817+
struct MyType<T>(T);
818+
819+
impl<T> Supertrait for MyType<T> {
820+
type Content = T;
821+
fn insert(_content: Self::Content) {
822+
println!("Inserting content: ");
823+
}
824+
}
825+
826+
impl<T: Clone> Subtrait for MyType<T> {
827+
// MyType::get_content
828+
fn get_content(&self) -> Self::Content {
829+
(*self).0.clone() // $ fieldof=MyType target=clone target=deref
830+
}
831+
}
832+
833+
fn get_content<T: Subtrait>(item: &T) -> T::Content {
834+
item.get_content() // $ target=Subtrait::get_content
835+
}
836+
837+
fn test() {
838+
let item1 = MyType(42i64);
839+
let _content1 = item1.get_content(); // $ target=MyType::get_content MISSING: type=_content1:i64
840+
841+
let item2 = MyType(true);
842+
let _content2 = get_content(&item2); // $ target=get_content MISSING: type=_content2:bool
843+
}
844+
}
845+
806846
mod generic_enum {
807847
#[derive(Debug)]
808848
enum MyEnum<A> {
@@ -2469,7 +2509,7 @@ fn main() {
24692509
method_non_parametric_impl::f(); // $ target=f
24702510
method_non_parametric_trait_impl::f(); // $ target=f
24712511
function_trait_bounds::f(); // $ target=f
2472-
trait_associated_type::f(); // $ target=f
2512+
associated_type_in_trait::f(); // $ target=f
24732513
generic_enum::f(); // $ target=f
24742514
method_supertraits::f(); // $ target=f
24752515
function_trait_bounds_2::f(); // $ target=f

0 commit comments

Comments
 (0)