|
| 1 | + |
| 2 | + |
| 3 | +import Foundation |
| 4 | + |
| 5 | +struct SamplePair { |
| 6 | + |
| 7 | + let index1: Int |
| 8 | + let index2: Int |
| 9 | + let distance: Double |
| 10 | + |
| 11 | + init(idx1: Int, idx2: Int, distance: Double = 1) { |
| 12 | + self.distance = distance |
| 13 | + if (idx1 < idx2) { |
| 14 | + index1 = idx1 |
| 15 | + index2 = idx2 |
| 16 | + } else { |
| 17 | + index1 = idx2 |
| 18 | + index2 = idx1 |
| 19 | + } |
| 20 | + } |
| 21 | +} |
| 22 | + |
| 23 | +struct OrderedSamplePair { |
| 24 | + |
| 25 | + let index1: Int |
| 26 | + let index2: Int |
| 27 | + let distance: Double |
| 28 | + |
| 29 | + init(idx1: Int, idx2: Int, distance: Double = 1) { |
| 30 | + self.distance = distance |
| 31 | + index1 = idx1 |
| 32 | + index2 = idx2 |
| 33 | + } |
| 34 | +} |
| 35 | + |
| 36 | +struct Pair { |
| 37 | + let index1: Int |
| 38 | + let index2: Int |
| 39 | + |
| 40 | + init(idx1: Int, idx2: Int) { |
| 41 | + index1 = idx1 |
| 42 | + index2 = idx2 |
| 43 | + } |
| 44 | +} |
| 45 | + |
| 46 | +class ChineseWhispers { |
| 47 | + |
| 48 | + func convertUnorderedToOrdered(edges: [SamplePair]) -> [OrderedSamplePair] { |
| 49 | + var orderdPairs: [OrderedSamplePair] = [] |
| 50 | + orderdPairs.reserveCapacity(edges.count*2) |
| 51 | + |
| 52 | + for (i, _) in edges.enumerated() { |
| 53 | + orderdPairs.append(OrderedSamplePair(idx1: edges[i].index1, idx2: edges[i].index2, distance: edges[i].distance)) |
| 54 | + if edges[i].index1 != edges[i].index2 { |
| 55 | + orderdPairs.append(OrderedSamplePair(idx1: edges[i].index2, idx2: edges[i].index1, distance: edges[i].distance)) |
| 56 | + } |
| 57 | + } |
| 58 | + return orderdPairs.sorted { (a, b) -> Bool in |
| 59 | + a.index1 < b.index1 || (a.index1 == b.index1 && a.index2 < b.index2) |
| 60 | + } |
| 61 | + } |
| 62 | + |
| 63 | + func findNeighborRanges(edges: [OrderedSamplePair]) -> [Pair] { |
| 64 | + let numNodes: Int = maxIndexPlusOne(pairs: edges) |
| 65 | + var neighbors: [Pair] = Array(repeating: Pair(idx1: 0, idx2: 0), count: Int(numNodes)) |
| 66 | + var cur_node: Int = 0 |
| 67 | + var start_idx: Int = 0 |
| 68 | + |
| 69 | + for (index, value) in edges.enumerated() { |
| 70 | + if value.index1 != cur_node { |
| 71 | + neighbors[Int(cur_node)] = Pair(idx1: start_idx, |
| 72 | + idx2: index) |
| 73 | + start_idx = index |
| 74 | + cur_node = value.index1 |
| 75 | + } |
| 76 | + } |
| 77 | + if !neighbors.isEmpty { |
| 78 | + neighbors[Int(cur_node)] = Pair(idx1: start_idx, |
| 79 | + idx2: edges.count) |
| 80 | + } |
| 81 | + return neighbors |
| 82 | + } |
| 83 | + |
| 84 | + func maxIndexPlusOne(pairs: [OrderedSamplePair]) -> Int { |
| 85 | + if pairs.count == 0 { |
| 86 | + return 0 |
| 87 | + }else { |
| 88 | + var max_idx: Int = 0 |
| 89 | + for (_, value) in pairs.enumerated() { |
| 90 | + if value.index1 > max_idx { |
| 91 | + max_idx = value.index1 |
| 92 | + } |
| 93 | + if value.index2 > max_idx { |
| 94 | + max_idx = value.index2 |
| 95 | + } |
| 96 | + } |
| 97 | + return max_idx + 1 |
| 98 | + } |
| 99 | + |
| 100 | + } |
| 101 | + |
| 102 | + |
| 103 | + func chinese_whispers<T>( |
| 104 | + objects: [T], |
| 105 | + distanceFunction: (T, T) -> Double, |
| 106 | + eps: Double, |
| 107 | + numIterations: Int) -> [Int] { |
| 108 | + var edges: [SamplePair] = [] |
| 109 | + guard !objects.isEmpty else { |
| 110 | + return [] |
| 111 | + } |
| 112 | + for i in 0...objects.count-1 { |
| 113 | + for j in i...objects.count-1 { |
| 114 | + let length = distanceFunction(objects[i], objects[j]) |
| 115 | + if length < eps { |
| 116 | + edges.append(SamplePair(idx1: i, idx2: j)) |
| 117 | + } |
| 118 | + } |
| 119 | + } |
| 120 | + return chinese_whispers(edges: edges, numIterations: numIterations) |
| 121 | + } |
| 122 | + |
| 123 | + func chinese_whispers( |
| 124 | + edges: [SamplePair], |
| 125 | + numIterations: Int) -> [Int] { |
| 126 | + let orderdEdges = convertUnorderedToOrdered(edges: edges) |
| 127 | + return chinese_whispers(edges: orderdEdges, numIterations: numIterations) |
| 128 | + } |
| 129 | + |
| 130 | + func chinese_whispers( |
| 131 | + edges: [OrderedSamplePair], |
| 132 | + numIterations: Int) -> [Int] { |
| 133 | + let neighbors = findNeighborRanges(edges: edges) |
| 134 | + var labels: [Int] = Array<Int>(repeating: 0, count: neighbors.count) |
| 135 | + for i in 0...neighbors.count-1 { |
| 136 | + labels[i] = i |
| 137 | + } |
| 138 | + |
| 139 | + for _ in 0..<(neighbors.count * numIterations) { |
| 140 | + |
| 141 | + // Pick a random node. |
| 142 | + let idx = Int.random(in: 0..<neighbors.count) |
| 143 | + var labels_to_counts: [Int: Double] = [:] |
| 144 | + let end: Int = Int(neighbors[idx].index2) |
| 145 | + |
| 146 | + for i in Int(neighbors[idx].index1)..<end { |
| 147 | + |
| 148 | + labels_to_counts[labels[Int(edges[i].index2)], default: 0] += Double(edges[i].distance) |
| 149 | + } |
| 150 | + |
| 151 | + var bestScore: Double = -Double.infinity |
| 152 | + var bestLabel = labels[idx] |
| 153 | + let sorted_labels_to_counts = labels_to_counts.sorted { (a, b) -> Bool in |
| 154 | + a.key < b.key |
| 155 | + } |
| 156 | + sorted_labels_to_counts.forEach { (i) in |
| 157 | + if i.value > bestScore { |
| 158 | + bestScore = i.value |
| 159 | + bestLabel = i.key |
| 160 | + } |
| 161 | + } |
| 162 | + labels[idx] = bestLabel |
| 163 | + } |
| 164 | + var label_remap: [Int: Int] = [:] |
| 165 | + for (i, _) in labels.enumerated() { |
| 166 | + let next_id = label_remap.count |
| 167 | + if label_remap[labels[i]] == nil { |
| 168 | + label_remap[labels[i]] = next_id |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + for (i, _) in labels.enumerated() { |
| 173 | + labels[i] = label_remap[labels[i]] ?? 0 |
| 174 | + } |
| 175 | + return labels |
| 176 | + } |
| 177 | + |
| 178 | + func group<T>(objects: [T], labels: [Int]) -> [[T]] { |
| 179 | + var cluster: [Int : [T]] = [:] |
| 180 | + for (i, value) in labels.enumerated() { |
| 181 | + if cluster[value] != nil { |
| 182 | + cluster[value]!.append(objects[i]) |
| 183 | + }else { |
| 184 | + cluster[value] = [objects[i]] |
| 185 | + } |
| 186 | + } |
| 187 | + return cluster.map { (_ , value) -> [T] in |
| 188 | + value |
| 189 | + } |
| 190 | + } |
| 191 | +} |
| 192 | + |
| 193 | + |
| 194 | + |
| 195 | + |
| 196 | +let a = 1.1 |
| 197 | +let b = 1.2 |
| 198 | +let c = 1.3 |
| 199 | +let d = 4.0 |
| 200 | + |
| 201 | +let chineseWhispers = ChineseWhispers() |
| 202 | +let labels = chineseWhispers.chinese_whispers(objects: [a, b, c, d], distanceFunction: { (a, b) -> Double in |
| 203 | + abs(a-b) |
| 204 | +}, eps: 0.4, numIterations: 100) |
| 205 | +let group = chineseWhispers.group(objects: [a, b, c, d], labels: labels) |
0 commit comments