mas_storage/upstream_oauth2/
provider.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::marker::PhantomData;
8
9use async_trait::async_trait;
10use mas_data_model::{
11    UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
12    UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode,
13    UpstreamOAuthProviderTokenAuthMethod,
14};
15use mas_iana::jose::JsonWebSignatureAlg;
16use oauth2_types::scope::Scope;
17use rand_core::RngCore;
18use ulid::Ulid;
19use url::Url;
20
21use crate::{Clock, Pagination, pagination::Page, repository_impl};
22
23/// Structure which holds parameters when inserting or updating an upstream
24/// OAuth 2.0 provider
25pub struct UpstreamOAuthProviderParams {
26    /// The OIDC issuer of the provider
27    pub issuer: Option<String>,
28
29    /// A human-readable name for the provider
30    pub human_name: Option<String>,
31
32    /// A brand identifier, e.g. "apple" or "google"
33    pub brand_name: Option<String>,
34
35    /// The scope to request during the authorization flow
36    pub scope: Scope,
37
38    /// The token endpoint authentication method
39    pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod,
40
41    /// The JWT signing algorithm to use when then `client_secret_jwt` or
42    /// `private_key_jwt` authentication methods are used
43    pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
44
45    /// Expected signature for the JWT payload returned by the token
46    /// authentication endpoint.
47    ///
48    /// Defaults to `RS256`.
49    pub id_token_signed_response_alg: JsonWebSignatureAlg,
50
51    /// Whether to fetch the user profile from the userinfo endpoint,
52    /// or to rely on the data returned in the `id_token` from the
53    /// `token_endpoint`.
54    pub fetch_userinfo: bool,
55
56    /// Expected signature for the JWT payload returned by the userinfo
57    /// endpoint.
58    ///
59    /// If not specified, the response is expected to be an unsigned JSON
60    /// payload. Defaults to `None`.
61    pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
62
63    /// The client ID to use when authenticating to the upstream
64    pub client_id: String,
65
66    /// The encrypted client secret to use when authenticating to the upstream
67    pub encrypted_client_secret: Option<String>,
68
69    /// How claims should be imported from the upstream provider
70    pub claims_imports: UpstreamOAuthProviderClaimsImports,
71
72    /// The URL to use as the authorization endpoint. If `None`, the URL will be
73    /// discovered
74    pub authorization_endpoint_override: Option<Url>,
75
76    /// The URL to use as the token endpoint. If `None`, the URL will be
77    /// discovered
78    pub token_endpoint_override: Option<Url>,
79
80    /// The URL to use as the userinfo endpoint. If `None`, the URL will be
81    /// discovered
82    pub userinfo_endpoint_override: Option<Url>,
83
84    /// The URL to use when fetching JWKS. If `None`, the URL will be discovered
85    pub jwks_uri_override: Option<Url>,
86
87    /// How the provider metadata should be discovered
88    pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
89
90    /// How should PKCE be used
91    pub pkce_mode: UpstreamOAuthProviderPkceMode,
92
93    /// What response mode it should ask
94    pub response_mode: Option<UpstreamOAuthProviderResponseMode>,
95
96    /// Additional parameters to include in the authorization request
97    pub additional_authorization_parameters: Vec<(String, String)>,
98
99    /// The position of the provider in the UI
100    pub ui_order: i32,
101}
102
103/// Filter parameters for listing upstream OAuth 2.0 providers
104#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
105pub struct UpstreamOAuthProviderFilter<'a> {
106    /// Filter by whether the provider is enabled
107    ///
108    /// If `None`, all providers are returned
109    enabled: Option<bool>,
110
111    _lifetime: PhantomData<&'a ()>,
112}
113
114impl UpstreamOAuthProviderFilter<'_> {
115    /// Create a new [`UpstreamOAuthProviderFilter`] with default values
116    #[must_use]
117    pub fn new() -> Self {
118        Self::default()
119    }
120
121    /// Return only enabled providers
122    #[must_use]
123    pub const fn enabled_only(mut self) -> Self {
124        self.enabled = Some(true);
125        self
126    }
127
128    /// Return only disabled providers
129    #[must_use]
130    pub const fn disabled_only(mut self) -> Self {
131        self.enabled = Some(false);
132        self
133    }
134
135    /// Get the enabled filter
136    ///
137    /// Returns `None` if the filter is not set
138    #[must_use]
139    pub const fn enabled(&self) -> Option<bool> {
140        self.enabled
141    }
142}
143
144/// An [`UpstreamOAuthProviderRepository`] helps interacting with
145/// [`UpstreamOAuthProvider`] saved in the storage backend
146#[async_trait]
147pub trait UpstreamOAuthProviderRepository: Send + Sync {
148    /// The error type returned by the repository
149    type Error;
150
151    /// Lookup an upstream OAuth provider by its ID
152    ///
153    /// Returns `None` if the provider was not found
154    ///
155    /// # Parameters
156    ///
157    /// * `id`: The ID of the provider to lookup
158    ///
159    /// # Errors
160    ///
161    /// Returns [`Self::Error`] if the underlying repository fails
162    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
163
164    /// Add a new upstream OAuth provider
165    ///
166    /// Returns the newly created provider
167    ///
168    /// # Parameters
169    ///
170    /// * `rng`: A random number generator
171    /// * `clock`: The clock used to generate timestamps
172    /// * `params`: The parameters of the provider to add
173    ///
174    /// # Errors
175    ///
176    /// Returns [`Self::Error`] if the underlying repository fails
177    async fn add(
178        &mut self,
179        rng: &mut (dyn RngCore + Send),
180        clock: &dyn Clock,
181        params: UpstreamOAuthProviderParams,
182    ) -> Result<UpstreamOAuthProvider, Self::Error>;
183
184    /// Delete an upstream OAuth provider
185    ///
186    /// # Parameters
187    ///
188    /// * `provider`: The provider to delete
189    ///
190    /// # Errors
191    ///
192    /// Returns [`Self::Error`] if the underlying repository fails
193    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
194        self.delete_by_id(provider.id).await
195    }
196
197    /// Delete an upstream OAuth provider by its ID
198    ///
199    /// # Parameters
200    ///
201    /// * `id`: The ID of the provider to delete
202    ///
203    /// # Errors
204    ///
205    /// Returns [`Self::Error`] if the underlying repository fails
206    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
207
208    /// Insert or update an upstream OAuth provider
209    ///
210    /// # Parameters
211    ///
212    /// * `clock`: The clock used to generate timestamps
213    /// * `id`: The ID of the provider to update
214    /// * `params`: The parameters of the provider to update
215    ///
216    /// # Errors
217    ///
218    /// Returns [`Self::Error`] if the underlying repository fails
219    async fn upsert(
220        &mut self,
221        clock: &dyn Clock,
222        id: Ulid,
223        params: UpstreamOAuthProviderParams,
224    ) -> Result<UpstreamOAuthProvider, Self::Error>;
225
226    /// Disable an upstream OAuth provider
227    ///
228    /// Returns the disabled provider
229    ///
230    /// # Parameters
231    ///
232    /// * `clock`: The clock used to generate timestamps
233    /// * `provider`: The provider to disable
234    ///
235    /// # Errors
236    ///
237    /// Returns [`Self::Error`] if the underlying repository fails
238    async fn disable(
239        &mut self,
240        clock: &dyn Clock,
241        provider: UpstreamOAuthProvider,
242    ) -> Result<UpstreamOAuthProvider, Self::Error>;
243
244    /// List [`UpstreamOAuthProvider`] with the given filter and pagination
245    ///
246    /// # Parameters
247    ///
248    /// * `filter`: The filter to apply
249    /// * `pagination`: The pagination parameters
250    ///
251    /// # Errors
252    ///
253    /// Returns [`Self::Error`] if the underlying repository fails
254    async fn list(
255        &mut self,
256        filter: UpstreamOAuthProviderFilter<'_>,
257        pagination: Pagination,
258    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
259
260    /// Count the number of [`UpstreamOAuthProvider`] with the given filter
261    ///
262    /// # Parameters
263    ///
264    /// * `filter`: The filter to apply
265    ///
266    /// # Errors
267    ///
268    /// Returns [`Self::Error`] if the underlying repository fails
269    async fn count(
270        &mut self,
271        filter: UpstreamOAuthProviderFilter<'_>,
272    ) -> Result<usize, Self::Error>;
273
274    /// Get all enabled upstream OAuth providers
275    ///
276    /// # Errors
277    ///
278    /// Returns [`Self::Error`] if the underlying repository fails
279    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
280}
281
282repository_impl!(UpstreamOAuthProviderRepository:
283    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
284
285    async fn add(
286        &mut self,
287        rng: &mut (dyn RngCore + Send),
288        clock: &dyn Clock,
289        params: UpstreamOAuthProviderParams
290    ) -> Result<UpstreamOAuthProvider, Self::Error>;
291
292    async fn upsert(
293        &mut self,
294        clock: &dyn Clock,
295        id: Ulid,
296        params: UpstreamOAuthProviderParams
297    ) -> Result<UpstreamOAuthProvider, Self::Error>;
298
299    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
300
301    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
302
303    async fn disable(
304        &mut self,
305        clock: &dyn Clock,
306        provider: UpstreamOAuthProvider
307    ) -> Result<UpstreamOAuthProvider, Self::Error>;
308
309    async fn list(
310        &mut self,
311        filter: UpstreamOAuthProviderFilter<'_>,
312        pagination: Pagination
313    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
314
315    async fn count(
316        &mut self,
317        filter: UpstreamOAuthProviderFilter<'_>
318    ) -> Result<usize, Self::Error>;
319
320    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
321);