1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 BrowserSession, CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device,
13 User, UserAgent,
14};
15use mas_storage::{
16 Clock, Page, Pagination,
17 compat::{CompatSessionFilter, CompatSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use url::Url;
25use uuid::Uuid;
26
27use crate::{
28 DatabaseError, DatabaseInconsistencyError,
29 filter::{Filter, StatementExt, StatementWithJoinsExt},
30 iden::{CompatSessions, CompatSsoLogins},
31 pagination::QueryBuilderExt,
32 tracing::ExecuteExt,
33};
34
35pub struct PgCompatSessionRepository<'c> {
37 conn: &'c mut PgConnection,
38}
39
40impl<'c> PgCompatSessionRepository<'c> {
41 pub fn new(conn: &'c mut PgConnection) -> Self {
44 Self { conn }
45 }
46}
47
48struct CompatSessionLookup {
49 compat_session_id: Uuid,
50 device_id: Option<String>,
51 human_name: Option<String>,
52 user_id: Uuid,
53 user_session_id: Option<Uuid>,
54 created_at: DateTime<Utc>,
55 finished_at: Option<DateTime<Utc>>,
56 is_synapse_admin: bool,
57 user_agent: Option<String>,
58 last_active_at: Option<DateTime<Utc>>,
59 last_active_ip: Option<IpAddr>,
60}
61
62impl From<CompatSessionLookup> for CompatSession {
63 fn from(value: CompatSessionLookup) -> Self {
64 let id = value.compat_session_id.into();
65
66 let state = match value.finished_at {
67 None => CompatSessionState::Valid,
68 Some(finished_at) => CompatSessionState::Finished { finished_at },
69 };
70
71 CompatSession {
72 id,
73 state,
74 user_id: value.user_id.into(),
75 user_session_id: value.user_session_id.map(Ulid::from),
76 device: value.device_id.map(Device::from),
77 human_name: value.human_name,
78 created_at: value.created_at,
79 is_synapse_admin: value.is_synapse_admin,
80 user_agent: value.user_agent.map(UserAgent::parse),
81 last_active_at: value.last_active_at,
82 last_active_ip: value.last_active_ip,
83 }
84 }
85}
86
87#[derive(sqlx::FromRow)]
88#[enum_def]
89struct CompatSessionAndSsoLoginLookup {
90 compat_session_id: Uuid,
91 device_id: Option<String>,
92 human_name: Option<String>,
93 user_id: Uuid,
94 user_session_id: Option<Uuid>,
95 created_at: DateTime<Utc>,
96 finished_at: Option<DateTime<Utc>>,
97 is_synapse_admin: bool,
98 user_agent: Option<String>,
99 last_active_at: Option<DateTime<Utc>>,
100 last_active_ip: Option<IpAddr>,
101 compat_sso_login_id: Option<Uuid>,
102 compat_sso_login_token: Option<String>,
103 compat_sso_login_redirect_uri: Option<String>,
104 compat_sso_login_created_at: Option<DateTime<Utc>>,
105 compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
106 compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
107}
108
109impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSsoLogin>) {
110 type Error = DatabaseInconsistencyError;
111
112 fn try_from(value: CompatSessionAndSsoLoginLookup) -> Result<Self, Self::Error> {
113 let id = value.compat_session_id.into();
114
115 let state = match value.finished_at {
116 None => CompatSessionState::Valid,
117 Some(finished_at) => CompatSessionState::Finished { finished_at },
118 };
119
120 let session = CompatSession {
121 id,
122 state,
123 user_id: value.user_id.into(),
124 device: value.device_id.map(Device::from),
125 human_name: value.human_name,
126 user_session_id: value.user_session_id.map(Ulid::from),
127 created_at: value.created_at,
128 is_synapse_admin: value.is_synapse_admin,
129 user_agent: value.user_agent.map(UserAgent::parse),
130 last_active_at: value.last_active_at,
131 last_active_ip: value.last_active_ip,
132 };
133
134 match (
135 value.compat_sso_login_id,
136 value.compat_sso_login_token,
137 value.compat_sso_login_redirect_uri,
138 value.compat_sso_login_created_at,
139 value.compat_sso_login_fulfilled_at,
140 value.compat_sso_login_exchanged_at,
141 ) {
142 (None, None, None, None, None, None) => Ok((session, None)),
143 (
144 Some(id),
145 Some(login_token),
146 Some(redirect_uri),
147 Some(created_at),
148 fulfilled_at,
149 exchanged_at,
150 ) => {
151 let id = id.into();
152 let redirect_uri = Url::parse(&redirect_uri).map_err(|e| {
153 DatabaseInconsistencyError::on("compat_sso_logins")
154 .column("redirect_uri")
155 .row(id)
156 .source(e)
157 })?;
158
159 let state = match (fulfilled_at, exchanged_at) {
160 (Some(fulfilled_at), Some(exchanged_at)) => CompatSsoLoginState::Exchanged {
161 fulfilled_at,
162 exchanged_at,
163 compat_session_id: session.id,
164 },
165 _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
166 };
167
168 let login = CompatSsoLogin {
169 id,
170 redirect_uri,
171 login_token,
172 created_at,
173 state,
174 };
175
176 Ok((session, Some(login)))
177 }
178 _ => Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
179 }
180 }
181}
182
183impl Filter for CompatSessionFilter<'_> {
184 fn generate_condition(&self, has_joins: bool) -> impl sea_query::IntoCondition {
185 sea_query::Condition::all()
186 .add_option(self.user().map(|user| {
187 Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
188 }))
189 .add_option(self.browser_session().map(|browser_session| {
190 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
191 .eq(Uuid::from(browser_session.id))
192 }))
193 .add_option(self.state().map(|state| {
194 if state.is_active() {
195 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
196 } else {
197 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
198 }
199 }))
200 .add_option(self.auth_type().map(|auth_type| {
201 if has_joins {
204 if auth_type.is_sso_login() {
205 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
206 .is_not_null()
207 } else {
208 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
209 .is_null()
210 }
211 } else {
212 let compat_sso_logins = Query::select()
216 .expr(Expr::col((
217 CompatSsoLogins::Table,
218 CompatSsoLogins::CompatSessionId,
219 )))
220 .from(CompatSsoLogins::Table)
221 .take();
222
223 if auth_type.is_sso_login() {
224 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
225 .eq(Expr::any(compat_sso_logins))
226 } else {
227 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
228 .ne(Expr::all(compat_sso_logins))
229 }
230 }
231 }))
232 .add_option(self.last_active_after().map(|last_active_after| {
233 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
234 .gt(last_active_after)
235 }))
236 .add_option(self.last_active_before().map(|last_active_before| {
237 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
238 .lt(last_active_before)
239 }))
240 .add_option(self.device().map(|device| {
241 Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
242 }))
243 }
244}
245
246#[async_trait]
247impl CompatSessionRepository for PgCompatSessionRepository<'_> {
248 type Error = DatabaseError;
249
250 #[tracing::instrument(
251 name = "db.compat_session.lookup",
252 skip_all,
253 fields(
254 db.query.text,
255 compat_session.id = %id,
256 ),
257 err,
258 )]
259 async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
260 let res = sqlx::query_as!(
261 CompatSessionLookup,
262 r#"
263 SELECT compat_session_id
264 , device_id
265 , human_name
266 , user_id
267 , user_session_id
268 , created_at
269 , finished_at
270 , is_synapse_admin
271 , user_agent
272 , last_active_at
273 , last_active_ip as "last_active_ip: IpAddr"
274 FROM compat_sessions
275 WHERE compat_session_id = $1
276 "#,
277 Uuid::from(id),
278 )
279 .traced()
280 .fetch_optional(&mut *self.conn)
281 .await?;
282
283 let Some(res) = res else { return Ok(None) };
284
285 Ok(Some(res.into()))
286 }
287
288 #[tracing::instrument(
289 name = "db.compat_session.add",
290 skip_all,
291 fields(
292 db.query.text,
293 compat_session.id,
294 %user.id,
295 %user.username,
296 compat_session.device.id = device.as_str(),
297 ),
298 err,
299 )]
300 async fn add(
301 &mut self,
302 rng: &mut (dyn RngCore + Send),
303 clock: &dyn Clock,
304 user: &User,
305 device: Device,
306 browser_session: Option<&BrowserSession>,
307 is_synapse_admin: bool,
308 ) -> Result<CompatSession, Self::Error> {
309 let created_at = clock.now();
310 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
311 tracing::Span::current().record("compat_session.id", tracing::field::display(id));
312
313 sqlx::query!(
314 r#"
315 INSERT INTO compat_sessions
316 (compat_session_id, user_id, device_id,
317 user_session_id, created_at, is_synapse_admin)
318 VALUES ($1, $2, $3, $4, $5, $6)
319 "#,
320 Uuid::from(id),
321 Uuid::from(user.id),
322 device.as_str(),
323 browser_session.map(|s| Uuid::from(s.id)),
324 created_at,
325 is_synapse_admin,
326 )
327 .traced()
328 .execute(&mut *self.conn)
329 .await?;
330
331 Ok(CompatSession {
332 id,
333 state: CompatSessionState::default(),
334 user_id: user.id,
335 device: Some(device),
336 human_name: None,
337 user_session_id: browser_session.map(|s| s.id),
338 created_at,
339 is_synapse_admin,
340 user_agent: None,
341 last_active_at: None,
342 last_active_ip: None,
343 })
344 }
345
346 #[tracing::instrument(
347 name = "db.compat_session.finish",
348 skip_all,
349 fields(
350 db.query.text,
351 %compat_session.id,
352 user.id = %compat_session.user_id,
353 compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
354 ),
355 err,
356 )]
357 async fn finish(
358 &mut self,
359 clock: &dyn Clock,
360 compat_session: CompatSession,
361 ) -> Result<CompatSession, Self::Error> {
362 let finished_at = clock.now();
363
364 let res = sqlx::query!(
365 r#"
366 UPDATE compat_sessions cs
367 SET finished_at = $2
368 WHERE compat_session_id = $1
369 "#,
370 Uuid::from(compat_session.id),
371 finished_at,
372 )
373 .traced()
374 .execute(&mut *self.conn)
375 .await?;
376
377 DatabaseError::ensure_affected_rows(&res, 1)?;
378
379 let compat_session = compat_session
380 .finish(finished_at)
381 .map_err(DatabaseError::to_invalid_operation)?;
382
383 Ok(compat_session)
384 }
385
386 #[tracing::instrument(
387 name = "db.compat_session.finish_bulk",
388 skip_all,
389 fields(db.query.text),
390 err,
391 )]
392 async fn finish_bulk(
393 &mut self,
394 clock: &dyn Clock,
395 filter: CompatSessionFilter<'_>,
396 ) -> Result<usize, Self::Error> {
397 let finished_at = clock.now();
398 let (sql, arguments) = Query::update()
399 .table(CompatSessions::Table)
400 .value(CompatSessions::FinishedAt, finished_at)
401 .apply_filter(filter)
402 .build_sqlx(PostgresQueryBuilder);
403
404 let res = sqlx::query_with(&sql, arguments)
405 .traced()
406 .execute(&mut *self.conn)
407 .await?;
408
409 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
410 }
411
412 #[tracing::instrument(
413 name = "db.compat_session.list",
414 skip_all,
415 fields(
416 db.query.text,
417 ),
418 err,
419 )]
420 async fn list(
421 &mut self,
422 filter: CompatSessionFilter<'_>,
423 pagination: Pagination,
424 ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
425 let (sql, arguments) = Query::select()
426 .expr_as(
427 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
428 CompatSessionAndSsoLoginLookupIden::CompatSessionId,
429 )
430 .expr_as(
431 Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
432 CompatSessionAndSsoLoginLookupIden::DeviceId,
433 )
434 .expr_as(
435 Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
436 CompatSessionAndSsoLoginLookupIden::HumanName,
437 )
438 .expr_as(
439 Expr::col((CompatSessions::Table, CompatSessions::UserId)),
440 CompatSessionAndSsoLoginLookupIden::UserId,
441 )
442 .expr_as(
443 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
444 CompatSessionAndSsoLoginLookupIden::UserSessionId,
445 )
446 .expr_as(
447 Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
448 CompatSessionAndSsoLoginLookupIden::CreatedAt,
449 )
450 .expr_as(
451 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
452 CompatSessionAndSsoLoginLookupIden::FinishedAt,
453 )
454 .expr_as(
455 Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
456 CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
457 )
458 .expr_as(
459 Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
460 CompatSessionAndSsoLoginLookupIden::UserAgent,
461 )
462 .expr_as(
463 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
464 CompatSessionAndSsoLoginLookupIden::LastActiveAt,
465 )
466 .expr_as(
467 Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
468 CompatSessionAndSsoLoginLookupIden::LastActiveIp,
469 )
470 .expr_as(
471 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
472 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginId,
473 )
474 .expr_as(
475 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
476 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginToken,
477 )
478 .expr_as(
479 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
480 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginRedirectUri,
481 )
482 .expr_as(
483 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
484 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginCreatedAt,
485 )
486 .expr_as(
487 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
488 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginFulfilledAt,
489 )
490 .expr_as(
491 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
492 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginExchangedAt,
493 )
494 .from(CompatSessions::Table)
495 .left_join(
496 CompatSsoLogins::Table,
497 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
498 .equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
499 )
500 .apply_filter_with_joins(filter)
501 .generate_pagination(
502 (CompatSessions::Table, CompatSessions::CompatSessionId),
503 pagination,
504 )
505 .build_sqlx(PostgresQueryBuilder);
506
507 let edges: Vec<CompatSessionAndSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
508 .traced()
509 .fetch_all(&mut *self.conn)
510 .await?;
511
512 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
513
514 Ok(page)
515 }
516
517 #[tracing::instrument(
518 name = "db.compat_session.count",
519 skip_all,
520 fields(
521 db.query.text,
522 ),
523 err,
524 )]
525 async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error> {
526 let (sql, arguments) = sea_query::Query::select()
527 .expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
528 .from(CompatSessions::Table)
529 .apply_filter(filter)
530 .build_sqlx(PostgresQueryBuilder);
531
532 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
533 .traced()
534 .fetch_one(&mut *self.conn)
535 .await?;
536
537 count
538 .try_into()
539 .map_err(DatabaseError::to_invalid_operation)
540 }
541
542 #[tracing::instrument(
543 name = "db.compat_session.record_batch_activity",
544 skip_all,
545 fields(
546 db.query.text,
547 ),
548 err,
549 )]
550 async fn record_batch_activity(
551 &mut self,
552 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
553 ) -> Result<(), Self::Error> {
554 let mut ids = Vec::with_capacity(activity.len());
555 let mut last_activities = Vec::with_capacity(activity.len());
556 let mut ips = Vec::with_capacity(activity.len());
557
558 for (id, last_activity, ip) in activity {
559 ids.push(Uuid::from(id));
560 last_activities.push(last_activity);
561 ips.push(ip);
562 }
563
564 let res = sqlx::query!(
565 r#"
566 UPDATE compat_sessions
567 SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)
568 , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
569 FROM (
570 SELECT *
571 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
572 AS t(compat_session_id, last_active_at, last_active_ip)
573 ) AS t
574 WHERE compat_sessions.compat_session_id = t.compat_session_id
575 "#,
576 &ids,
577 &last_activities,
578 &ips as &[Option<IpAddr>],
579 )
580 .traced()
581 .execute(&mut *self.conn)
582 .await?;
583
584 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
585
586 Ok(())
587 }
588
589 #[tracing::instrument(
590 name = "db.compat_session.record_user_agent",
591 skip_all,
592 fields(
593 db.query.text,
594 %compat_session.id,
595 ),
596 err,
597 )]
598 async fn record_user_agent(
599 &mut self,
600 mut compat_session: CompatSession,
601 user_agent: UserAgent,
602 ) -> Result<CompatSession, Self::Error> {
603 let res = sqlx::query!(
604 r#"
605 UPDATE compat_sessions
606 SET user_agent = $2
607 WHERE compat_session_id = $1
608 "#,
609 Uuid::from(compat_session.id),
610 &*user_agent,
611 )
612 .traced()
613 .execute(&mut *self.conn)
614 .await?;
615
616 compat_session.user_agent = Some(user_agent);
617
618 DatabaseError::ensure_affected_rows(&res, 1)?;
619
620 Ok(compat_session)
621 }
622}