mas_handlers/upstream_oauth2/
link.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::{Arc, LazyLock};
8
9use axum::{
10    Form,
11    extract::{Path, State},
12    response::{Html, IntoResponse, Response},
13};
14use axum_extra::typed_header::TypedHeader;
15use hyper::StatusCode;
16use mas_axum_utils::{
17    FancyError, SessionInfoExt,
18    cookies::CookieJar,
19    csrf::{CsrfExt, ProtectedForm},
20    sentry::SentryEventID,
21};
22use mas_data_model::UserAgent;
23use mas_jose::jwt::Jwt;
24use mas_matrix::HomeserverConnection;
25use mas_policy::Policy;
26use mas_router::UrlBuilder;
27use mas_storage::{
28    BoxClock, BoxRepository, BoxRng, RepositoryAccess,
29    queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
30    upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
31    user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
32};
33use mas_templates::{
34    AccountInactiveContext, ErrorContext, FieldError, FormError, TemplateContext, Templates,
35    ToFormState, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
36};
37use minijinja::Environment;
38use opentelemetry::{Key, KeyValue, metrics::Counter};
39use serde::{Deserialize, Serialize};
40use thiserror::Error;
41use tracing::warn;
42use ulid::Ulid;
43
44use super::{
45    UpstreamSessionsCookie,
46    template::{AttributeMappingContext, environment},
47};
48use crate::{
49    BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route,
50    views::shared::OptionalPostAuthAction,
51};
52
53static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
54    METER
55        .u64_counter("mas.upstream_oauth2.login")
56        .with_description("Successful upstream OAuth 2.0 login to existing accounts")
57        .with_unit("{login}")
58        .build()
59});
60static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
61    METER
62        .u64_counter("mas.upstream_oauth2.registration")
63        .with_description("Successful upstream OAuth 2.0 registration")
64        .with_unit("{registration}")
65        .build()
66});
67const PROVIDER: Key = Key::from_static_str("provider");
68
69const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
70const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}";
71const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}";
72
73#[derive(Debug, Error)]
74pub(crate) enum RouteError {
75    /// Couldn't find the link specified in the URL
76    #[error("Link not found")]
77    LinkNotFound,
78
79    /// Couldn't find the session on the link
80    #[error("Session not found")]
81    SessionNotFound,
82
83    /// Couldn't find the user
84    #[error("User not found")]
85    UserNotFound,
86
87    /// Couldn't find upstream provider
88    #[error("Upstream provider not found")]
89    ProviderNotFound,
90
91    /// Required attribute rendered to an empty string
92    #[error("Template {template:?} rendered to an empty string")]
93    RequiredAttributeEmpty { template: String },
94
95    /// Required claim was missing in `id_token`
96    #[error(
97        "Template {template:?} could not be rendered from the upstream provider's response for required claim"
98    )]
99    RequiredAttributeRender {
100        template: String,
101
102        #[source]
103        source: minijinja::Error,
104    },
105
106    /// Session was already consumed
107    #[error("Session already consumed")]
108    SessionConsumed,
109
110    #[error("Missing session cookie")]
111    MissingCookie,
112
113    #[error("Invalid form action")]
114    InvalidFormAction,
115
116    #[error("Homeserver connection error")]
117    HomeserverConnection(#[source] anyhow::Error),
118
119    #[error(transparent)]
120    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
121}
122
123impl_from_error_for_route!(mas_templates::TemplateError);
124impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
125impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
126impl_from_error_for_route!(mas_storage::RepositoryError);
127impl_from_error_for_route!(mas_policy::EvaluationError);
128impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
129
130impl IntoResponse for RouteError {
131    fn into_response(self) -> axum::response::Response {
132        let event_id = sentry::capture_error(&self);
133        let response = match self {
134            Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
135            Self::Internal(e) => FancyError::from(e).into_response(),
136            e => FancyError::from(e).into_response(),
137        };
138
139        (SentryEventID::from(event_id), response).into_response()
140    }
141}
142
143/// Utility function to render an attribute template.
144///
145/// # Parameters
146///
147/// * `environment` - The minijinja environment to use to render the template
148/// * `template` - The template to use to render the claim
149/// * `required` - Whether the attribute is required or not
150///
151/// # Errors
152///
153/// Returns an error if the attribute is required but fails to render or is
154/// empty
155fn render_attribute_template(
156    environment: &Environment,
157    template: &str,
158    context: &minijinja::Value,
159    required: bool,
160) -> Result<Option<String>, RouteError> {
161    match environment.render_str(template, context) {
162        Ok(value) if value.is_empty() => {
163            if required {
164                return Err(RouteError::RequiredAttributeEmpty {
165                    template: template.to_owned(),
166                });
167            }
168
169            Ok(None)
170        }
171
172        Ok(value) => Ok(Some(value)),
173
174        Err(source) => {
175            if required {
176                return Err(RouteError::RequiredAttributeRender {
177                    template: template.to_owned(),
178                    source,
179                });
180            }
181
182            tracing::warn!(error = &source as &dyn std::error::Error, %template, "Error while rendering template");
183            Ok(None)
184        }
185    }
186}
187
188#[derive(Deserialize, Serialize)]
189#[serde(rename_all = "lowercase", tag = "action")]
190pub(crate) enum FormData {
191    Register {
192        #[serde(default)]
193        username: Option<String>,
194        #[serde(default)]
195        import_email: Option<String>,
196        #[serde(default)]
197        import_display_name: Option<String>,
198        #[serde(default)]
199        accept_terms: Option<String>,
200    },
201    Link,
202}
203
204impl ToFormState for FormData {
205    type Field = mas_templates::UpstreamRegisterFormField;
206}
207
208#[tracing::instrument(
209    name = "handlers.upstream_oauth2.link.get",
210    fields(upstream_oauth_link.id = %link_id),
211    skip_all,
212    err,
213)]
214pub(crate) async fn get(
215    mut rng: BoxRng,
216    clock: BoxClock,
217    mut repo: BoxRepository,
218    mut policy: Policy,
219    PreferredLanguage(locale): PreferredLanguage,
220    State(templates): State<Templates>,
221    State(url_builder): State<UrlBuilder>,
222    State(homeserver): State<Arc<dyn HomeserverConnection>>,
223    cookie_jar: CookieJar,
224    activity_tracker: BoundActivityTracker,
225    user_agent: Option<TypedHeader<headers::UserAgent>>,
226    Path(link_id): Path<Ulid>,
227) -> Result<impl IntoResponse, RouteError> {
228    let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
229    let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
230    let (session_id, post_auth_action) = sessions_cookie
231        .lookup_link(link_id)
232        .map_err(|_| RouteError::MissingCookie)?;
233
234    let post_auth_action = OptionalPostAuthAction {
235        post_auth_action: post_auth_action.cloned(),
236    };
237
238    let link = repo
239        .upstream_oauth_link()
240        .lookup(link_id)
241        .await?
242        .ok_or(RouteError::LinkNotFound)?;
243
244    let upstream_session = repo
245        .upstream_oauth_session()
246        .lookup(session_id)
247        .await?
248        .ok_or(RouteError::SessionNotFound)?;
249
250    // This checks that we're in a browser session which is allowed to consume this
251    // link: the upstream auth session should have been started in this browser.
252    if upstream_session.link_id() != Some(link.id) {
253        return Err(RouteError::SessionNotFound);
254    }
255
256    if upstream_session.is_consumed() {
257        return Err(RouteError::SessionConsumed);
258    }
259
260    let (user_session_info, cookie_jar) = cookie_jar.session_info();
261    let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
262    let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
263
264    let response = match (maybe_user_session, link.user_id) {
265        (Some(session), Some(user_id)) if session.user.id == user_id => {
266            // Session already linked, and link matches the currently logged
267            // user. Mark the session as consumed and renew the authentication.
268            let upstream_session = repo
269                .upstream_oauth_session()
270                .consume(&clock, upstream_session)
271                .await?;
272
273            repo.browser_session()
274                .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
275                .await?;
276
277            cookie_jar = cookie_jar.set_session(&session);
278
279            repo.save().await?;
280
281            post_auth_action.go_next(&url_builder).into_response()
282        }
283
284        (Some(user_session), Some(user_id)) => {
285            // Session already linked, but link doesn't match the currently
286            // logged user. Suggest logging out of the current user
287            // and logging in with the new one
288            let user = repo
289                .user()
290                .lookup(user_id)
291                .await?
292                .ok_or(RouteError::UserNotFound)?;
293
294            let ctx = UpstreamExistingLinkContext::new(user)
295                .with_session(user_session)
296                .with_csrf(csrf_token.form_value())
297                .with_language(locale);
298
299            Html(templates.render_upstream_oauth2_link_mismatch(&ctx)?).into_response()
300        }
301
302        (Some(user_session), None) => {
303            // Session not linked, but user logged in: suggest linking account
304            let ctx = UpstreamSuggestLink::new(&link)
305                .with_session(user_session)
306                .with_csrf(csrf_token.form_value())
307                .with_language(locale);
308
309            Html(templates.render_upstream_oauth2_suggest_link(&ctx)?).into_response()
310        }
311
312        (None, Some(user_id)) => {
313            // Session linked, but user not logged in: do the login
314            let user = repo
315                .user()
316                .lookup(user_id)
317                .await?
318                .ok_or(RouteError::UserNotFound)?;
319
320            // Check that the user is not locked or deactivated
321            if user.deactivated_at.is_some() {
322                // The account is deactivated, show the 'account deactivated' fallback
323                let ctx = AccountInactiveContext::new(user)
324                    .with_csrf(csrf_token.form_value())
325                    .with_language(locale);
326                let fallback = templates.render_account_deactivated(&ctx)?;
327                return Ok((cookie_jar, Html(fallback).into_response()));
328            }
329
330            if user.locked_at.is_some() {
331                // The account is locked, show the 'account locked' fallback
332                let ctx = AccountInactiveContext::new(user)
333                    .with_csrf(csrf_token.form_value())
334                    .with_language(locale);
335                let fallback = templates.render_account_locked(&ctx)?;
336                return Ok((cookie_jar, Html(fallback).into_response()));
337            }
338
339            let session = repo
340                .browser_session()
341                .add(&mut rng, &clock, &user, user_agent)
342                .await?;
343
344            let upstream_session = repo
345                .upstream_oauth_session()
346                .consume(&clock, upstream_session)
347                .await?;
348
349            repo.browser_session()
350                .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
351                .await?;
352
353            cookie_jar = sessions_cookie
354                .consume_link(link_id)?
355                .save(cookie_jar, &clock);
356            cookie_jar = cookie_jar.set_session(&session);
357
358            repo.save().await?;
359
360            LOGIN_COUNTER.add(
361                1,
362                &[KeyValue::new(
363                    PROVIDER,
364                    upstream_session.provider_id.to_string(),
365                )],
366            );
367
368            post_auth_action.go_next(&url_builder).into_response()
369        }
370
371        (None, None) => {
372            // Session not linked and used not logged in: suggest creating an
373            // account or logging in an existing user
374            let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
375
376            let provider = repo
377                .upstream_oauth_provider()
378                .lookup(link.provider_id)
379                .await?
380                .ok_or(RouteError::ProviderNotFound)?;
381
382            let ctx = UpstreamRegister::new(link.clone(), provider.clone());
383
384            let env = environment();
385
386            let mut context = AttributeMappingContext::new();
387            if let Some(id_token) = id_token {
388                let (_, payload) = id_token.into_parts();
389                context = context.with_id_token_claims(payload);
390            }
391            if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
392                context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
393            }
394            if let Some(userinfo) = upstream_session.userinfo() {
395                context = context.with_userinfo_claims(userinfo.clone());
396            }
397            let context = context.build();
398
399            let ctx = if provider.claims_imports.displayname.ignore() {
400                ctx
401            } else {
402                let template = provider
403                    .claims_imports
404                    .displayname
405                    .template
406                    .as_deref()
407                    .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
408
409                match render_attribute_template(
410                    &env,
411                    template,
412                    &context,
413                    provider.claims_imports.displayname.is_required(),
414                )? {
415                    Some(value) => ctx
416                        .with_display_name(value, provider.claims_imports.displayname.is_forced()),
417                    None => ctx,
418                }
419            };
420
421            let ctx = if provider.claims_imports.email.ignore() {
422                ctx
423            } else {
424                let template = provider
425                    .claims_imports
426                    .email
427                    .template
428                    .as_deref()
429                    .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
430
431                match render_attribute_template(
432                    &env,
433                    template,
434                    &context,
435                    provider.claims_imports.email.is_required(),
436                )? {
437                    Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()),
438                    None => ctx,
439                }
440            };
441
442            let ctx = if provider.claims_imports.localpart.ignore() {
443                ctx
444            } else {
445                let template = provider
446                    .claims_imports
447                    .localpart
448                    .template
449                    .as_deref()
450                    .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
451
452                match render_attribute_template(
453                    &env,
454                    template,
455                    &context,
456                    provider.claims_imports.localpart.is_required(),
457                )? {
458                    Some(localpart) => {
459                        // We could run policy & existing user checks when the user submits the
460                        // form, but this lead to poor UX. This is why we do
461                        // it ahead of time here.
462                        let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
463                        let is_available = homeserver
464                            .is_localpart_available(&localpart)
465                            .await
466                            .map_err(RouteError::HomeserverConnection)?;
467
468                        if maybe_existing_user.is_some() || !is_available {
469                            if let Some(existing_user) = maybe_existing_user {
470                                // The mapper returned a username which already exists, but isn't
471                                // linked to this upstream user.
472                                warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
473                            }
474
475                            // TODO: translate
476                            let ctx = ErrorContext::new()
477                                .with_code("User exists")
478                                .with_description(format!(
479                                    r"Upstream account provider returned {localpart:?} as username,
480                                    which is not linked to that upstream account"
481                                ))
482                                .with_language(&locale);
483
484                            return Ok((
485                                cookie_jar,
486                                Html(templates.render_error(&ctx)?).into_response(),
487                            ));
488                        }
489
490                        let res = policy
491                            .evaluate_register(mas_policy::RegisterInput {
492                                registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
493                                username: &localpart,
494                                email: None,
495                                requester: mas_policy::Requester {
496                                    ip_address: activity_tracker.ip(),
497                                    user_agent: user_agent.clone().map(|ua| ua.raw),
498                                },
499                            })
500                            .await?;
501
502                        if res.valid() {
503                            // The username passes the policy check, add it to the context
504                            ctx.with_localpart(
505                                localpart,
506                                provider.claims_imports.localpart.is_forced(),
507                            )
508                        } else if provider.claims_imports.localpart.is_forced() {
509                            // If the username claim is 'forced' but doesn't pass the policy check,
510                            // we display an error message.
511                            // TODO: translate
512                            let ctx = ErrorContext::new()
513                                .with_code("Policy error")
514                                .with_description(format!(
515                                    r"Upstream account provider returned {localpart:?} as username,
516                                    which does not pass the policy check: {res}"
517                                ))
518                                .with_language(&locale);
519
520                            return Ok((
521                                cookie_jar,
522                                Html(templates.render_error(&ctx)?).into_response(),
523                            ));
524                        } else {
525                            // Else, we just ignore it when it doesn't pass the policy check.
526                            ctx
527                        }
528                    }
529                    None => ctx,
530                }
531            };
532
533            let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
534
535            Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
536        }
537    };
538
539    Ok((cookie_jar, response))
540}
541
542#[tracing::instrument(
543    name = "handlers.upstream_oauth2.link.post",
544    fields(upstream_oauth_link.id = %link_id),
545    skip_all,
546    err,
547)]
548pub(crate) async fn post(
549    mut rng: BoxRng,
550    clock: BoxClock,
551    mut repo: BoxRepository,
552    cookie_jar: CookieJar,
553    user_agent: Option<TypedHeader<headers::UserAgent>>,
554    mut policy: Policy,
555    PreferredLanguage(locale): PreferredLanguage,
556    activity_tracker: BoundActivityTracker,
557    State(templates): State<Templates>,
558    State(homeserver): State<Arc<dyn HomeserverConnection>>,
559    State(url_builder): State<UrlBuilder>,
560    State(site_config): State<SiteConfig>,
561    Path(link_id): Path<Ulid>,
562    Form(form): Form<ProtectedForm<FormData>>,
563) -> Result<Response, RouteError> {
564    let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
565    let form = cookie_jar.verify_form(&clock, form)?;
566
567    let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
568    let (session_id, post_auth_action) = sessions_cookie
569        .lookup_link(link_id)
570        .map_err(|_| RouteError::MissingCookie)?;
571
572    let post_auth_action = OptionalPostAuthAction {
573        post_auth_action: post_auth_action.cloned(),
574    };
575
576    let link = repo
577        .upstream_oauth_link()
578        .lookup(link_id)
579        .await?
580        .ok_or(RouteError::LinkNotFound)?;
581
582    let upstream_session = repo
583        .upstream_oauth_session()
584        .lookup(session_id)
585        .await?
586        .ok_or(RouteError::SessionNotFound)?;
587
588    // This checks that we're in a browser session which is allowed to consume this
589    // link: the upstream auth session should have been started in this browser.
590    if upstream_session.link_id() != Some(link.id) {
591        return Err(RouteError::SessionNotFound);
592    }
593
594    if upstream_session.is_consumed() {
595        return Err(RouteError::SessionConsumed);
596    }
597
598    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
599    let (user_session_info, cookie_jar) = cookie_jar.session_info();
600    let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
601    let form_state = form.to_form_state();
602
603    let session = match (maybe_user_session, link.user_id, form) {
604        (Some(session), None, FormData::Link) => {
605            // The user is already logged in, the link is not linked to any user, and the
606            // user asked to link their account.
607            repo.upstream_oauth_link()
608                .associate_to_user(&link, &session.user)
609                .await?;
610
611            session
612        }
613
614        (
615            None,
616            None,
617            FormData::Register {
618                username,
619                import_email,
620                import_display_name,
621                accept_terms,
622            },
623        ) => {
624            // The user got the form to register a new account, and is not logged in.
625            // Depending on the claims_imports, we've let the user choose their username,
626            // choose whether they want to import the email and display name, or
627            // not.
628
629            // Those fields are Some("on") if the checkbox is checked
630            let import_email = import_email.is_some();
631            let import_display_name = import_display_name.is_some();
632            let accept_terms = accept_terms.is_some();
633
634            let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
635
636            let provider = repo
637                .upstream_oauth_provider()
638                .lookup(link.provider_id)
639                .await?
640                .ok_or(RouteError::ProviderNotFound)?;
641
642            // Let's try to import the claims from the ID token
643            let env = environment();
644
645            let mut context = AttributeMappingContext::new();
646            if let Some(id_token) = id_token {
647                let (_, payload) = id_token.into_parts();
648                context = context.with_id_token_claims(payload);
649            }
650            if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
651                context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
652            }
653            if let Some(userinfo) = upstream_session.userinfo() {
654                context = context.with_userinfo_claims(userinfo.clone());
655            }
656            let context = context.build();
657
658            // Create a template context in case we need to re-render because of an error
659            let ctx = UpstreamRegister::new(link.clone(), provider.clone());
660
661            let display_name = if provider
662                .claims_imports
663                .displayname
664                .should_import(import_display_name)
665            {
666                let template = provider
667                    .claims_imports
668                    .displayname
669                    .template
670                    .as_deref()
671                    .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
672
673                render_attribute_template(
674                    &env,
675                    template,
676                    &context,
677                    provider.claims_imports.displayname.is_required(),
678                )?
679            } else {
680                None
681            };
682
683            let ctx = if let Some(ref display_name) = display_name {
684                ctx.with_display_name(
685                    display_name.clone(),
686                    provider.claims_imports.email.is_forced(),
687                )
688            } else {
689                ctx
690            };
691
692            let email = if provider.claims_imports.email.should_import(import_email) {
693                let template = provider
694                    .claims_imports
695                    .email
696                    .template
697                    .as_deref()
698                    .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
699
700                render_attribute_template(
701                    &env,
702                    template,
703                    &context,
704                    provider.claims_imports.email.is_required(),
705                )?
706            } else {
707                None
708            };
709
710            let ctx = if let Some(ref email) = email {
711                ctx.with_email(email.clone(), provider.claims_imports.email.is_forced())
712            } else {
713                ctx
714            };
715
716            let username = if provider.claims_imports.localpart.is_forced() {
717                let template = provider
718                    .claims_imports
719                    .localpart
720                    .template
721                    .as_deref()
722                    .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
723
724                render_attribute_template(&env, template, &context, true)?
725            } else {
726                // If there is no forced username, we can use the one the user entered
727                username
728            }
729            .unwrap_or_default();
730
731            let ctx = ctx.with_localpart(
732                username.clone(),
733                provider.claims_imports.localpart.is_forced(),
734            );
735
736            // Validate the form
737            let form_state = {
738                let mut form_state = form_state;
739                let mut homeserver_denied_username = false;
740                if username.is_empty() {
741                    form_state.add_error_on_field(
742                        mas_templates::UpstreamRegisterFormField::Username,
743                        FieldError::Required,
744                    );
745                } else if repo.user().exists(&username).await? {
746                    form_state.add_error_on_field(
747                        mas_templates::UpstreamRegisterFormField::Username,
748                        FieldError::Exists,
749                    );
750                } else if !homeserver
751                    .is_localpart_available(&username)
752                    .await
753                    .map_err(RouteError::HomeserverConnection)?
754                {
755                    // The user already exists on the homeserver
756                    tracing::warn!(
757                        %username,
758                        "Homeserver denied username provided by user"
759                    );
760
761                    // We defer adding the error on the field, until we know whether we had another
762                    // error from the policy, to avoid showing both
763                    homeserver_denied_username = true;
764                }
765
766                // If we have a TOS in the config, make sure the user has accepted it
767                if site_config.tos_uri.is_some() && !accept_terms {
768                    form_state.add_error_on_field(
769                        mas_templates::UpstreamRegisterFormField::AcceptTerms,
770                        FieldError::Required,
771                    );
772                }
773
774                // Policy check
775                let res = policy
776                    .evaluate_register(mas_policy::RegisterInput {
777                        registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
778                        username: &username,
779                        email: email.as_deref(),
780                        requester: mas_policy::Requester {
781                            ip_address: activity_tracker.ip(),
782                            user_agent: user_agent.clone().map(|ua| ua.raw),
783                        },
784                    })
785                    .await?;
786
787                for violation in res.violations {
788                    match violation.field.as_deref() {
789                        Some("username") => {
790                            // If the homeserver denied the username, but we also had an error on
791                            // the policy side, we don't want to show
792                            // both, so we reset the state here
793                            homeserver_denied_username = false;
794                            form_state.add_error_on_field(
795                                mas_templates::UpstreamRegisterFormField::Username,
796                                FieldError::Policy {
797                                    code: violation.code.map(|c| c.as_str()),
798                                    message: violation.msg,
799                                },
800                            );
801                        }
802                        _ => form_state.add_error_on_form(FormError::Policy {
803                            code: violation.code.map(|c| c.as_str()),
804                            message: violation.msg,
805                        }),
806                    }
807                }
808
809                if homeserver_denied_username {
810                    // XXX: we may want to return different errors like "this username is reserved"
811                    form_state.add_error_on_field(
812                        mas_templates::UpstreamRegisterFormField::Username,
813                        FieldError::Exists,
814                    );
815                }
816
817                form_state
818            };
819
820            if !form_state.is_valid() {
821                let ctx = ctx
822                    .with_form_state(form_state)
823                    .with_csrf(csrf_token.form_value())
824                    .with_language(locale);
825
826                return Ok((
827                    cookie_jar,
828                    Html(templates.render_upstream_oauth2_do_register(&ctx)?),
829                )
830                    .into_response());
831            }
832
833            REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
834
835            // Now we can create the user
836            let user = repo.user().add(&mut rng, &clock, username).await?;
837
838            if let Some(terms_url) = &site_config.tos_uri {
839                repo.user_terms()
840                    .accept_terms(&mut rng, &clock, &user, terms_url.clone())
841                    .await?;
842            }
843
844            // And schedule the job to provision it
845            let mut job = ProvisionUserJob::new(&user);
846
847            // If we have a display name, set it during provisioning
848            if let Some(name) = display_name {
849                job = job.set_display_name(name);
850            }
851
852            repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
853
854            // If we have an email, add it to the user
855            if let Some(email) = email {
856                repo.user_email()
857                    .add(&mut rng, &clock, &user, email)
858                    .await?;
859            }
860
861            repo.upstream_oauth_link()
862                .associate_to_user(&link, &user)
863                .await?;
864
865            repo.browser_session()
866                .add(&mut rng, &clock, &user, user_agent)
867                .await?
868        }
869
870        _ => return Err(RouteError::InvalidFormAction),
871    };
872
873    let upstream_session = repo
874        .upstream_oauth_session()
875        .consume(&clock, upstream_session)
876        .await?;
877
878    repo.browser_session()
879        .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
880        .await?;
881
882    let cookie_jar = sessions_cookie
883        .consume_link(link_id)?
884        .save(cookie_jar, &clock);
885    let cookie_jar = cookie_jar.set_session(&session);
886
887    repo.save().await?;
888
889    Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
890}
891
892#[cfg(test)]
893mod tests {
894    use hyper::{Request, StatusCode, header::CONTENT_TYPE};
895    use mas_data_model::{
896        UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference,
897        UpstreamOAuthProviderTokenAuthMethod,
898    };
899    use mas_iana::jose::JsonWebSignatureAlg;
900    use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
901    use mas_router::Route;
902    use mas_storage::{
903        Pagination, upstream_oauth2::UpstreamOAuthProviderParams, user::UserEmailFilter,
904    };
905    use oauth2_types::scope::{OPENID, Scope};
906    use sqlx::PgPool;
907
908    use super::UpstreamSessionsCookie;
909    use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
910
911    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
912    async fn test_register(pool: PgPool) {
913        setup();
914        let state = TestState::from_pool(pool).await.unwrap();
915        let mut rng = state.rng();
916        let cookies = CookieHelper::new();
917
918        let claims_imports = UpstreamOAuthProviderClaimsImports {
919            localpart: UpstreamOAuthProviderImportPreference {
920                action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
921                template: None,
922            },
923            email: UpstreamOAuthProviderImportPreference {
924                action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
925                template: None,
926            },
927            ..UpstreamOAuthProviderClaimsImports::default()
928        };
929
930        let id_token = serde_json::json!({
931            "preferred_username": "john",
932            "email": "john@example.com",
933            "email_verified": true,
934        });
935
936        // Grab a key to sign the id_token
937        // We could generate a key on the fly, but because we have one available here,
938        // why not use it?
939        let key = state
940            .key_store
941            .signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256)
942            .unwrap();
943
944        let signer = key
945            .params()
946            .signing_key_for_alg(&JsonWebSignatureAlg::Rs256)
947            .unwrap();
948        let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256);
949        let id_token = Jwt::sign_with_rng(&mut rng, header, id_token, &signer).unwrap();
950
951        // Provision a provider and a link
952        let mut repo = state.repository().await.unwrap();
953        let provider = repo
954            .upstream_oauth_provider()
955            .add(
956                &mut rng,
957                &state.clock,
958                UpstreamOAuthProviderParams {
959                    issuer: Some("https://example.com/".to_owned()),
960                    human_name: Some("Example Ltd.".to_owned()),
961                    brand_name: None,
962                    scope: Scope::from_iter([OPENID]),
963                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
964                    token_endpoint_signing_alg: None,
965                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
966                    client_id: "client".to_owned(),
967                    encrypted_client_secret: None,
968                    claims_imports,
969                    authorization_endpoint_override: None,
970                    token_endpoint_override: None,
971                    userinfo_endpoint_override: None,
972                    fetch_userinfo: false,
973                    userinfo_signed_response_alg: None,
974                    jwks_uri_override: None,
975                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
976                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
977                    response_mode: None,
978                    additional_authorization_parameters: Vec::new(),
979                    ui_order: 0,
980                },
981            )
982            .await
983            .unwrap();
984
985        let session = repo
986            .upstream_oauth_session()
987            .add(
988                &mut rng,
989                &state.clock,
990                &provider,
991                "state".to_owned(),
992                None,
993                "nonce".to_owned(),
994            )
995            .await
996            .unwrap();
997
998        let link = repo
999            .upstream_oauth_link()
1000            .add(
1001                &mut rng,
1002                &state.clock,
1003                &provider,
1004                "subject".to_owned(),
1005                None,
1006            )
1007            .await
1008            .unwrap();
1009
1010        let session = repo
1011            .upstream_oauth_session()
1012            .complete_with_link(
1013                &state.clock,
1014                session,
1015                &link,
1016                Some(id_token.into_string()),
1017                None,
1018                None,
1019            )
1020            .await
1021            .unwrap();
1022
1023        repo.save().await.unwrap();
1024
1025        let cookie_jar = state.cookie_jar();
1026        let upstream_sessions = UpstreamSessionsCookie::default()
1027            .add(session.id, provider.id, "state".to_owned(), None)
1028            .add_link_to_session(session.id, link.id)
1029            .unwrap();
1030        let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock);
1031        cookies.import(cookie_jar);
1032
1033        let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty();
1034        let request = cookies.with_cookies(request);
1035        let response = state.request(request).await;
1036        cookies.save_cookies(&response);
1037        response.assert_status(StatusCode::OK);
1038        response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
1039
1040        // Extract the CSRF token from the response body
1041        let csrf_token = response
1042            .body()
1043            .split("name=\"csrf\" value=\"")
1044            .nth(1)
1045            .unwrap()
1046            .split('\"')
1047            .next()
1048            .unwrap();
1049
1050        let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form(
1051            serde_json::json!({
1052                "csrf": csrf_token,
1053                "action": "register",
1054                "import_email": "on",
1055                "accept_terms": "on",
1056            }),
1057        );
1058        let request = cookies.with_cookies(request);
1059        let response = state.request(request).await;
1060        cookies.save_cookies(&response);
1061        response.assert_status(StatusCode::SEE_OTHER);
1062
1063        // Check that we have a registered user, with the email imported
1064        let mut repo = state.repository().await.unwrap();
1065        let user = repo
1066            .user()
1067            .find_by_username("john")
1068            .await
1069            .unwrap()
1070            .expect("user exists");
1071
1072        let link = repo
1073            .upstream_oauth_link()
1074            .find_by_subject(&provider, "subject")
1075            .await
1076            .unwrap()
1077            .expect("link exists");
1078
1079        assert_eq!(link.user_id, Some(user.id));
1080
1081        let page = repo
1082            .user_email()
1083            .list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
1084            .await
1085            .unwrap();
1086        let email = page.edges.first().expect("email exists");
1087
1088        assert_eq!(email.email, "john@example.com");
1089    }
1090}