5
5
Jessica Yung
6
6
Dec 2018
7
7
"""
8
+ import numpy as np
8
9
9
10
class TravellingSalesman :
10
11
11
- def __init__ (self , graph , start ):
12
+ def __init__ (self , graph , start = 0 ):
12
13
"""Initialise with graph and node you start from.
13
14
:param graph: takes the form of an adjacency matrix
14
15
(suitable since we are given a fully connected graph).
@@ -17,35 +18,61 @@ def __init__(self, graph, start):
17
18
"""
18
19
self .graph = graph
19
20
self .start = start
20
- self .nodes = np .arange (len (graph ))
21
+ self .nodes = list (np .arange (len (graph )))
22
+ self .cost_dict = {}
21
23
22
24
def cost (self , nodes , end ):
23
- if (nodes , end ) is in self .cost_dict .keys ():
24
- return self .cost_dict [(nodes , end )]
25
+ if (tuple ( nodes ) , end ) in self .cost_dict .keys ():
26
+ return self .cost_dict [(tuple ( nodes ) , end )]
25
27
else :
26
- self .cost_dict [nodes , end ] = self .calc_cost (nodes , end )
27
- return self .cost_dict [nodes , end ]
28
+ self .cost_dict [tuple ( nodes ) , end ] = self .calc_cost (nodes , end )
29
+ return self .cost_dict [tuple ( nodes ) , end ]
28
30
29
31
def calc_cost (self , nodes , end ):
30
32
if end not in nodes :
31
33
return Exception ("Endpoint not in nodes to visit." )
34
+ # print("Nodes: {}".format(nodes))
32
35
if len (nodes ) == 1 :
33
36
return 0
34
37
if len (nodes ) == 2 :
35
38
return self .graph [nodes [0 ], nodes [1 ]]
36
- non_end_nodes = nodes .copy ().remove (end )
37
- return min (self .cost (non_end_nodes , j ) + self .graph [j , end ] for j in non_end_nodes if j != self .start )
39
+ non_end_nodes = nodes .copy ()
40
+ non_end_nodes .remove (end )
41
+ temp = [self .cost (non_end_nodes , j ) + self .graph [j , end ] for j in non_end_nodes if j != self .start ]
42
+ # print("Non end nodes: {}".format(non_end_nodes))
43
+ # print("End: ", end)
44
+ # for j in non_end_nodes:
45
+ # if j != self.start:
46
+ # print(self.cost(non_end_nodes, j))
47
+ # print(self.graph[j, end])
48
+ # print("Graph: ", self.graph)
49
+ # print("j={}, end={}".format(j, end))
50
+ # print("cost candidates:", temp)
51
+ return min (temp )
38
52
39
53
def dp (self ):
40
54
"""Dynamic programming solution to Travelling Salesman problem."""
41
- return self .cost (self .nodes , self .start )
55
+ # calculate costs
56
+ return min (self .cost (self .nodes , i ) + self .graph [i , 0 ] for i in self .nodes [1 :])
57
+ # return self.cost(self.nodes, self.start)
42
58
43
59
44
60
# test case:
45
- def create_adj_matrix (dists ):
61
+ def create_adj_matrix (distances ):
46
62
"""dists: (n-1)x(n-1) matrix with (n-1)*n/2 entries
47
- dists from 0 to 1, 2,...n-1, then dists from 1 to 2,...,n-1.
48
- cells that don't represent dists are left as zeroes .
63
+ dists from 0 to 1, 2,...n-1, then dists from 1 to 2,...,n-1, up to dists from n-1 .
64
+ cells that don't represent dists in input may not exist or can exist but are ignored .
49
65
"""
50
- pass
66
+ n = len (distances ) + 1
67
+ mat = np .diag (np .ones (n )* np .inf )
68
+ for i in range (n - 1 ):
69
+ for j in range (n - i - 1 ):
70
+ mat [i , j + i + 1 ] = mat [j + i + 1 , i ] = distances [i ][j ]
71
+ return mat
51
72
73
+ dists = create_adj_matrix ([[4 , 3 ],[2 ]])
74
+ # print(dists)
75
+ ts = TravellingSalesman (dists , 0 )
76
+ soln = ts .dp ()
77
+ print ("Min dist:" , soln )
78
+ # print(ts.cost_dict)
0 commit comments