Skip to content

Commit 69aa6a0

Browse files
authored
[mlir][quant] Fix quantization example. (#151518)
Fix and improve code example [email protected]
1 parent af0be76 commit 69aa6a0

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,11 @@ class QuantizedType : public Type {
143143
/// Casts from a type based on the storageType to a corresponding type based
144144
/// on this type (returns nullptr if the cast is not valid).
145145
/// Examples:
146+
/// `candidate type` -> `return type`
146147
/// i8 -> !quant.uniform<i8:f32, 1.0>
147148
/// tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>>
148149
/// vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>>
150+
/// It is assumed above that this type's quantization is `<i8:f32, 1.0>`.
149151
Type castFromStorageType(Type candidateType);
150152

151153
/// Casts from a type based on a QuantizedType to a corresponding type based

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
127127

128128
Type QuantizedType::castFromStorageType(Type candidateType) {
129129
if (candidateType == getStorageType()) {
130-
// i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
130+
// i.e. i8 -> quant<"uniform[i8:f32]{1.0}">
131131
return *this;
132132
}
133133
if (llvm::isa<RankedTensorType>(candidateType)) {
@@ -137,11 +137,11 @@ Type QuantizedType::castFromStorageType(Type candidateType) {
137137
getStorageType());
138138
}
139139
if (llvm::isa<UnrankedTensorType>(candidateType)) {
140-
// i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
140+
// i.e. tensor<xi8> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
141141
return UnrankedTensorType::get(getStorageType());
142142
}
143143
if (llvm::isa<VectorType>(candidateType)) {
144-
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
144+
// i.e. vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
145145
return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
146146
getStorageType());
147147
}

0 commit comments

Comments
 (0)