mas_handlers/upstream_oauth2/
callback.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::LazyLock;
8
9use axum::{
10    Form,
11    extract::{Path, State},
12    http::Method,
13    response::{Html, IntoResponse, Response},
14};
15use hyper::StatusCode;
16use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID};
17use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode};
18use mas_jose::claims::TokenHash;
19use mas_keystore::{Encrypter, Keystore};
20use mas_oidc_client::requests::jose::JwtVerificationData;
21use mas_router::UrlBuilder;
22use mas_storage::{
23    BoxClock, BoxRepository, BoxRng, Clock,
24    upstream_oauth2::{
25        UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
26        UpstreamOAuthSessionRepository,
27    },
28};
29use mas_templates::{FormPostContext, Templates};
30use oauth2_types::{errors::ClientErrorCode, requests::AccessTokenRequest};
31use opentelemetry::{Key, KeyValue, metrics::Counter};
32use serde::{Deserialize, Serialize};
33use serde_json::json;
34use thiserror::Error;
35use ulid::Ulid;
36
37use super::{
38    UpstreamSessionsCookie,
39    cache::LazyProviderInfos,
40    client_credentials_for_provider,
41    template::{AttributeMappingContext, environment},
42};
43use crate::{
44    METER, PreferredLanguage, impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
45};
46
47static CALLBACK_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
48    METER
49        .u64_counter("mas.upstream_oauth2.callback")
50        .with_description("Number of requests to the upstream OAuth2 callback endpoint")
51        .build()
52});
53const PROVIDER: Key = Key::from_static_str("provider");
54const RESULT: Key = Key::from_static_str("result");
55
56#[derive(Serialize, Deserialize)]
57pub struct Params {
58    #[serde(skip_serializing_if = "Option::is_none")]
59    state: Option<String>,
60
61    /// An extra parameter to track whether the POST request was re-made by us
62    /// to the same URL to escape Same-Site cookies restrictions
63    #[serde(default)]
64    did_mas_repost_to_itself: bool,
65
66    #[serde(skip_serializing_if = "Option::is_none")]
67    code: Option<String>,
68
69    #[serde(skip_serializing_if = "Option::is_none")]
70    error: Option<ClientErrorCode>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    error_description: Option<String>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    error_uri: Option<String>,
75
76    #[serde(flatten)]
77    extra_callback_parameters: Option<serde_json::Value>,
78}
79
80impl Params {
81    /// Returns true if none of the fields are set
82    pub fn is_empty(&self) -> bool {
83        self.state.is_none()
84            && self.code.is_none()
85            && self.error.is_none()
86            && self.error_description.is_none()
87            && self.error_uri.is_none()
88    }
89}
90
91#[derive(Debug, Error)]
92pub(crate) enum RouteError {
93    #[error("Session not found")]
94    SessionNotFound,
95
96    #[error("Provider not found")]
97    ProviderNotFound,
98
99    #[error("Provider mismatch")]
100    ProviderMismatch,
101
102    #[error("Session already completed")]
103    AlreadyCompleted,
104
105    #[error("State parameter mismatch")]
106    StateMismatch,
107
108    #[error("Missing state parameter")]
109    MissingState,
110
111    #[error("Missing code parameter")]
112    MissingCode,
113
114    #[error("Could not extract subject from ID token")]
115    ExtractSubject(#[source] minijinja::Error),
116
117    #[error("Subject is empty")]
118    EmptySubject,
119
120    #[error("Error from the provider: {error}")]
121    ClientError {
122        error: ClientErrorCode,
123        error_description: Option<String>,
124    },
125
126    #[error("Missing session cookie")]
127    MissingCookie,
128
129    #[error("Missing query parameters")]
130    MissingQueryParams,
131
132    #[error("Missing form parameters")]
133    MissingFormParams,
134
135    #[error("Invalid response mode, expected '{expected}'")]
136    InvalidResponseMode {
137        expected: UpstreamOAuthProviderResponseMode,
138    },
139
140    #[error(transparent)]
141    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
142}
143
144impl_from_error_for_route!(mas_templates::TemplateError);
145impl_from_error_for_route!(mas_storage::RepositoryError);
146impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
147impl_from_error_for_route!(mas_oidc_client::error::JwksError);
148impl_from_error_for_route!(mas_oidc_client::error::TokenRequestError);
149impl_from_error_for_route!(mas_oidc_client::error::IdTokenError);
150impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
151impl_from_error_for_route!(super::ProviderCredentialsError);
152impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
153
154impl IntoResponse for RouteError {
155    fn into_response(self) -> axum::response::Response {
156        let event_id = sentry::capture_error(&self);
157        let response = match self {
158            Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
159            Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
160            Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
161            e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
162        };
163
164        (SentryEventID::from(event_id), response).into_response()
165    }
166}
167
168#[tracing::instrument(
169    name = "handlers.upstream_oauth2.callback.handler",
170    fields(upstream_oauth_provider.id = %provider_id),
171    skip_all,
172    err,
173)]
174#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
175pub(crate) async fn handler(
176    mut rng: BoxRng,
177    clock: BoxClock,
178    State(metadata_cache): State<MetadataCache>,
179    mut repo: BoxRepository,
180    State(url_builder): State<UrlBuilder>,
181    State(encrypter): State<Encrypter>,
182    State(keystore): State<Keystore>,
183    State(client): State<reqwest::Client>,
184    State(templates): State<Templates>,
185    method: Method,
186    PreferredLanguage(locale): PreferredLanguage,
187    cookie_jar: CookieJar,
188    Path(provider_id): Path<Ulid>,
189    Form(params): Form<Params>,
190) -> Result<Response, RouteError> {
191    let provider = repo
192        .upstream_oauth_provider()
193        .lookup(provider_id)
194        .await?
195        .filter(UpstreamOAuthProvider::enabled)
196        .ok_or(RouteError::ProviderNotFound)?;
197
198    let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
199
200    if params.is_empty() {
201        if let Method::GET = method {
202            return Err(RouteError::MissingQueryParams);
203        }
204
205        return Err(RouteError::MissingFormParams);
206    }
207
208    // The `Form` extractor will use the body of the request for POST requests and
209    // the query parameters for GET requests. We need to then look at the method do
210    // make sure it matches the expected `response_mode`
211    match (provider.response_mode, method) {
212        (Some(UpstreamOAuthProviderResponseMode::FormPost) | None, Method::POST) => {
213            // We set the cookies with a `Same-Site` policy set to `Lax`, so because this is
214            // usually a cross-site form POST, we need to render a form with the
215            // same values, which posts back to the same URL. However, there are
216            // other valid reasons for the cookie to be missing, so to track whether we did
217            // this POST ourselves, we set a flag.
218            if sessions_cookie.is_empty() && !params.did_mas_repost_to_itself {
219                let params = Params {
220                    did_mas_repost_to_itself: true,
221                    ..params
222                };
223                let context = FormPostContext::new_for_current_url(params).with_language(&locale);
224                let html = templates.render_form_post(&context)?;
225                return Ok(Html(html).into_response());
226            }
227        }
228        (None, _) | (Some(UpstreamOAuthProviderResponseMode::Query), Method::GET) => {}
229        (Some(expected), _) => return Err(RouteError::InvalidResponseMode { expected }),
230    }
231
232    if let Some(error) = params.error {
233        CALLBACK_COUNTER.add(
234            1,
235            &[
236                KeyValue::new(PROVIDER, provider_id.to_string()),
237                KeyValue::new(RESULT, "error"),
238            ],
239        );
240
241        return Err(RouteError::ClientError {
242            error,
243            error_description: params.error_description.clone(),
244        });
245    }
246
247    let Some(state) = params.state else {
248        return Err(RouteError::MissingState);
249    };
250
251    let (session_id, _post_auth_action) = sessions_cookie
252        .find_session(provider_id, &state)
253        .map_err(|_| RouteError::MissingCookie)?;
254
255    let session = repo
256        .upstream_oauth_session()
257        .lookup(session_id)
258        .await?
259        .ok_or(RouteError::SessionNotFound)?;
260
261    if provider.id != session.provider_id {
262        // The provider in the session cookie should match the one from the URL
263        return Err(RouteError::ProviderMismatch);
264    }
265
266    if state != session.state_str {
267        // The state in the session cookie should match the one from the params
268        return Err(RouteError::StateMismatch);
269    }
270
271    if !session.is_pending() {
272        // The session was already completed
273        return Err(RouteError::AlreadyCompleted);
274    }
275
276    // Let's extract the code from the params, and return if there was an error
277    let Some(code) = params.code else {
278        return Err(RouteError::MissingCode);
279    };
280
281    CALLBACK_COUNTER.add(
282        1,
283        &[
284            KeyValue::new(PROVIDER, provider_id.to_string()),
285            KeyValue::new(RESULT, "success"),
286        ],
287    );
288
289    let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
290
291    // Figure out the client credentials
292    let client_credentials = client_credentials_for_provider(
293        &provider,
294        lazy_metadata.token_endpoint().await?,
295        &keystore,
296        &encrypter,
297    )?;
298
299    let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
300
301    let token_response = mas_oidc_client::requests::token::request_access_token(
302        &client,
303        client_credentials,
304        lazy_metadata.token_endpoint().await?,
305        AccessTokenRequest::AuthorizationCode(oauth2_types::requests::AuthorizationCodeGrant {
306            code: code.clone(),
307            redirect_uri: Some(redirect_uri),
308            code_verifier: session.code_challenge_verifier.clone(),
309        }),
310        clock.now(),
311        &mut rng,
312    )
313    .await?;
314
315    let mut jwks = None;
316
317    let mut context = AttributeMappingContext::new();
318    if let Some(id_token) = token_response.id_token.as_ref() {
319        jwks = Some(
320            mas_oidc_client::requests::jose::fetch_jwks(&client, lazy_metadata.jwks_uri().await?)
321                .await?,
322        );
323
324        let id_token_verification_data = JwtVerificationData {
325            issuer: provider.issuer.as_deref(),
326            jwks: jwks.as_ref().unwrap(),
327            signing_algorithm: &provider.id_token_signed_response_alg,
328            client_id: &provider.client_id,
329        };
330
331        // Decode and verify the ID token
332        let id_token = mas_oidc_client::requests::jose::verify_id_token(
333            id_token,
334            id_token_verification_data,
335            None,
336            clock.now(),
337        )?;
338
339        let (_headers, mut claims) = id_token.into_parts();
340
341        // Access token hash must match.
342        mas_jose::claims::AT_HASH
343            .extract_optional_with_options(
344                &mut claims,
345                TokenHash::new(
346                    id_token_verification_data.signing_algorithm,
347                    &token_response.access_token,
348                ),
349            )
350            .map_err(mas_oidc_client::error::IdTokenError::from)?;
351
352        // Code hash must match.
353        mas_jose::claims::C_HASH
354            .extract_optional_with_options(
355                &mut claims,
356                TokenHash::new(id_token_verification_data.signing_algorithm, &code),
357            )
358            .map_err(mas_oidc_client::error::IdTokenError::from)?;
359
360        // Nonce must match.
361        mas_jose::claims::NONCE
362            .extract_required_with_options(&mut claims, session.nonce.as_str())
363            .map_err(mas_oidc_client::error::IdTokenError::from)?;
364
365        context = context.with_id_token_claims(claims);
366    }
367
368    if let Some(extra_callback_parameters) = params.extra_callback_parameters.clone() {
369        context = context.with_extra_callback_parameters(extra_callback_parameters);
370    }
371
372    let userinfo = if provider.fetch_userinfo {
373        Some(json!(match &provider.userinfo_signed_response_alg {
374            Some(signing_algorithm) => {
375                let jwks = match jwks {
376                    Some(jwks) => jwks,
377                    None => {
378                        mas_oidc_client::requests::jose::fetch_jwks(
379                            &client,
380                            lazy_metadata.jwks_uri().await?,
381                        )
382                        .await?
383                    }
384                };
385
386                mas_oidc_client::requests::userinfo::fetch_userinfo(
387                    &client,
388                    lazy_metadata.userinfo_endpoint().await?,
389                    token_response.access_token.as_str(),
390                    Some(JwtVerificationData {
391                        issuer: provider.issuer.as_deref(),
392                        jwks: &jwks,
393                        signing_algorithm,
394                        client_id: &provider.client_id,
395                    }),
396                )
397                .await?
398            }
399            None => {
400                mas_oidc_client::requests::userinfo::fetch_userinfo(
401                    &client,
402                    lazy_metadata.userinfo_endpoint().await?,
403                    token_response.access_token.as_str(),
404                    None,
405                )
406                .await?
407            }
408        }))
409    } else {
410        None
411    };
412
413    if let Some(userinfo) = userinfo.clone() {
414        context = context.with_userinfo_claims(userinfo);
415    }
416
417    let context = context.build();
418
419    let env = environment();
420
421    let template = provider
422        .claims_imports
423        .subject
424        .template
425        .as_deref()
426        .unwrap_or("{{ user.sub }}");
427    let subject = env
428        .render_str(template, context.clone())
429        .map_err(RouteError::ExtractSubject)?;
430
431    if subject.is_empty() {
432        return Err(RouteError::EmptySubject);
433    }
434
435    // Look for an existing link
436    let maybe_link = repo
437        .upstream_oauth_link()
438        .find_by_subject(&provider, &subject)
439        .await?;
440
441    let link = if let Some(link) = maybe_link {
442        link
443    } else {
444        // Try to render the human account name if we have one,
445        // but just log if it fails
446        let human_account_name = provider
447            .claims_imports
448            .account_name
449            .template
450            .as_deref()
451            .and_then(|template| match env.render_str(template, context) {
452                Ok(name) => Some(name),
453                Err(e) => {
454                    tracing::warn!(
455                        error = &e as &dyn std::error::Error,
456                        "Failed to render account name"
457                    );
458                    None
459                }
460            });
461
462        repo.upstream_oauth_link()
463            .add(&mut rng, &clock, &provider, subject, human_account_name)
464            .await?
465    };
466
467    let session = repo
468        .upstream_oauth_session()
469        .complete_with_link(
470            &clock,
471            session,
472            &link,
473            token_response.id_token,
474            params.extra_callback_parameters,
475            userinfo,
476        )
477        .await?;
478
479    let cookie_jar = sessions_cookie
480        .add_link_to_session(session.id, link.id)?
481        .save(cookie_jar, &clock);
482
483    repo.save().await?;
484
485    Ok((
486        cookie_jar,
487        url_builder.redirect(&mas_router::UpstreamOAuth2Link::new(link.id)),
488    )
489        .into_response())
490}