1use std::sync::{Arc, LazyLock};
8
9use axum::{
10 Form,
11 extract::{Path, State},
12 response::{Html, IntoResponse, Response},
13};
14use axum_extra::typed_header::TypedHeader;
15use hyper::StatusCode;
16use mas_axum_utils::{
17 FancyError, SessionInfoExt,
18 cookies::CookieJar,
19 csrf::{CsrfExt, ProtectedForm},
20 sentry::SentryEventID,
21};
22use mas_data_model::UserAgent;
23use mas_jose::jwt::Jwt;
24use mas_matrix::HomeserverConnection;
25use mas_policy::Policy;
26use mas_router::UrlBuilder;
27use mas_storage::{
28 BoxClock, BoxRepository, BoxRng, RepositoryAccess,
29 queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
30 upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
31 user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
32};
33use mas_templates::{
34 AccountInactiveContext, ErrorContext, FieldError, FormError, TemplateContext, Templates,
35 ToFormState, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
36};
37use minijinja::Environment;
38use opentelemetry::{Key, KeyValue, metrics::Counter};
39use serde::{Deserialize, Serialize};
40use thiserror::Error;
41use tracing::warn;
42use ulid::Ulid;
43
44use super::{
45 UpstreamSessionsCookie,
46 template::{AttributeMappingContext, environment},
47};
48use crate::{
49 BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route,
50 views::shared::OptionalPostAuthAction,
51};
52
53static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
54 METER
55 .u64_counter("mas.upstream_oauth2.login")
56 .with_description("Successful upstream OAuth 2.0 login to existing accounts")
57 .with_unit("{login}")
58 .build()
59});
60static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
61 METER
62 .u64_counter("mas.upstream_oauth2.registration")
63 .with_description("Successful upstream OAuth 2.0 registration")
64 .with_unit("{registration}")
65 .build()
66});
67const PROVIDER: Key = Key::from_static_str("provider");
68
69const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
70const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}";
71const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}";
72
73#[derive(Debug, Error)]
74pub(crate) enum RouteError {
75 #[error("Link not found")]
77 LinkNotFound,
78
79 #[error("Session not found")]
81 SessionNotFound,
82
83 #[error("User not found")]
85 UserNotFound,
86
87 #[error("Upstream provider not found")]
89 ProviderNotFound,
90
91 #[error("Template {template:?} rendered to an empty string")]
93 RequiredAttributeEmpty { template: String },
94
95 #[error(
97 "Template {template:?} could not be rendered from the upstream provider's response for required claim"
98 )]
99 RequiredAttributeRender {
100 template: String,
101
102 #[source]
103 source: minijinja::Error,
104 },
105
106 #[error("Session already consumed")]
108 SessionConsumed,
109
110 #[error("Missing session cookie")]
111 MissingCookie,
112
113 #[error("Invalid form action")]
114 InvalidFormAction,
115
116 #[error("Homeserver connection error")]
117 HomeserverConnection(#[source] anyhow::Error),
118
119 #[error(transparent)]
120 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
121}
122
123impl_from_error_for_route!(mas_templates::TemplateError);
124impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
125impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
126impl_from_error_for_route!(mas_storage::RepositoryError);
127impl_from_error_for_route!(mas_policy::EvaluationError);
128impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
129
130impl IntoResponse for RouteError {
131 fn into_response(self) -> axum::response::Response {
132 let event_id = sentry::capture_error(&self);
133 let response = match self {
134 Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
135 Self::Internal(e) => FancyError::from(e).into_response(),
136 e => FancyError::from(e).into_response(),
137 };
138
139 (SentryEventID::from(event_id), response).into_response()
140 }
141}
142
143fn render_attribute_template(
156 environment: &Environment,
157 template: &str,
158 context: &minijinja::Value,
159 required: bool,
160) -> Result<Option<String>, RouteError> {
161 match environment.render_str(template, context) {
162 Ok(value) if value.is_empty() => {
163 if required {
164 return Err(RouteError::RequiredAttributeEmpty {
165 template: template.to_owned(),
166 });
167 }
168
169 Ok(None)
170 }
171
172 Ok(value) => Ok(Some(value)),
173
174 Err(source) => {
175 if required {
176 return Err(RouteError::RequiredAttributeRender {
177 template: template.to_owned(),
178 source,
179 });
180 }
181
182 tracing::warn!(error = &source as &dyn std::error::Error, %template, "Error while rendering template");
183 Ok(None)
184 }
185 }
186}
187
188#[derive(Deserialize, Serialize)]
189#[serde(rename_all = "lowercase", tag = "action")]
190pub(crate) enum FormData {
191 Register {
192 #[serde(default)]
193 username: Option<String>,
194 #[serde(default)]
195 import_email: Option<String>,
196 #[serde(default)]
197 import_display_name: Option<String>,
198 #[serde(default)]
199 accept_terms: Option<String>,
200 },
201 Link,
202}
203
204impl ToFormState for FormData {
205 type Field = mas_templates::UpstreamRegisterFormField;
206}
207
208#[tracing::instrument(
209 name = "handlers.upstream_oauth2.link.get",
210 fields(upstream_oauth_link.id = %link_id),
211 skip_all,
212 err,
213)]
214pub(crate) async fn get(
215 mut rng: BoxRng,
216 clock: BoxClock,
217 mut repo: BoxRepository,
218 mut policy: Policy,
219 PreferredLanguage(locale): PreferredLanguage,
220 State(templates): State<Templates>,
221 State(url_builder): State<UrlBuilder>,
222 State(homeserver): State<Arc<dyn HomeserverConnection>>,
223 cookie_jar: CookieJar,
224 activity_tracker: BoundActivityTracker,
225 user_agent: Option<TypedHeader<headers::UserAgent>>,
226 Path(link_id): Path<Ulid>,
227) -> Result<impl IntoResponse, RouteError> {
228 let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
229 let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
230 let (session_id, post_auth_action) = sessions_cookie
231 .lookup_link(link_id)
232 .map_err(|_| RouteError::MissingCookie)?;
233
234 let post_auth_action = OptionalPostAuthAction {
235 post_auth_action: post_auth_action.cloned(),
236 };
237
238 let link = repo
239 .upstream_oauth_link()
240 .lookup(link_id)
241 .await?
242 .ok_or(RouteError::LinkNotFound)?;
243
244 let upstream_session = repo
245 .upstream_oauth_session()
246 .lookup(session_id)
247 .await?
248 .ok_or(RouteError::SessionNotFound)?;
249
250 if upstream_session.link_id() != Some(link.id) {
253 return Err(RouteError::SessionNotFound);
254 }
255
256 if upstream_session.is_consumed() {
257 return Err(RouteError::SessionConsumed);
258 }
259
260 let (user_session_info, cookie_jar) = cookie_jar.session_info();
261 let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
262 let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
263
264 let response = match (maybe_user_session, link.user_id) {
265 (Some(session), Some(user_id)) if session.user.id == user_id => {
266 let upstream_session = repo
269 .upstream_oauth_session()
270 .consume(&clock, upstream_session)
271 .await?;
272
273 repo.browser_session()
274 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
275 .await?;
276
277 cookie_jar = cookie_jar.set_session(&session);
278
279 repo.save().await?;
280
281 post_auth_action.go_next(&url_builder).into_response()
282 }
283
284 (Some(user_session), Some(user_id)) => {
285 let user = repo
289 .user()
290 .lookup(user_id)
291 .await?
292 .ok_or(RouteError::UserNotFound)?;
293
294 let ctx = UpstreamExistingLinkContext::new(user)
295 .with_session(user_session)
296 .with_csrf(csrf_token.form_value())
297 .with_language(locale);
298
299 Html(templates.render_upstream_oauth2_link_mismatch(&ctx)?).into_response()
300 }
301
302 (Some(user_session), None) => {
303 let ctx = UpstreamSuggestLink::new(&link)
305 .with_session(user_session)
306 .with_csrf(csrf_token.form_value())
307 .with_language(locale);
308
309 Html(templates.render_upstream_oauth2_suggest_link(&ctx)?).into_response()
310 }
311
312 (None, Some(user_id)) => {
313 let user = repo
315 .user()
316 .lookup(user_id)
317 .await?
318 .ok_or(RouteError::UserNotFound)?;
319
320 if user.deactivated_at.is_some() {
322 let ctx = AccountInactiveContext::new(user)
324 .with_csrf(csrf_token.form_value())
325 .with_language(locale);
326 let fallback = templates.render_account_deactivated(&ctx)?;
327 return Ok((cookie_jar, Html(fallback).into_response()));
328 }
329
330 if user.locked_at.is_some() {
331 let ctx = AccountInactiveContext::new(user)
333 .with_csrf(csrf_token.form_value())
334 .with_language(locale);
335 let fallback = templates.render_account_locked(&ctx)?;
336 return Ok((cookie_jar, Html(fallback).into_response()));
337 }
338
339 let session = repo
340 .browser_session()
341 .add(&mut rng, &clock, &user, user_agent)
342 .await?;
343
344 let upstream_session = repo
345 .upstream_oauth_session()
346 .consume(&clock, upstream_session)
347 .await?;
348
349 repo.browser_session()
350 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
351 .await?;
352
353 cookie_jar = sessions_cookie
354 .consume_link(link_id)?
355 .save(cookie_jar, &clock);
356 cookie_jar = cookie_jar.set_session(&session);
357
358 repo.save().await?;
359
360 LOGIN_COUNTER.add(
361 1,
362 &[KeyValue::new(
363 PROVIDER,
364 upstream_session.provider_id.to_string(),
365 )],
366 );
367
368 post_auth_action.go_next(&url_builder).into_response()
369 }
370
371 (None, None) => {
372 let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
375
376 let provider = repo
377 .upstream_oauth_provider()
378 .lookup(link.provider_id)
379 .await?
380 .ok_or(RouteError::ProviderNotFound)?;
381
382 let ctx = UpstreamRegister::new(link.clone(), provider.clone());
383
384 let env = environment();
385
386 let mut context = AttributeMappingContext::new();
387 if let Some(id_token) = id_token {
388 let (_, payload) = id_token.into_parts();
389 context = context.with_id_token_claims(payload);
390 }
391 if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
392 context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
393 }
394 if let Some(userinfo) = upstream_session.userinfo() {
395 context = context.with_userinfo_claims(userinfo.clone());
396 }
397 let context = context.build();
398
399 let ctx = if provider.claims_imports.displayname.ignore() {
400 ctx
401 } else {
402 let template = provider
403 .claims_imports
404 .displayname
405 .template
406 .as_deref()
407 .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
408
409 match render_attribute_template(
410 &env,
411 template,
412 &context,
413 provider.claims_imports.displayname.is_required(),
414 )? {
415 Some(value) => ctx
416 .with_display_name(value, provider.claims_imports.displayname.is_forced()),
417 None => ctx,
418 }
419 };
420
421 let ctx = if provider.claims_imports.email.ignore() {
422 ctx
423 } else {
424 let template = provider
425 .claims_imports
426 .email
427 .template
428 .as_deref()
429 .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
430
431 match render_attribute_template(
432 &env,
433 template,
434 &context,
435 provider.claims_imports.email.is_required(),
436 )? {
437 Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()),
438 None => ctx,
439 }
440 };
441
442 let ctx = if provider.claims_imports.localpart.ignore() {
443 ctx
444 } else {
445 let template = provider
446 .claims_imports
447 .localpart
448 .template
449 .as_deref()
450 .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
451
452 match render_attribute_template(
453 &env,
454 template,
455 &context,
456 provider.claims_imports.localpart.is_required(),
457 )? {
458 Some(localpart) => {
459 let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
463 let is_available = homeserver
464 .is_localpart_available(&localpart)
465 .await
466 .map_err(RouteError::HomeserverConnection)?;
467
468 if maybe_existing_user.is_some() || !is_available {
469 if let Some(existing_user) = maybe_existing_user {
470 warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
473 }
474
475 let ctx = ErrorContext::new()
477 .with_code("User exists")
478 .with_description(format!(
479 r"Upstream account provider returned {localpart:?} as username,
480 which is not linked to that upstream account"
481 ))
482 .with_language(&locale);
483
484 return Ok((
485 cookie_jar,
486 Html(templates.render_error(&ctx)?).into_response(),
487 ));
488 }
489
490 let res = policy
491 .evaluate_register(mas_policy::RegisterInput {
492 registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
493 username: &localpart,
494 email: None,
495 requester: mas_policy::Requester {
496 ip_address: activity_tracker.ip(),
497 user_agent: user_agent.clone().map(|ua| ua.raw),
498 },
499 })
500 .await?;
501
502 if res.valid() {
503 ctx.with_localpart(
505 localpart,
506 provider.claims_imports.localpart.is_forced(),
507 )
508 } else if provider.claims_imports.localpart.is_forced() {
509 let ctx = ErrorContext::new()
513 .with_code("Policy error")
514 .with_description(format!(
515 r"Upstream account provider returned {localpart:?} as username,
516 which does not pass the policy check: {res}"
517 ))
518 .with_language(&locale);
519
520 return Ok((
521 cookie_jar,
522 Html(templates.render_error(&ctx)?).into_response(),
523 ));
524 } else {
525 ctx
527 }
528 }
529 None => ctx,
530 }
531 };
532
533 let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
534
535 Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
536 }
537 };
538
539 Ok((cookie_jar, response))
540}
541
542#[tracing::instrument(
543 name = "handlers.upstream_oauth2.link.post",
544 fields(upstream_oauth_link.id = %link_id),
545 skip_all,
546 err,
547)]
548pub(crate) async fn post(
549 mut rng: BoxRng,
550 clock: BoxClock,
551 mut repo: BoxRepository,
552 cookie_jar: CookieJar,
553 user_agent: Option<TypedHeader<headers::UserAgent>>,
554 mut policy: Policy,
555 PreferredLanguage(locale): PreferredLanguage,
556 activity_tracker: BoundActivityTracker,
557 State(templates): State<Templates>,
558 State(homeserver): State<Arc<dyn HomeserverConnection>>,
559 State(url_builder): State<UrlBuilder>,
560 State(site_config): State<SiteConfig>,
561 Path(link_id): Path<Ulid>,
562 Form(form): Form<ProtectedForm<FormData>>,
563) -> Result<Response, RouteError> {
564 let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
565 let form = cookie_jar.verify_form(&clock, form)?;
566
567 let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
568 let (session_id, post_auth_action) = sessions_cookie
569 .lookup_link(link_id)
570 .map_err(|_| RouteError::MissingCookie)?;
571
572 let post_auth_action = OptionalPostAuthAction {
573 post_auth_action: post_auth_action.cloned(),
574 };
575
576 let link = repo
577 .upstream_oauth_link()
578 .lookup(link_id)
579 .await?
580 .ok_or(RouteError::LinkNotFound)?;
581
582 let upstream_session = repo
583 .upstream_oauth_session()
584 .lookup(session_id)
585 .await?
586 .ok_or(RouteError::SessionNotFound)?;
587
588 if upstream_session.link_id() != Some(link.id) {
591 return Err(RouteError::SessionNotFound);
592 }
593
594 if upstream_session.is_consumed() {
595 return Err(RouteError::SessionConsumed);
596 }
597
598 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
599 let (user_session_info, cookie_jar) = cookie_jar.session_info();
600 let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
601 let form_state = form.to_form_state();
602
603 let session = match (maybe_user_session, link.user_id, form) {
604 (Some(session), None, FormData::Link) => {
605 repo.upstream_oauth_link()
608 .associate_to_user(&link, &session.user)
609 .await?;
610
611 session
612 }
613
614 (
615 None,
616 None,
617 FormData::Register {
618 username,
619 import_email,
620 import_display_name,
621 accept_terms,
622 },
623 ) => {
624 let import_email = import_email.is_some();
631 let import_display_name = import_display_name.is_some();
632 let accept_terms = accept_terms.is_some();
633
634 let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
635
636 let provider = repo
637 .upstream_oauth_provider()
638 .lookup(link.provider_id)
639 .await?
640 .ok_or(RouteError::ProviderNotFound)?;
641
642 let env = environment();
644
645 let mut context = AttributeMappingContext::new();
646 if let Some(id_token) = id_token {
647 let (_, payload) = id_token.into_parts();
648 context = context.with_id_token_claims(payload);
649 }
650 if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
651 context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
652 }
653 if let Some(userinfo) = upstream_session.userinfo() {
654 context = context.with_userinfo_claims(userinfo.clone());
655 }
656 let context = context.build();
657
658 let ctx = UpstreamRegister::new(link.clone(), provider.clone());
660
661 let display_name = if provider
662 .claims_imports
663 .displayname
664 .should_import(import_display_name)
665 {
666 let template = provider
667 .claims_imports
668 .displayname
669 .template
670 .as_deref()
671 .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
672
673 render_attribute_template(
674 &env,
675 template,
676 &context,
677 provider.claims_imports.displayname.is_required(),
678 )?
679 } else {
680 None
681 };
682
683 let ctx = if let Some(ref display_name) = display_name {
684 ctx.with_display_name(
685 display_name.clone(),
686 provider.claims_imports.email.is_forced(),
687 )
688 } else {
689 ctx
690 };
691
692 let email = if provider.claims_imports.email.should_import(import_email) {
693 let template = provider
694 .claims_imports
695 .email
696 .template
697 .as_deref()
698 .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
699
700 render_attribute_template(
701 &env,
702 template,
703 &context,
704 provider.claims_imports.email.is_required(),
705 )?
706 } else {
707 None
708 };
709
710 let ctx = if let Some(ref email) = email {
711 ctx.with_email(email.clone(), provider.claims_imports.email.is_forced())
712 } else {
713 ctx
714 };
715
716 let username = if provider.claims_imports.localpart.is_forced() {
717 let template = provider
718 .claims_imports
719 .localpart
720 .template
721 .as_deref()
722 .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
723
724 render_attribute_template(&env, template, &context, true)?
725 } else {
726 username
728 }
729 .unwrap_or_default();
730
731 let ctx = ctx.with_localpart(
732 username.clone(),
733 provider.claims_imports.localpart.is_forced(),
734 );
735
736 let form_state = {
738 let mut form_state = form_state;
739 let mut homeserver_denied_username = false;
740 if username.is_empty() {
741 form_state.add_error_on_field(
742 mas_templates::UpstreamRegisterFormField::Username,
743 FieldError::Required,
744 );
745 } else if repo.user().exists(&username).await? {
746 form_state.add_error_on_field(
747 mas_templates::UpstreamRegisterFormField::Username,
748 FieldError::Exists,
749 );
750 } else if !homeserver
751 .is_localpart_available(&username)
752 .await
753 .map_err(RouteError::HomeserverConnection)?
754 {
755 tracing::warn!(
757 %username,
758 "Homeserver denied username provided by user"
759 );
760
761 homeserver_denied_username = true;
764 }
765
766 if site_config.tos_uri.is_some() && !accept_terms {
768 form_state.add_error_on_field(
769 mas_templates::UpstreamRegisterFormField::AcceptTerms,
770 FieldError::Required,
771 );
772 }
773
774 let res = policy
776 .evaluate_register(mas_policy::RegisterInput {
777 registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
778 username: &username,
779 email: email.as_deref(),
780 requester: mas_policy::Requester {
781 ip_address: activity_tracker.ip(),
782 user_agent: user_agent.clone().map(|ua| ua.raw),
783 },
784 })
785 .await?;
786
787 for violation in res.violations {
788 match violation.field.as_deref() {
789 Some("username") => {
790 homeserver_denied_username = false;
794 form_state.add_error_on_field(
795 mas_templates::UpstreamRegisterFormField::Username,
796 FieldError::Policy {
797 code: violation.code.map(|c| c.as_str()),
798 message: violation.msg,
799 },
800 );
801 }
802 _ => form_state.add_error_on_form(FormError::Policy {
803 code: violation.code.map(|c| c.as_str()),
804 message: violation.msg,
805 }),
806 }
807 }
808
809 if homeserver_denied_username {
810 form_state.add_error_on_field(
812 mas_templates::UpstreamRegisterFormField::Username,
813 FieldError::Exists,
814 );
815 }
816
817 form_state
818 };
819
820 if !form_state.is_valid() {
821 let ctx = ctx
822 .with_form_state(form_state)
823 .with_csrf(csrf_token.form_value())
824 .with_language(locale);
825
826 return Ok((
827 cookie_jar,
828 Html(templates.render_upstream_oauth2_do_register(&ctx)?),
829 )
830 .into_response());
831 }
832
833 REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
834
835 let user = repo.user().add(&mut rng, &clock, username).await?;
837
838 if let Some(terms_url) = &site_config.tos_uri {
839 repo.user_terms()
840 .accept_terms(&mut rng, &clock, &user, terms_url.clone())
841 .await?;
842 }
843
844 let mut job = ProvisionUserJob::new(&user);
846
847 if let Some(name) = display_name {
849 job = job.set_display_name(name);
850 }
851
852 repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
853
854 if let Some(email) = email {
856 repo.user_email()
857 .add(&mut rng, &clock, &user, email)
858 .await?;
859 }
860
861 repo.upstream_oauth_link()
862 .associate_to_user(&link, &user)
863 .await?;
864
865 repo.browser_session()
866 .add(&mut rng, &clock, &user, user_agent)
867 .await?
868 }
869
870 _ => return Err(RouteError::InvalidFormAction),
871 };
872
873 let upstream_session = repo
874 .upstream_oauth_session()
875 .consume(&clock, upstream_session)
876 .await?;
877
878 repo.browser_session()
879 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
880 .await?;
881
882 let cookie_jar = sessions_cookie
883 .consume_link(link_id)?
884 .save(cookie_jar, &clock);
885 let cookie_jar = cookie_jar.set_session(&session);
886
887 repo.save().await?;
888
889 Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
890}
891
892#[cfg(test)]
893mod tests {
894 use hyper::{Request, StatusCode, header::CONTENT_TYPE};
895 use mas_data_model::{
896 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference,
897 UpstreamOAuthProviderTokenAuthMethod,
898 };
899 use mas_iana::jose::JsonWebSignatureAlg;
900 use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
901 use mas_router::Route;
902 use mas_storage::{
903 Pagination, upstream_oauth2::UpstreamOAuthProviderParams, user::UserEmailFilter,
904 };
905 use oauth2_types::scope::{OPENID, Scope};
906 use sqlx::PgPool;
907
908 use super::UpstreamSessionsCookie;
909 use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
910
911 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
912 async fn test_register(pool: PgPool) {
913 setup();
914 let state = TestState::from_pool(pool).await.unwrap();
915 let mut rng = state.rng();
916 let cookies = CookieHelper::new();
917
918 let claims_imports = UpstreamOAuthProviderClaimsImports {
919 localpart: UpstreamOAuthProviderImportPreference {
920 action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
921 template: None,
922 },
923 email: UpstreamOAuthProviderImportPreference {
924 action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
925 template: None,
926 },
927 ..UpstreamOAuthProviderClaimsImports::default()
928 };
929
930 let id_token = serde_json::json!({
931 "preferred_username": "john",
932 "email": "john@example.com",
933 "email_verified": true,
934 });
935
936 let key = state
940 .key_store
941 .signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256)
942 .unwrap();
943
944 let signer = key
945 .params()
946 .signing_key_for_alg(&JsonWebSignatureAlg::Rs256)
947 .unwrap();
948 let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256);
949 let id_token = Jwt::sign_with_rng(&mut rng, header, id_token, &signer).unwrap();
950
951 let mut repo = state.repository().await.unwrap();
953 let provider = repo
954 .upstream_oauth_provider()
955 .add(
956 &mut rng,
957 &state.clock,
958 UpstreamOAuthProviderParams {
959 issuer: Some("https://example.com/".to_owned()),
960 human_name: Some("Example Ltd.".to_owned()),
961 brand_name: None,
962 scope: Scope::from_iter([OPENID]),
963 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
964 token_endpoint_signing_alg: None,
965 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
966 client_id: "client".to_owned(),
967 encrypted_client_secret: None,
968 claims_imports,
969 authorization_endpoint_override: None,
970 token_endpoint_override: None,
971 userinfo_endpoint_override: None,
972 fetch_userinfo: false,
973 userinfo_signed_response_alg: None,
974 jwks_uri_override: None,
975 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
976 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
977 response_mode: None,
978 additional_authorization_parameters: Vec::new(),
979 ui_order: 0,
980 },
981 )
982 .await
983 .unwrap();
984
985 let session = repo
986 .upstream_oauth_session()
987 .add(
988 &mut rng,
989 &state.clock,
990 &provider,
991 "state".to_owned(),
992 None,
993 "nonce".to_owned(),
994 )
995 .await
996 .unwrap();
997
998 let link = repo
999 .upstream_oauth_link()
1000 .add(
1001 &mut rng,
1002 &state.clock,
1003 &provider,
1004 "subject".to_owned(),
1005 None,
1006 )
1007 .await
1008 .unwrap();
1009
1010 let session = repo
1011 .upstream_oauth_session()
1012 .complete_with_link(
1013 &state.clock,
1014 session,
1015 &link,
1016 Some(id_token.into_string()),
1017 None,
1018 None,
1019 )
1020 .await
1021 .unwrap();
1022
1023 repo.save().await.unwrap();
1024
1025 let cookie_jar = state.cookie_jar();
1026 let upstream_sessions = UpstreamSessionsCookie::default()
1027 .add(session.id, provider.id, "state".to_owned(), None)
1028 .add_link_to_session(session.id, link.id)
1029 .unwrap();
1030 let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock);
1031 cookies.import(cookie_jar);
1032
1033 let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty();
1034 let request = cookies.with_cookies(request);
1035 let response = state.request(request).await;
1036 cookies.save_cookies(&response);
1037 response.assert_status(StatusCode::OK);
1038 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
1039
1040 let csrf_token = response
1042 .body()
1043 .split("name=\"csrf\" value=\"")
1044 .nth(1)
1045 .unwrap()
1046 .split('\"')
1047 .next()
1048 .unwrap();
1049
1050 let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form(
1051 serde_json::json!({
1052 "csrf": csrf_token,
1053 "action": "register",
1054 "import_email": "on",
1055 "accept_terms": "on",
1056 }),
1057 );
1058 let request = cookies.with_cookies(request);
1059 let response = state.request(request).await;
1060 cookies.save_cookies(&response);
1061 response.assert_status(StatusCode::SEE_OTHER);
1062
1063 let mut repo = state.repository().await.unwrap();
1065 let user = repo
1066 .user()
1067 .find_by_username("john")
1068 .await
1069 .unwrap()
1070 .expect("user exists");
1071
1072 let link = repo
1073 .upstream_oauth_link()
1074 .find_by_subject(&provider, "subject")
1075 .await
1076 .unwrap()
1077 .expect("link exists");
1078
1079 assert_eq!(link.user_id, Some(user.id));
1080
1081 let page = repo
1082 .user_email()
1083 .list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
1084 .await
1085 .unwrap();
1086 let email = page.edges.first().expect("email exists");
1087
1088 assert_eq!(email.email, "john@example.com");
1089 }
1090}