summaryrefslogtreecommitdiffstats
path: root/store.py
diff options
context:
space:
mode:
authorGustav Sörnäs <gusso230@student.liu.se>2020-11-19 15:38:33 +0100
committerGustav Sörnäs <gusso230@student.liu.se>2020-11-20 13:39:05 +0100
commitcdcdf5de8846b0fe1cb91cc0c6ec41f5842c7c98 (patch)
tree8abf0ac2a810a036b861890a717336d6b4a67ab5 /store.py
parent10e3d6874745283b545d3b8029a3e5e184932c5d (diff)
downloadtdde25-cdcdf5de8846b0fe1cb91cc0c6ec41f5842c7c98.tar.gz
implement union-find
Diffstat (limited to 'store.py')
-rw-r--r--store.py41
1 files changed, 39 insertions, 2 deletions
diff --git a/store.py b/store.py
index 9e6d4f7..f98bb25 100644
--- a/store.py
+++ b/store.py
@@ -1,4 +1,5 @@
from osm_parser import get_default_parser
+from collections import defaultdict
class Node:
@@ -8,6 +9,27 @@ class Node:
self.lng = float(lng)
self.neighbours = []
+ self.parent = None
+ self.size = 1
+
+ def find_root(self):
+ if not self.parent:
+ return self
+ else:
+ return self.parent.find_root()
+
+ def union(self, other):
+ this = self.find_root()
+ other = other.find_root()
+
+ if this == other:
+ return
+
+ if this.size < other.size:
+ this, other = other, this
+
+ other.parent = this
+ this.size += other.size
def coord_tuple(self):
return self.lat, self.lng
@@ -29,6 +51,7 @@ def add_neighbours(nodes):
nodes[node1].neighbours.append(nodes[node2])
nodes[node2].neighbours.append(nodes[node1])
+ nodes[node1].union(nodes[node2])
return nodes
@@ -48,11 +71,25 @@ def extract_osm_nodes(f_name):
if not node.neighbours:
del nodes[node_id]
- return nodes
+ # construct a forest of disjoint trees using union-find
+ forest = defaultdict(dict) # contains {root: tree}
+ for node_id, node in nodes.items():
+ forest[node.find_root().id][node_id] = node
+
+ # find the largest disjoin tree
+ best_size = 0
+ best_tree = None
+ for root in forest:
+ tree = forest[root]
+ size = len(tree)
+ if size > best_size:
+ best_size = size
+ best_tree = tree
+
+ return best_tree
def select_nodes_in_rectangle(nodes, min_lat, max_lat, min_long, max_long):
return [node for node in nodes.values()
if min_lat <= node.lat <= max_lat
and min_long <= node.lng <= max_long]
-