mas_storage_pg/compat/
session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    BrowserSession, CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device,
13    User, UserAgent,
14};
15use mas_storage::{
16    Clock, Page, Pagination,
17    compat::{CompatSessionFilter, CompatSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use url::Url;
25use uuid::Uuid;
26
27use crate::{
28    DatabaseError, DatabaseInconsistencyError,
29    filter::{Filter, StatementExt, StatementWithJoinsExt},
30    iden::{CompatSessions, CompatSsoLogins},
31    pagination::QueryBuilderExt,
32    tracing::ExecuteExt,
33};
34
35/// An implementation of [`CompatSessionRepository`] for a PostgreSQL connection
36pub struct PgCompatSessionRepository<'c> {
37    conn: &'c mut PgConnection,
38}
39
40impl<'c> PgCompatSessionRepository<'c> {
41    /// Create a new [`PgCompatSessionRepository`] from an active PostgreSQL
42    /// connection
43    pub fn new(conn: &'c mut PgConnection) -> Self {
44        Self { conn }
45    }
46}
47
48struct CompatSessionLookup {
49    compat_session_id: Uuid,
50    device_id: Option<String>,
51    human_name: Option<String>,
52    user_id: Uuid,
53    user_session_id: Option<Uuid>,
54    created_at: DateTime<Utc>,
55    finished_at: Option<DateTime<Utc>>,
56    is_synapse_admin: bool,
57    user_agent: Option<String>,
58    last_active_at: Option<DateTime<Utc>>,
59    last_active_ip: Option<IpAddr>,
60}
61
62impl From<CompatSessionLookup> for CompatSession {
63    fn from(value: CompatSessionLookup) -> Self {
64        let id = value.compat_session_id.into();
65
66        let state = match value.finished_at {
67            None => CompatSessionState::Valid,
68            Some(finished_at) => CompatSessionState::Finished { finished_at },
69        };
70
71        CompatSession {
72            id,
73            state,
74            user_id: value.user_id.into(),
75            user_session_id: value.user_session_id.map(Ulid::from),
76            device: value.device_id.map(Device::from),
77            human_name: value.human_name,
78            created_at: value.created_at,
79            is_synapse_admin: value.is_synapse_admin,
80            user_agent: value.user_agent.map(UserAgent::parse),
81            last_active_at: value.last_active_at,
82            last_active_ip: value.last_active_ip,
83        }
84    }
85}
86
87#[derive(sqlx::FromRow)]
88#[enum_def]
89struct CompatSessionAndSsoLoginLookup {
90    compat_session_id: Uuid,
91    device_id: Option<String>,
92    human_name: Option<String>,
93    user_id: Uuid,
94    user_session_id: Option<Uuid>,
95    created_at: DateTime<Utc>,
96    finished_at: Option<DateTime<Utc>>,
97    is_synapse_admin: bool,
98    user_agent: Option<String>,
99    last_active_at: Option<DateTime<Utc>>,
100    last_active_ip: Option<IpAddr>,
101    compat_sso_login_id: Option<Uuid>,
102    compat_sso_login_token: Option<String>,
103    compat_sso_login_redirect_uri: Option<String>,
104    compat_sso_login_created_at: Option<DateTime<Utc>>,
105    compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
106    compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
107}
108
109impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSsoLogin>) {
110    type Error = DatabaseInconsistencyError;
111
112    fn try_from(value: CompatSessionAndSsoLoginLookup) -> Result<Self, Self::Error> {
113        let id = value.compat_session_id.into();
114
115        let state = match value.finished_at {
116            None => CompatSessionState::Valid,
117            Some(finished_at) => CompatSessionState::Finished { finished_at },
118        };
119
120        let session = CompatSession {
121            id,
122            state,
123            user_id: value.user_id.into(),
124            device: value.device_id.map(Device::from),
125            human_name: value.human_name,
126            user_session_id: value.user_session_id.map(Ulid::from),
127            created_at: value.created_at,
128            is_synapse_admin: value.is_synapse_admin,
129            user_agent: value.user_agent.map(UserAgent::parse),
130            last_active_at: value.last_active_at,
131            last_active_ip: value.last_active_ip,
132        };
133
134        match (
135            value.compat_sso_login_id,
136            value.compat_sso_login_token,
137            value.compat_sso_login_redirect_uri,
138            value.compat_sso_login_created_at,
139            value.compat_sso_login_fulfilled_at,
140            value.compat_sso_login_exchanged_at,
141        ) {
142            (None, None, None, None, None, None) => Ok((session, None)),
143            (
144                Some(id),
145                Some(login_token),
146                Some(redirect_uri),
147                Some(created_at),
148                fulfilled_at,
149                exchanged_at,
150            ) => {
151                let id = id.into();
152                let redirect_uri = Url::parse(&redirect_uri).map_err(|e| {
153                    DatabaseInconsistencyError::on("compat_sso_logins")
154                        .column("redirect_uri")
155                        .row(id)
156                        .source(e)
157                })?;
158
159                let state = match (fulfilled_at, exchanged_at) {
160                    (Some(fulfilled_at), Some(exchanged_at)) => CompatSsoLoginState::Exchanged {
161                        fulfilled_at,
162                        exchanged_at,
163                        compat_session_id: session.id,
164                    },
165                    _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
166                };
167
168                let login = CompatSsoLogin {
169                    id,
170                    redirect_uri,
171                    login_token,
172                    created_at,
173                    state,
174                };
175
176                Ok((session, Some(login)))
177            }
178            _ => Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
179        }
180    }
181}
182
183impl Filter for CompatSessionFilter<'_> {
184    fn generate_condition(&self, has_joins: bool) -> impl sea_query::IntoCondition {
185        sea_query::Condition::all()
186            .add_option(self.user().map(|user| {
187                Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
188            }))
189            .add_option(self.browser_session().map(|browser_session| {
190                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
191                    .eq(Uuid::from(browser_session.id))
192            }))
193            .add_option(self.state().map(|state| {
194                if state.is_active() {
195                    Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
196                } else {
197                    Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
198                }
199            }))
200            .add_option(self.auth_type().map(|auth_type| {
201                // In in the SELECT to list sessions, we can rely on the JOINed table, whereas
202                // in other queries we need to do a subquery
203                if has_joins {
204                    if auth_type.is_sso_login() {
205                        Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
206                            .is_not_null()
207                    } else {
208                        Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
209                            .is_null()
210                    }
211                } else {
212                    // This builds either a:
213                    // `WHERE compat_session_id = ANY(...)`
214                    // or a `WHERE compat_session_id <> ALL(...)`
215                    let compat_sso_logins = Query::select()
216                        .expr(Expr::col((
217                            CompatSsoLogins::Table,
218                            CompatSsoLogins::CompatSessionId,
219                        )))
220                        .from(CompatSsoLogins::Table)
221                        .take();
222
223                    if auth_type.is_sso_login() {
224                        Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
225                            .eq(Expr::any(compat_sso_logins))
226                    } else {
227                        Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
228                            .ne(Expr::all(compat_sso_logins))
229                    }
230                }
231            }))
232            .add_option(self.last_active_after().map(|last_active_after| {
233                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
234                    .gt(last_active_after)
235            }))
236            .add_option(self.last_active_before().map(|last_active_before| {
237                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
238                    .lt(last_active_before)
239            }))
240            .add_option(self.device().map(|device| {
241                Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
242            }))
243    }
244}
245
246#[async_trait]
247impl CompatSessionRepository for PgCompatSessionRepository<'_> {
248    type Error = DatabaseError;
249
250    #[tracing::instrument(
251        name = "db.compat_session.lookup",
252        skip_all,
253        fields(
254            db.query.text,
255            compat_session.id = %id,
256        ),
257        err,
258    )]
259    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
260        let res = sqlx::query_as!(
261            CompatSessionLookup,
262            r#"
263                SELECT compat_session_id
264                     , device_id
265                     , human_name
266                     , user_id
267                     , user_session_id
268                     , created_at
269                     , finished_at
270                     , is_synapse_admin
271                     , user_agent
272                     , last_active_at
273                     , last_active_ip as "last_active_ip: IpAddr"
274                FROM compat_sessions
275                WHERE compat_session_id = $1
276            "#,
277            Uuid::from(id),
278        )
279        .traced()
280        .fetch_optional(&mut *self.conn)
281        .await?;
282
283        let Some(res) = res else { return Ok(None) };
284
285        Ok(Some(res.into()))
286    }
287
288    #[tracing::instrument(
289        name = "db.compat_session.add",
290        skip_all,
291        fields(
292            db.query.text,
293            compat_session.id,
294            %user.id,
295            %user.username,
296            compat_session.device.id = device.as_str(),
297        ),
298        err,
299    )]
300    async fn add(
301        &mut self,
302        rng: &mut (dyn RngCore + Send),
303        clock: &dyn Clock,
304        user: &User,
305        device: Device,
306        browser_session: Option<&BrowserSession>,
307        is_synapse_admin: bool,
308    ) -> Result<CompatSession, Self::Error> {
309        let created_at = clock.now();
310        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
311        tracing::Span::current().record("compat_session.id", tracing::field::display(id));
312
313        sqlx::query!(
314            r#"
315                INSERT INTO compat_sessions
316                    (compat_session_id, user_id, device_id,
317                     user_session_id, created_at, is_synapse_admin)
318                VALUES ($1, $2, $3, $4, $5, $6)
319            "#,
320            Uuid::from(id),
321            Uuid::from(user.id),
322            device.as_str(),
323            browser_session.map(|s| Uuid::from(s.id)),
324            created_at,
325            is_synapse_admin,
326        )
327        .traced()
328        .execute(&mut *self.conn)
329        .await?;
330
331        Ok(CompatSession {
332            id,
333            state: CompatSessionState::default(),
334            user_id: user.id,
335            device: Some(device),
336            human_name: None,
337            user_session_id: browser_session.map(|s| s.id),
338            created_at,
339            is_synapse_admin,
340            user_agent: None,
341            last_active_at: None,
342            last_active_ip: None,
343        })
344    }
345
346    #[tracing::instrument(
347        name = "db.compat_session.finish",
348        skip_all,
349        fields(
350            db.query.text,
351            %compat_session.id,
352            user.id = %compat_session.user_id,
353            compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
354        ),
355        err,
356    )]
357    async fn finish(
358        &mut self,
359        clock: &dyn Clock,
360        compat_session: CompatSession,
361    ) -> Result<CompatSession, Self::Error> {
362        let finished_at = clock.now();
363
364        let res = sqlx::query!(
365            r#"
366                UPDATE compat_sessions cs
367                SET finished_at = $2
368                WHERE compat_session_id = $1
369            "#,
370            Uuid::from(compat_session.id),
371            finished_at,
372        )
373        .traced()
374        .execute(&mut *self.conn)
375        .await?;
376
377        DatabaseError::ensure_affected_rows(&res, 1)?;
378
379        let compat_session = compat_session
380            .finish(finished_at)
381            .map_err(DatabaseError::to_invalid_operation)?;
382
383        Ok(compat_session)
384    }
385
386    #[tracing::instrument(
387        name = "db.compat_session.finish_bulk",
388        skip_all,
389        fields(db.query.text),
390        err,
391    )]
392    async fn finish_bulk(
393        &mut self,
394        clock: &dyn Clock,
395        filter: CompatSessionFilter<'_>,
396    ) -> Result<usize, Self::Error> {
397        let finished_at = clock.now();
398        let (sql, arguments) = Query::update()
399            .table(CompatSessions::Table)
400            .value(CompatSessions::FinishedAt, finished_at)
401            .apply_filter(filter)
402            .build_sqlx(PostgresQueryBuilder);
403
404        let res = sqlx::query_with(&sql, arguments)
405            .traced()
406            .execute(&mut *self.conn)
407            .await?;
408
409        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
410    }
411
412    #[tracing::instrument(
413        name = "db.compat_session.list",
414        skip_all,
415        fields(
416            db.query.text,
417        ),
418        err,
419    )]
420    async fn list(
421        &mut self,
422        filter: CompatSessionFilter<'_>,
423        pagination: Pagination,
424    ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
425        let (sql, arguments) = Query::select()
426            .expr_as(
427                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
428                CompatSessionAndSsoLoginLookupIden::CompatSessionId,
429            )
430            .expr_as(
431                Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
432                CompatSessionAndSsoLoginLookupIden::DeviceId,
433            )
434            .expr_as(
435                Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
436                CompatSessionAndSsoLoginLookupIden::HumanName,
437            )
438            .expr_as(
439                Expr::col((CompatSessions::Table, CompatSessions::UserId)),
440                CompatSessionAndSsoLoginLookupIden::UserId,
441            )
442            .expr_as(
443                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
444                CompatSessionAndSsoLoginLookupIden::UserSessionId,
445            )
446            .expr_as(
447                Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
448                CompatSessionAndSsoLoginLookupIden::CreatedAt,
449            )
450            .expr_as(
451                Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
452                CompatSessionAndSsoLoginLookupIden::FinishedAt,
453            )
454            .expr_as(
455                Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
456                CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
457            )
458            .expr_as(
459                Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
460                CompatSessionAndSsoLoginLookupIden::UserAgent,
461            )
462            .expr_as(
463                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
464                CompatSessionAndSsoLoginLookupIden::LastActiveAt,
465            )
466            .expr_as(
467                Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
468                CompatSessionAndSsoLoginLookupIden::LastActiveIp,
469            )
470            .expr_as(
471                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
472                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginId,
473            )
474            .expr_as(
475                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
476                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginToken,
477            )
478            .expr_as(
479                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
480                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginRedirectUri,
481            )
482            .expr_as(
483                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
484                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginCreatedAt,
485            )
486            .expr_as(
487                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
488                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginFulfilledAt,
489            )
490            .expr_as(
491                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
492                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginExchangedAt,
493            )
494            .from(CompatSessions::Table)
495            .left_join(
496                CompatSsoLogins::Table,
497                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
498                    .equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
499            )
500            .apply_filter_with_joins(filter)
501            .generate_pagination(
502                (CompatSessions::Table, CompatSessions::CompatSessionId),
503                pagination,
504            )
505            .build_sqlx(PostgresQueryBuilder);
506
507        let edges: Vec<CompatSessionAndSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
508            .traced()
509            .fetch_all(&mut *self.conn)
510            .await?;
511
512        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
513
514        Ok(page)
515    }
516
517    #[tracing::instrument(
518        name = "db.compat_session.count",
519        skip_all,
520        fields(
521            db.query.text,
522        ),
523        err,
524    )]
525    async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error> {
526        let (sql, arguments) = sea_query::Query::select()
527            .expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
528            .from(CompatSessions::Table)
529            .apply_filter(filter)
530            .build_sqlx(PostgresQueryBuilder);
531
532        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
533            .traced()
534            .fetch_one(&mut *self.conn)
535            .await?;
536
537        count
538            .try_into()
539            .map_err(DatabaseError::to_invalid_operation)
540    }
541
542    #[tracing::instrument(
543        name = "db.compat_session.record_batch_activity",
544        skip_all,
545        fields(
546            db.query.text,
547        ),
548        err,
549    )]
550    async fn record_batch_activity(
551        &mut self,
552        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
553    ) -> Result<(), Self::Error> {
554        let mut ids = Vec::with_capacity(activity.len());
555        let mut last_activities = Vec::with_capacity(activity.len());
556        let mut ips = Vec::with_capacity(activity.len());
557
558        for (id, last_activity, ip) in activity {
559            ids.push(Uuid::from(id));
560            last_activities.push(last_activity);
561            ips.push(ip);
562        }
563
564        let res = sqlx::query!(
565            r#"
566                UPDATE compat_sessions
567                SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)
568                  , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
569                FROM (
570                    SELECT *
571                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
572                        AS t(compat_session_id, last_active_at, last_active_ip)
573                ) AS t
574                WHERE compat_sessions.compat_session_id = t.compat_session_id
575            "#,
576            &ids,
577            &last_activities,
578            &ips as &[Option<IpAddr>],
579        )
580        .traced()
581        .execute(&mut *self.conn)
582        .await?;
583
584        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
585
586        Ok(())
587    }
588
589    #[tracing::instrument(
590        name = "db.compat_session.record_user_agent",
591        skip_all,
592        fields(
593            db.query.text,
594            %compat_session.id,
595        ),
596        err,
597    )]
598    async fn record_user_agent(
599        &mut self,
600        mut compat_session: CompatSession,
601        user_agent: UserAgent,
602    ) -> Result<CompatSession, Self::Error> {
603        let res = sqlx::query!(
604            r#"
605            UPDATE compat_sessions
606            SET user_agent = $2
607            WHERE compat_session_id = $1
608        "#,
609            Uuid::from(compat_session.id),
610            &*user_agent,
611        )
612        .traced()
613        .execute(&mut *self.conn)
614        .await?;
615
616        compat_session.user_agent = Some(user_agent);
617
618        DatabaseError::ensure_affected_rows(&res, 1)?;
619
620        Ok(compat_session)
621    }
622}