]> jcornell.net Git - ntbd-parcels.git/commitdiff
Add prototype of tool to produce an amenity library from OSM XML
authorJakob Cornell <jakob+gpg@jcornell.net>
Thu, 21 May 2026 06:06:03 +0000 (01:06 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Thu, 21 May 2026 06:06:03 +0000 (01:06 -0500)
extract_amenities.py [new file with mode: 0644]

diff --git a/extract_amenities.py b/extract_amenities.py
new file mode 100644 (file)
index 0000000..3502de1
--- /dev/null
@@ -0,0 +1,47 @@
+"""
+Given an OSM XML file on standard in, classify nodes and ways of particular
+amenity types and write the results to standard out as JSON. Each amenity way
+is reduced to a representative point and the points are grouped by amenity type
+in the output.
+"""
+
+from collections import defaultdict
+from decimal import Decimal
+from xml.etree import ElementTree
+from sys import stdin
+from typing import Optional
+import json
+
+
+def classify(node_or_way: ElementTree.Element) -> Optional[str]:
+    # TODO: finish this
+    if node_or_way.find("tag[@k='shop'][@v='supermarket']") is not None:
+        return "grocery"
+    else:
+        return None
+
+
+node_locations = {}
+locations_by_amenity = defaultdict(list)
+for (event_kind, el) in ElementTree.iterparse(stdin, events=["end"]):
+    assert event_kind == "end"
+    match el.tag:
+        case "node":
+            node_id = int(el.attrib["id"])
+            assert node_id not in node_locations
+            node_locations[node_id] = (Decimal(el.attrib["lat"]), Decimal(el.attrib["lon"]))
+
+            if classification := classify(el):
+                float_location = tuple(map(float, node_locations[node_id]))
+                locations_by_amenity[classification].append(float_location)
+        case "way":
+            if classification := classify(el):
+                way_node_locations = [
+                    node_locations[int(child.attrib["ref"])] for child in el.findall("nd")
+                ]
+                (lats, lons) = zip(*way_node_locations)
+                lat = (min(lats) + max(lats)) / 2
+                lon = (min(lons) + max(lons)) / 2
+                locations_by_amenity[classification].append((float(lat), float(lon)))
+
+print(json.dumps(locations_by_amenity))