game/edit/traffic_signals/
gmns.rs1use std::collections::{BTreeMap, BTreeSet, HashMap};
4use std::io::Cursor;
5
6use anyhow::Result;
7use serde::{Deserialize, Deserializer};
8
9use geom::{Angle, Duration, LonLat, Pt2D};
10use map_model::{
11 osm, ControlTrafficSignal, DirectedRoadID, DrivingSide, EditIntersectionControl,
12 IntersectionID, Map, Movement, MovementID, Stage, StageType, TurnPriority, TurnType,
13};
14use widgetry::tools::PopupMsg;
15use widgetry::{EventCtx, State};
16
17use crate::edit::apply_map_edits;
18use crate::App;
19
20pub fn import(map: &Map, i: IntersectionID, bytes: &Vec<u8>) -> Result<ControlTrafficSignal> {
24 let i = map.get_i(i);
25 let mut matches_per_plan: BTreeMap<String, Vec<Record>> = BTreeMap::new();
26 for rec in csv::Reader::from_reader(Cursor::new(bytes)).deserialize() {
27 let rec: Record = rec?;
28 if !rec.osm_ids.contains(&i.orig_id) {
29 continue;
30 }
31 matches_per_plan
32 .entry(rec.timing_plan_id.clone())
33 .or_insert_with(Vec::new)
34 .push(rec);
35 }
36
37 let mut records = matches_per_plan
39 .into_iter()
40 .next()
41 .ok_or_else(|| anyhow!("no matches for {}", i.orig_id))?
42 .1;
43 records.sort_by_key(|rec| rec.stage);
44
45 let snapper = Snapper::new(map, i.id)?;
46
47 let mut signal = ControlTrafficSignal::new(map, i.id);
48 signal.stages.clear();
49 for rec in records {
50 let stage_idx = rec.stage - 1;
51 match signal.stages.len().cmp(&stage_idx) {
52 std::cmp::Ordering::Equal => {
53 signal.stages.push(Stage {
54 protected_movements: BTreeSet::new(),
55 yield_movements: BTreeSet::new(),
56 stage_type: StageType::Fixed(Duration::seconds(rec.green_time as f64)),
57 });
58 }
59 std::cmp::Ordering::Less => {
60 bail!("missing intermediate stage");
61 }
62 std::cmp::Ordering::Greater => {}
63 }
64 let stage = &mut signal.stages[stage_idx];
65
66 if stage.stage_type.simple_duration() != Duration::seconds(rec.green_time as f64) {
67 bail!(
68 "Stage {} has green_times {} and {}",
69 rec.stage,
70 stage.stage_type.simple_duration(),
71 rec.green_time
72 );
73 }
74
75 let mvmnt = match snapper.get_mvmnt(
76 (
77 rec.geometry.0.to_pt(map.get_gps_bounds()),
78 rec.geometry.1.to_pt(map.get_gps_bounds()),
79 ),
80 &rec.mvmt_txt_id,
81 map,
82 ) {
83 Ok(x) => x,
84 Err(err) => {
85 error!(
86 "Skipping {} -> {} for stage {}: {}",
87 rec.geometry.0, rec.geometry.1, rec.stage, err
88 );
89 continue;
90 }
91 };
92 if rec.protection == "protected" {
93 stage.protected_movements.insert(mvmnt);
94 } else {
95 stage.yield_movements.insert(mvmnt);
96 }
97 }
98
99 add_crosswalks(&mut signal, map);
100
101 Ok(signal)
102}
103
104pub fn import_all(
105 ctx: &mut EventCtx,
106 app: &mut App,
107 path: &str,
108 bytes: Vec<u8>,
109) -> Box<dyn State<App>> {
110 let all_signals: Vec<IntersectionID> = app
111 .primary
112 .map
113 .all_intersections()
114 .iter()
115 .filter_map(|i| {
116 if i.is_traffic_signal() {
117 Some(i.id)
118 } else {
119 None
120 }
121 })
122 .collect();
123 let mut successes = 0;
124 let mut failures_no_match = 0;
125 let mut failures_other = 0;
126 let mut edits = app.primary.map.get_edits().clone();
127
128 ctx.loading_screen("import signal timing", |_, timer| {
129 timer.start_iter("import", all_signals.len());
130 for i in all_signals {
131 timer.next();
132 match import(&app.primary.map, i, &bytes)
133 .and_then(|signal| signal.validate(app.primary.map.get_i(i)).map(|_| signal))
134 {
135 Ok(signal) => {
136 info!("Success at {}", i);
137 successes += 1;
138 edits
139 .commands
140 .push(app.primary.map.edit_intersection_cmd(i, |new| {
141 new.control = EditIntersectionControl::TrafficSignal(
142 signal.export(&app.primary.map),
143 );
144 }));
145 }
146 Err(err) => {
147 error!("Failure at {}: {}", i, err);
148 if err.to_string().contains("no matches for") {
149 failures_no_match += 1;
150 } else {
151 failures_other += 1;
152 }
153 }
154 }
155 }
156 });
157
158 apply_map_edits(ctx, app, edits);
159
160 PopupMsg::new_state(
161 ctx,
162 &format!("Import from {}", path),
163 vec![
164 format!("{} traffic signals successfully imported", successes),
165 format!("{} intersections without any data", failures_no_match),
166 format!("{} other failures", failures_other),
167 ],
168 )
169}
170
171#[derive(Debug, Deserialize)]
172struct Record {
173 #[serde(deserialize_with = "parse_osm_ids", rename = "osm_node_id")]
174 osm_ids: Vec<osm::NodeID>,
175 timing_plan_id: String,
176 green_time: usize,
177 #[serde(rename = "stage_no")]
178 stage: usize,
179 #[serde(deserialize_with = "parse_linestring")]
180 geometry: (LonLat, LonLat),
181 protection: String,
182 mvmt_txt_id: String,
184}
185
186fn parse_linestring<'de, D: Deserializer<'de>>(d: D) -> Result<(LonLat, LonLat), D::Error> {
187 let raw = <String>::deserialize(d)?;
188 let pts = LonLat::parse_wkt_linestring(&raw)
189 .ok_or_else(|| serde::de::Error::custom(format!("bad linestring {}", raw)))?;
190 if pts.len() != 2 {
191 return Err(serde::de::Error::custom(format!(
192 "{} points, expecting 2",
193 pts.len()
194 )));
195 }
196 Ok((pts[0], pts[1]))
197}
198
199fn parse_osm_ids<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<osm::NodeID>, D::Error> {
200 let raw = <String>::deserialize(d)?;
201 let mut ids = Vec::new();
202 for id in raw.split('_') {
203 ids.push(osm::NodeID(id.parse::<i64>().map_err(|_| {
204 serde::de::Error::custom(format!("bad ID {}", id))
205 })?));
206 }
207 Ok(ids)
208}
209
210struct Snapper {
217 roads_incoming: HashMap<DirectedRoadID, Pt2D>,
218 roads_outgoing: HashMap<DirectedRoadID, Pt2D>,
219 movements: BTreeMap<MovementID, Movement>,
220}
221
222impl Snapper {
223 fn new(map: &Map, i: IntersectionID) -> Result<Snapper> {
224 let mut roads_incoming = HashMap::new();
225 let mut roads_outgoing = HashMap::new();
226 for r in &map.get_i(i).roads {
227 let r = map.get_r(*r);
228
229 let incoming_id = r.directed_id_to(i);
230 let outgoing_id = r.directed_id_from(i);
231
232 let mut incoming_pts = Vec::new();
235 let mut outgoing_pts = Vec::new();
236
237 for l in &r.lanes {
238 if l.lane_type.is_walkable() {
239 continue;
240 }
241 if l.dir == incoming_id.dir {
242 incoming_pts.push(l.lane_center_pts.last_pt());
243 } else {
244 outgoing_pts.push(l.lane_center_pts.first_pt());
245 }
246 }
247
248 if !incoming_pts.is_empty() {
249 roads_incoming.insert(incoming_id, Pt2D::center(&incoming_pts));
250 }
251 if !outgoing_pts.is_empty() {
252 roads_outgoing.insert(outgoing_id, Pt2D::center(&outgoing_pts));
253 }
254 }
255 if roads_incoming.is_empty() || roads_outgoing.is_empty() {
256 bail!("{} has no incoming or outgoing roads", i);
257 }
258
259 Ok(Snapper {
260 roads_incoming,
261 roads_outgoing,
262 movements: map
263 .get_i(i)
264 .movements
265 .iter()
266 .filter(|(id, _)| !id.crosswalk)
267 .map(|(k, v)| (*k, v.clone()))
268 .collect(),
269 })
270 }
271
272 fn get_mvmnt(&self, pair: (Pt2D, Pt2D), code: &str, map: &Map) -> Result<MovementID> {
273 let code_turn_type = match code.chars().last() {
275 Some('T') => TurnType::Straight,
276 Some('L') => TurnType::Left,
277 Some('R') => TurnType::Right,
278 x => bail!("Weird movement_str {:?}", x),
279 };
280 let code_direction = &code[0..2];
281
282 let (id, mvmnt) = self
283 .movements
284 .iter()
285 .min_by_key(|(id, mvmnt)| {
286 let from_cost = pair.0.dist_to(self.roads_incoming[&id.from]);
287 let to_cost = pair.1.dist_to(self.roads_outgoing[&id.to]);
288 let direction = cardinal_direction(
289 map.get_l(mvmnt.members[0].src)
290 .lane_center_pts
291 .overall_angle(),
292 );
293
294 let type_cost = if mvmnt.turn_type == code_turn_type {
297 1.0
298 } else {
299 2.0
300 };
301 let direction_cost = if direction == code_direction {
303 1.0
304 } else {
305 10.0
306 };
307 type_cost * direction_cost * (from_cost + to_cost)
308 })
309 .unwrap();
310
311 let direction = cardinal_direction(
313 map.get_l(mvmnt.members[0].src)
314 .lane_center_pts
315 .overall_angle(),
316 );
317 if mvmnt.turn_type != code_turn_type || direction != code_direction {
318 warn!(
319 "A {} snapped to a {} {:?}",
320 code, direction, mvmnt.turn_type
321 );
322 }
323
324 Ok(*id)
325 }
326}
327
328fn cardinal_direction(angle: Angle) -> &'static str {
329 let deg = angle.normalized_degrees();
331 if deg >= 335.0 || deg <= 45.0 {
332 return "EB";
333 }
334 if (45.0..=135.0).contains(°) {
335 return "SB";
336 }
337 if (135.0..=225.0).contains(°) {
338 return "WB";
339 }
340 "NB"
341}
342
343fn add_crosswalks(signal: &mut ControlTrafficSignal, map: &Map) {
347 let downgrade_type = if map.get_config().driving_side == DrivingSide::Right {
348 TurnType::Right
349 } else {
350 TurnType::Left
351 };
352
353 let i = map.get_i(signal.id);
354 let mut crosswalks: Vec<MovementID> = Vec::new();
355 for id in i.movements.keys() {
356 if id.crosswalk {
357 crosswalks.push(*id);
358 }
359 }
360
361 for stage in &mut signal.stages {
364 crosswalks.retain(|id| {
365 if stage.could_be_protected(*id, i) {
366 stage.edit_movement(&i.movements[id], TurnPriority::Protected);
367 false
368 } else {
369 let mut stage_copy = stage.clone();
371 for maybe_right_turn in stage.protected_movements.clone() {
372 if i.movements[&maybe_right_turn].turn_type == downgrade_type {
373 stage.protected_movements.remove(&maybe_right_turn);
374 stage.yield_movements.insert(maybe_right_turn);
375 }
376 }
377 if stage_copy.could_be_protected(*id, i) {
378 stage_copy.edit_movement(&i.movements[id], TurnPriority::Protected);
379 *stage = stage_copy;
380 false
381 } else {
382 true
383 }
384 }
385 });
386 }
387}