game/edit/traffic_signals/
gmns.rs

1// TODO Move to map_model
2
3use std::collections::{BTreeMap, BTreeSet, HashMap};
4use std::io::Cursor;
5
6use anyhow::Result;
7use serde::{Deserialize, Deserializer};
8
9use geom::{Angle, Duration, LonLat, Pt2D};
10use map_model::{
11    osm, ControlTrafficSignal, DirectedRoadID, DrivingSide, EditIntersectionControl,
12    IntersectionID, Map, Movement, MovementID, Stage, StageType, TurnPriority, TurnType,
13};
14use widgetry::tools::PopupMsg;
15use widgetry::{EventCtx, State};
16
17use crate::edit::apply_map_edits;
18use crate::App;
19
20/// This imports timing.csv from https://github.com/asu-trans-ai-lab/Vol2Timing. It operates in a
21/// best-effort / permissive mode, skipping over mismatched movements and other problems and should
22/// still be considered experimental.
23pub fn import(map: &Map, i: IntersectionID, bytes: &Vec<u8>) -> Result<ControlTrafficSignal> {
24    let i = map.get_i(i);
25    let mut matches_per_plan: BTreeMap<String, Vec<Record>> = BTreeMap::new();
26    for rec in csv::Reader::from_reader(Cursor::new(bytes)).deserialize() {
27        let rec: Record = rec?;
28        if !rec.osm_ids.contains(&i.orig_id) {
29            continue;
30        }
31        matches_per_plan
32            .entry(rec.timing_plan_id.clone())
33            .or_insert_with(Vec::new)
34            .push(rec);
35    }
36
37    // For now, just use any arbitrary plan
38    let mut records = matches_per_plan
39        .into_iter()
40        .next()
41        .ok_or_else(|| anyhow!("no matches for {}", i.orig_id))?
42        .1;
43    records.sort_by_key(|rec| rec.stage);
44
45    let snapper = Snapper::new(map, i.id)?;
46
47    let mut signal = ControlTrafficSignal::new(map, i.id);
48    signal.stages.clear();
49    for rec in records {
50        let stage_idx = rec.stage - 1;
51        match signal.stages.len().cmp(&stage_idx) {
52            std::cmp::Ordering::Equal => {
53                signal.stages.push(Stage {
54                    protected_movements: BTreeSet::new(),
55                    yield_movements: BTreeSet::new(),
56                    stage_type: StageType::Fixed(Duration::seconds(rec.green_time as f64)),
57                });
58            }
59            std::cmp::Ordering::Less => {
60                bail!("missing intermediate stage");
61            }
62            std::cmp::Ordering::Greater => {}
63        }
64        let stage = &mut signal.stages[stage_idx];
65
66        if stage.stage_type.simple_duration() != Duration::seconds(rec.green_time as f64) {
67            bail!(
68                "Stage {} has green_times {} and {}",
69                rec.stage,
70                stage.stage_type.simple_duration(),
71                rec.green_time
72            );
73        }
74
75        let mvmnt = match snapper.get_mvmnt(
76            (
77                rec.geometry.0.to_pt(map.get_gps_bounds()),
78                rec.geometry.1.to_pt(map.get_gps_bounds()),
79            ),
80            &rec.mvmt_txt_id,
81            map,
82        ) {
83            Ok(x) => x,
84            Err(err) => {
85                error!(
86                    "Skipping {} -> {} for stage {}: {}",
87                    rec.geometry.0, rec.geometry.1, rec.stage, err
88                );
89                continue;
90            }
91        };
92        if rec.protection == "protected" {
93            stage.protected_movements.insert(mvmnt);
94        } else {
95            stage.yield_movements.insert(mvmnt);
96        }
97    }
98
99    add_crosswalks(&mut signal, map);
100
101    Ok(signal)
102}
103
104pub fn import_all(
105    ctx: &mut EventCtx,
106    app: &mut App,
107    path: &str,
108    bytes: Vec<u8>,
109) -> Box<dyn State<App>> {
110    let all_signals: Vec<IntersectionID> = app
111        .primary
112        .map
113        .all_intersections()
114        .iter()
115        .filter_map(|i| {
116            if i.is_traffic_signal() {
117                Some(i.id)
118            } else {
119                None
120            }
121        })
122        .collect();
123    let mut successes = 0;
124    let mut failures_no_match = 0;
125    let mut failures_other = 0;
126    let mut edits = app.primary.map.get_edits().clone();
127
128    ctx.loading_screen("import signal timing", |_, timer| {
129        timer.start_iter("import", all_signals.len());
130        for i in all_signals {
131            timer.next();
132            match import(&app.primary.map, i, &bytes)
133                .and_then(|signal| signal.validate(app.primary.map.get_i(i)).map(|_| signal))
134            {
135                Ok(signal) => {
136                    info!("Success at {}", i);
137                    successes += 1;
138                    edits
139                        .commands
140                        .push(app.primary.map.edit_intersection_cmd(i, |new| {
141                            new.control = EditIntersectionControl::TrafficSignal(
142                                signal.export(&app.primary.map),
143                            );
144                        }));
145                }
146                Err(err) => {
147                    error!("Failure at {}: {}", i, err);
148                    if err.to_string().contains("no matches for") {
149                        failures_no_match += 1;
150                    } else {
151                        failures_other += 1;
152                    }
153                }
154            }
155        }
156    });
157
158    apply_map_edits(ctx, app, edits);
159
160    PopupMsg::new_state(
161        ctx,
162        &format!("Import from {}", path),
163        vec![
164            format!("{} traffic signals successfully imported", successes),
165            format!("{} intersections without any data", failures_no_match),
166            format!("{} other failures", failures_other),
167        ],
168    )
169}
170
171#[derive(Debug, Deserialize)]
172struct Record {
173    #[serde(deserialize_with = "parse_osm_ids", rename = "osm_node_id")]
174    osm_ids: Vec<osm::NodeID>,
175    timing_plan_id: String,
176    green_time: usize,
177    #[serde(rename = "stage_no")]
178    stage: usize,
179    #[serde(deserialize_with = "parse_linestring")]
180    geometry: (LonLat, LonLat),
181    protection: String,
182    // Something like EBL or NBT -- eastbound left, northbound through.
183    mvmt_txt_id: String,
184}
185
186fn parse_linestring<'de, D: Deserializer<'de>>(d: D) -> Result<(LonLat, LonLat), D::Error> {
187    let raw = <String>::deserialize(d)?;
188    let pts = LonLat::parse_wkt_linestring(&raw)
189        .ok_or_else(|| serde::de::Error::custom(format!("bad linestring {}", raw)))?;
190    if pts.len() != 2 {
191        return Err(serde::de::Error::custom(format!(
192            "{} points, expecting 2",
193            pts.len()
194        )));
195    }
196    Ok((pts[0], pts[1]))
197}
198
199fn parse_osm_ids<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<osm::NodeID>, D::Error> {
200    let raw = <String>::deserialize(d)?;
201    let mut ids = Vec::new();
202    for id in raw.split('_') {
203        ids.push(osm::NodeID(id.parse::<i64>().map_err(|_| {
204            serde::de::Error::custom(format!("bad ID {}", id))
205        })?));
206    }
207    Ok(ids)
208}
209
210/// Snaps a line to a vehicle movement across an intersection. It uses movement endpoints and a
211/// hint about turn type to match.
212///
213/// OSM IDs aren't used to snap, because GMNS and A/B Street may disagree about where a road
214/// segment begins/ends. This could happen from OSM IDs changing over time or from different rules
215/// about importing things like service roads.
216struct Snapper {
217    roads_incoming: HashMap<DirectedRoadID, Pt2D>,
218    roads_outgoing: HashMap<DirectedRoadID, Pt2D>,
219    movements: BTreeMap<MovementID, Movement>,
220}
221
222impl Snapper {
223    fn new(map: &Map, i: IntersectionID) -> Result<Snapper> {
224        let mut roads_incoming = HashMap::new();
225        let mut roads_outgoing = HashMap::new();
226        for r in &map.get_i(i).roads {
227            let r = map.get_r(*r);
228
229            let incoming_id = r.directed_id_to(i);
230            let outgoing_id = r.directed_id_from(i);
231
232            // TODO There are a few methods for finding the "middle" of a directed road; here's yet
233            // another.
234            let mut incoming_pts = Vec::new();
235            let mut outgoing_pts = Vec::new();
236
237            for l in &r.lanes {
238                if l.lane_type.is_walkable() {
239                    continue;
240                }
241                if l.dir == incoming_id.dir {
242                    incoming_pts.push(l.lane_center_pts.last_pt());
243                } else {
244                    outgoing_pts.push(l.lane_center_pts.first_pt());
245                }
246            }
247
248            if !incoming_pts.is_empty() {
249                roads_incoming.insert(incoming_id, Pt2D::center(&incoming_pts));
250            }
251            if !outgoing_pts.is_empty() {
252                roads_outgoing.insert(outgoing_id, Pt2D::center(&outgoing_pts));
253            }
254        }
255        if roads_incoming.is_empty() || roads_outgoing.is_empty() {
256            bail!("{} has no incoming or outgoing roads", i);
257        }
258
259        Ok(Snapper {
260            roads_incoming,
261            roads_outgoing,
262            movements: map
263                .get_i(i)
264                .movements
265                .iter()
266                .filter(|(id, _)| !id.crosswalk)
267                .map(|(k, v)| (*k, v.clone()))
268                .collect(),
269        })
270    }
271
272    fn get_mvmnt(&self, pair: (Pt2D, Pt2D), code: &str, map: &Map) -> Result<MovementID> {
273        // Code is something like "WBT", westbound through.
274        let code_turn_type = match code.chars().last() {
275            Some('T') => TurnType::Straight,
276            Some('L') => TurnType::Left,
277            Some('R') => TurnType::Right,
278            x => bail!("Weird movement_str {:?}", x),
279        };
280        let code_direction = &code[0..2];
281
282        let (id, mvmnt) = self
283            .movements
284            .iter()
285            .min_by_key(|(id, mvmnt)| {
286                let from_cost = pair.0.dist_to(self.roads_incoming[&id.from]);
287                let to_cost = pair.1.dist_to(self.roads_outgoing[&id.to]);
288                let direction = cardinal_direction(
289                    map.get_l(mvmnt.members[0].src)
290                        .lane_center_pts
291                        .overall_angle(),
292                );
293
294                // Arbitrary parameters, tuned to make weird geometry at University/Mill in Tempe
295                // work.
296                let type_cost = if mvmnt.turn_type == code_turn_type {
297                    1.0
298                } else {
299                    2.0
300                };
301                // TODO This one is way more important than the geometry! Maybe JUST use the code?
302                let direction_cost = if direction == code_direction {
303                    1.0
304                } else {
305                    10.0
306                };
307                type_cost * direction_cost * (from_cost + to_cost)
308            })
309            .unwrap();
310
311        // Debug if the we didn't agree
312        let direction = cardinal_direction(
313            map.get_l(mvmnt.members[0].src)
314                .lane_center_pts
315                .overall_angle(),
316        );
317        if mvmnt.turn_type != code_turn_type || direction != code_direction {
318            warn!(
319                "A {} snapped to a {} {:?}",
320                code, direction, mvmnt.turn_type
321            );
322        }
323
324        Ok(*id)
325    }
326}
327
328fn cardinal_direction(angle: Angle) -> &'static str {
329    // Note Y inversion, as usual
330    let deg = angle.normalized_degrees();
331    if deg >= 335.0 || deg <= 45.0 {
332        return "EB";
333    }
334    if (45.0..=135.0).contains(&deg) {
335        return "SB";
336    }
337    if (135.0..=225.0).contains(&deg) {
338        return "WB";
339    }
340    "NB"
341}
342
343// The GMNS input doesn't include crosswalks yet -- and even once it does, it's likely the two map
344// models will disagree about where sidewalks exist. Try to add all crosswalks to the stage where
345// they're compatible. Downgrade right turns from protected to permitted as needed.
346fn add_crosswalks(signal: &mut ControlTrafficSignal, map: &Map) {
347    let downgrade_type = if map.get_config().driving_side == DrivingSide::Right {
348        TurnType::Right
349    } else {
350        TurnType::Left
351    };
352
353    let i = map.get_i(signal.id);
354    let mut crosswalks: Vec<MovementID> = Vec::new();
355    for id in i.movements.keys() {
356        if id.crosswalk {
357            crosswalks.push(*id);
358        }
359    }
360
361    // We could try to look for straight turns parallel to the crosswalk, but... just brute-force
362    // it
363    for stage in &mut signal.stages {
364        crosswalks.retain(|id| {
365            if stage.could_be_protected(*id, i) {
366                stage.edit_movement(&i.movements[id], TurnPriority::Protected);
367                false
368            } else {
369                // There may be conflicting right turns that we can downgrade. Try that.
370                let mut stage_copy = stage.clone();
371                for maybe_right_turn in stage.protected_movements.clone() {
372                    if i.movements[&maybe_right_turn].turn_type == downgrade_type {
373                        stage.protected_movements.remove(&maybe_right_turn);
374                        stage.yield_movements.insert(maybe_right_turn);
375                    }
376                }
377                if stage_copy.could_be_protected(*id, i) {
378                    stage_copy.edit_movement(&i.movements[id], TurnPriority::Protected);
379                    *stage = stage_copy;
380                    false
381                } else {
382                    true
383                }
384            }
385        });
386    }
387}