1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
67use base64ct::{Base64UrlUnpadded, Encoding};
8use chrono::{DateTime, Duration, Utc};
9use mas_storage::Clock;
10use rand::{Rng, RngCore, distributions::Standard, prelude::Distribution as _};
11use serde::{Deserialize, Serialize};
12use serde_with::{TimestampSeconds, serde_as};
13use thiserror::Error;
1415use crate::cookies::{CookieDecodeError, CookieJar};
1617/// Failed to validate CSRF token
18#[derive(Debug, Error)]
19pub enum CsrfError {
20/// The token in the form did not match the token in the cookie
21#[error("CSRF token mismatch")]
22Mismatch,
2324/// The token in the form did not match the token in the cookie
25#[error("Missing CSRF cookie")]
26Missing,
2728/// Failed to decode the token
29#[error("could not decode CSRF cookie")]
30DecodeCookie(#[from] CookieDecodeError),
3132/// The token expired
33#[error("CSRF token expired")]
34Expired,
3536/// Failed to decode the token
37#[error("could not decode CSRF token")]
38Decode(#[from] base64ct::Error),
39}
4041/// A CSRF token
42#[serde_as]
43#[derive(Serialize, Deserialize, Debug)]
44pub struct CsrfToken {
45#[serde_as(as = "TimestampSeconds<i64>")]
46expiration: DateTime<Utc>,
47 token: [u8; 32],
48}
4950impl CsrfToken {
51/// Create a new token from a defined value valid for a specified duration
52fn new(token: [u8; 32], now: DateTime<Utc>, ttl: Duration) -> Self {
53let expiration = now + ttl;
54Self { expiration, token }
55 }
5657/// Generate a new random token valid for a specified duration
58fn generate(now: DateTime<Utc>, mut rng: impl Rng, ttl: Duration) -> Self {
59let token = Standard.sample(&mut rng);
60Self::new(token, now, ttl)
61 }
6263/// Generate a new token with the same value but an up to date expiration
64fn refresh(self, now: DateTime<Utc>, ttl: Duration) -> Self {
65Self::new(self.token, now, ttl)
66 }
6768/// Get the value to include in HTML forms
69#[must_use]
70pub fn form_value(&self) -> String {
71 Base64UrlUnpadded::encode_string(&self.token[..])
72 }
7374/// Verifies that the value got from an HTML form matches this token
75 ///
76 /// # Errors
77 ///
78 /// Returns an error if the value in the form does not match this token
79pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
80let form_value = Base64UrlUnpadded::decode_vec(form_value)?;
81if self.token[..] == form_value {
82Ok(())
83 } else {
84Err(CsrfError::Mismatch)
85 }
86 }
8788fn verify_expiration(self, now: DateTime<Utc>) -> Result<Self, CsrfError> {
89if now < self.expiration {
90Ok(self)
91 } else {
92Err(CsrfError::Expired)
93 }
94 }
95}
9697// A CSRF-protected form
98#[derive(Deserialize)]
99pub struct ProtectedForm<T> {
100 csrf: String,
101102#[serde(flatten)]
103inner: T,
104}
105106pub trait CsrfExt {
107/// Get the current CSRF token out of the cookie jar, generating a new one
108 /// if necessary
109fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
110where
111R: RngCore,
112 C: Clock;
113114/// Verify that the given CSRF-protected form is valid, returning the inner
115 /// value
116 ///
117 /// # Errors
118 ///
119 /// Returns an error if the CSRF cookie is missing or if the value in the
120 /// form is invalid
121fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
122where
123C: Clock;
124}
125126impl CsrfExt for CookieJar {
127fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
128where
129R: RngCore,
130 C: Clock,
131 {
132let now = clock.now();
133let maybe_token = match self.load::<CsrfToken>("csrf") {
134Ok(Some(token)) => {
135let token = token.verify_expiration(now);
136137// If the token is expired, just ignore it
138token.ok()
139 }
140Ok(None) => None,
141Err(e) => {
142tracing::warn!("Failed to decode CSRF cookie: {}", e);
143None
144}
145 };
146147let token = maybe_token.map_or_else(
148 || CsrfToken::generate(now, rng, Duration::try_hours(1).unwrap()),
149 |token| token.refresh(now, Duration::try_hours(1).unwrap()),
150 );
151152let jar = self.save("csrf", &token, false);
153 (token, jar)
154 }
155156fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
157where
158C: Clock,
159 {
160let token: CsrfToken = self.load("csrf")?.ok_or(CsrfError::Missing)?;
161let token = token.verify_expiration(clock.now())?;
162 token.verify_form_value(&form.csrf)?;
163Ok(form.inner)
164 }
165}