--- /dev/null
+"""
+Given an OSM XML file on standard in, classify nodes and ways of particular
+amenity types and write the results to standard out as JSON. Each amenity way
+is reduced to a representative point and the points are grouped by amenity type
+in the output.
+"""
+
+from collections import defaultdict
+from decimal import Decimal
+from xml.etree import ElementTree
+from sys import stdin
+from typing import Optional
+import json
+
+
+def classify(node_or_way: ElementTree.Element) -> Optional[str]:
+ # TODO: finish this
+ if node_or_way.find("tag[@k='shop'][@v='supermarket']") is not None:
+ return "grocery"
+ else:
+ return None
+
+
+node_locations = {}
+locations_by_amenity = defaultdict(list)
+for (event_kind, el) in ElementTree.iterparse(stdin, events=["end"]):
+ assert event_kind == "end"
+ match el.tag:
+ case "node":
+ node_id = int(el.attrib["id"])
+ assert node_id not in node_locations
+ node_locations[node_id] = (Decimal(el.attrib["lat"]), Decimal(el.attrib["lon"]))
+
+ if classification := classify(el):
+ float_location = tuple(map(float, node_locations[node_id]))
+ locations_by_amenity[classification].append(float_location)
+ case "way":
+ if classification := classify(el):
+ way_node_locations = [
+ node_locations[int(child.attrib["ref"])] for child in el.findall("nd")
+ ]
+ (lats, lons) = zip(*way_node_locations)
+ lat = (min(lats) + max(lats)) / 2
+ lon = (min(lons) + max(lons)) / 2
+ locations_by_amenity[classification].append((float(lat), float(lon)))
+
+print(json.dumps(locations_by_amenity))