Skip to content

Commit 398b471

Browse files
committed
fix(travelling-salesman): debugged, now works on 3-node graph
1 parent a34c3d5 commit 398b471

File tree

1 file changed

+40
-13
lines changed

1 file changed

+40
-13
lines changed

algorithms/travelling_salesman.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
Jessica Yung
66
Dec 2018
77
"""
8+
import numpy as np
89

910
class TravellingSalesman:
1011

11-
def __init__(self, graph, start):
12+
def __init__(self, graph, start=0):
1213
"""Initialise with graph and node you start from.
1314
:param graph: takes the form of an adjacency matrix
1415
(suitable since we are given a fully connected graph).
@@ -17,35 +18,61 @@ def __init__(self, graph, start):
1718
"""
1819
self.graph = graph
1920
self.start = start
20-
self.nodes = np.arange(len(graph))
21+
self.nodes = list(np.arange(len(graph)))
22+
self.cost_dict = {}
2123

2224
def cost(self, nodes, end):
23-
if (nodes, end) is in self.cost_dict.keys():
24-
return self.cost_dict[(nodes, end)]
25+
if (tuple(nodes), end) in self.cost_dict.keys():
26+
return self.cost_dict[(tuple(nodes), end)]
2527
else:
26-
self.cost_dict[nodes, end] = self.calc_cost(nodes, end)
27-
return self.cost_dict[nodes, end]
28+
self.cost_dict[tuple(nodes), end] = self.calc_cost(nodes, end)
29+
return self.cost_dict[tuple(nodes), end]
2830

2931
def calc_cost(self, nodes, end):
3032
if end not in nodes:
3133
return Exception("Endpoint not in nodes to visit.")
34+
# print("Nodes: {}".format(nodes))
3235
if len(nodes) == 1:
3336
return 0
3437
if len(nodes) == 2:
3538
return self.graph[nodes[0], nodes[1]]
36-
non_end_nodes = nodes.copy().remove(end)
37-
return min(self.cost(non_end_nodes, j) + self.graph[j, end] for j in non_end_nodes if j != self.start)
39+
non_end_nodes = nodes.copy()
40+
non_end_nodes.remove(end)
41+
temp = [self.cost(non_end_nodes, j) + self.graph[j, end] for j in non_end_nodes if j != self.start]
42+
# print("Non end nodes: {}".format(non_end_nodes))
43+
# print("End: ", end)
44+
# for j in non_end_nodes:
45+
# if j != self.start:
46+
# print(self.cost(non_end_nodes, j))
47+
# print(self.graph[j, end])
48+
# print("Graph: ", self.graph)
49+
# print("j={}, end={}".format(j, end))
50+
# print("cost candidates:", temp)
51+
return min(temp)
3852

3953
def dp(self):
4054
"""Dynamic programming solution to Travelling Salesman problem."""
41-
return self.cost(self.nodes, self.start)
55+
# calculate costs
56+
return min(self.cost(self.nodes, i) + self.graph[i, 0] for i in self.nodes[1:])
57+
# return self.cost(self.nodes, self.start)
4258

4359

4460
# test case:
45-
def create_adj_matrix(dists):
61+
def create_adj_matrix(distances):
4662
"""dists: (n-1)x(n-1) matrix with (n-1)*n/2 entries
47-
dists from 0 to 1, 2,...n-1, then dists from 1 to 2,...,n-1.
48-
cells that don't represent dists are left as zeroes.
63+
dists from 0 to 1, 2,...n-1, then dists from 1 to 2,...,n-1, up to dists from n-1.
64+
cells that don't represent dists in input may not exist or can exist but are ignored.
4965
"""
50-
pass
66+
n = len(distances) + 1
67+
mat = np.diag(np.ones(n)*np.inf)
68+
for i in range(n-1):
69+
for j in range(n-i-1):
70+
mat[i, j+i+1] = mat[j+i+1, i] = distances[i][j]
71+
return mat
5172

73+
dists = create_adj_matrix([[4, 3],[2]])
74+
# print(dists)
75+
ts = TravellingSalesman(dists, 0)
76+
soln = ts.dp()
77+
print("Min dist:", soln)
78+
# print(ts.cost_dict)

0 commit comments

Comments
 (0)