mas_storage/upstream_oauth2/provider.rs
1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::marker::PhantomData;
8
9use async_trait::async_trait;
10use mas_data_model::{
11 UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
12 UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode,
13 UpstreamOAuthProviderTokenAuthMethod,
14};
15use mas_iana::jose::JsonWebSignatureAlg;
16use oauth2_types::scope::Scope;
17use rand_core::RngCore;
18use ulid::Ulid;
19use url::Url;
20
21use crate::{Clock, Pagination, pagination::Page, repository_impl};
22
23/// Structure which holds parameters when inserting or updating an upstream
24/// OAuth 2.0 provider
25pub struct UpstreamOAuthProviderParams {
26 /// The OIDC issuer of the provider
27 pub issuer: Option<String>,
28
29 /// A human-readable name for the provider
30 pub human_name: Option<String>,
31
32 /// A brand identifier, e.g. "apple" or "google"
33 pub brand_name: Option<String>,
34
35 /// The scope to request during the authorization flow
36 pub scope: Scope,
37
38 /// The token endpoint authentication method
39 pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod,
40
41 /// The JWT signing algorithm to use when then `client_secret_jwt` or
42 /// `private_key_jwt` authentication methods are used
43 pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
44
45 /// Expected signature for the JWT payload returned by the token
46 /// authentication endpoint.
47 ///
48 /// Defaults to `RS256`.
49 pub id_token_signed_response_alg: JsonWebSignatureAlg,
50
51 /// Whether to fetch the user profile from the userinfo endpoint,
52 /// or to rely on the data returned in the `id_token` from the
53 /// `token_endpoint`.
54 pub fetch_userinfo: bool,
55
56 /// Expected signature for the JWT payload returned by the userinfo
57 /// endpoint.
58 ///
59 /// If not specified, the response is expected to be an unsigned JSON
60 /// payload. Defaults to `None`.
61 pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
62
63 /// The client ID to use when authenticating to the upstream
64 pub client_id: String,
65
66 /// The encrypted client secret to use when authenticating to the upstream
67 pub encrypted_client_secret: Option<String>,
68
69 /// How claims should be imported from the upstream provider
70 pub claims_imports: UpstreamOAuthProviderClaimsImports,
71
72 /// The URL to use as the authorization endpoint. If `None`, the URL will be
73 /// discovered
74 pub authorization_endpoint_override: Option<Url>,
75
76 /// The URL to use as the token endpoint. If `None`, the URL will be
77 /// discovered
78 pub token_endpoint_override: Option<Url>,
79
80 /// The URL to use as the userinfo endpoint. If `None`, the URL will be
81 /// discovered
82 pub userinfo_endpoint_override: Option<Url>,
83
84 /// The URL to use when fetching JWKS. If `None`, the URL will be discovered
85 pub jwks_uri_override: Option<Url>,
86
87 /// How the provider metadata should be discovered
88 pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
89
90 /// How should PKCE be used
91 pub pkce_mode: UpstreamOAuthProviderPkceMode,
92
93 /// What response mode it should ask
94 pub response_mode: Option<UpstreamOAuthProviderResponseMode>,
95
96 /// Additional parameters to include in the authorization request
97 pub additional_authorization_parameters: Vec<(String, String)>,
98
99 /// The position of the provider in the UI
100 pub ui_order: i32,
101}
102
103/// Filter parameters for listing upstream OAuth 2.0 providers
104#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
105pub struct UpstreamOAuthProviderFilter<'a> {
106 /// Filter by whether the provider is enabled
107 ///
108 /// If `None`, all providers are returned
109 enabled: Option<bool>,
110
111 _lifetime: PhantomData<&'a ()>,
112}
113
114impl UpstreamOAuthProviderFilter<'_> {
115 /// Create a new [`UpstreamOAuthProviderFilter`] with default values
116 #[must_use]
117 pub fn new() -> Self {
118 Self::default()
119 }
120
121 /// Return only enabled providers
122 #[must_use]
123 pub const fn enabled_only(mut self) -> Self {
124 self.enabled = Some(true);
125 self
126 }
127
128 /// Return only disabled providers
129 #[must_use]
130 pub const fn disabled_only(mut self) -> Self {
131 self.enabled = Some(false);
132 self
133 }
134
135 /// Get the enabled filter
136 ///
137 /// Returns `None` if the filter is not set
138 #[must_use]
139 pub const fn enabled(&self) -> Option<bool> {
140 self.enabled
141 }
142}
143
144/// An [`UpstreamOAuthProviderRepository`] helps interacting with
145/// [`UpstreamOAuthProvider`] saved in the storage backend
146#[async_trait]
147pub trait UpstreamOAuthProviderRepository: Send + Sync {
148 /// The error type returned by the repository
149 type Error;
150
151 /// Lookup an upstream OAuth provider by its ID
152 ///
153 /// Returns `None` if the provider was not found
154 ///
155 /// # Parameters
156 ///
157 /// * `id`: The ID of the provider to lookup
158 ///
159 /// # Errors
160 ///
161 /// Returns [`Self::Error`] if the underlying repository fails
162 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
163
164 /// Add a new upstream OAuth provider
165 ///
166 /// Returns the newly created provider
167 ///
168 /// # Parameters
169 ///
170 /// * `rng`: A random number generator
171 /// * `clock`: The clock used to generate timestamps
172 /// * `params`: The parameters of the provider to add
173 ///
174 /// # Errors
175 ///
176 /// Returns [`Self::Error`] if the underlying repository fails
177 async fn add(
178 &mut self,
179 rng: &mut (dyn RngCore + Send),
180 clock: &dyn Clock,
181 params: UpstreamOAuthProviderParams,
182 ) -> Result<UpstreamOAuthProvider, Self::Error>;
183
184 /// Delete an upstream OAuth provider
185 ///
186 /// # Parameters
187 ///
188 /// * `provider`: The provider to delete
189 ///
190 /// # Errors
191 ///
192 /// Returns [`Self::Error`] if the underlying repository fails
193 async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
194 self.delete_by_id(provider.id).await
195 }
196
197 /// Delete an upstream OAuth provider by its ID
198 ///
199 /// # Parameters
200 ///
201 /// * `id`: The ID of the provider to delete
202 ///
203 /// # Errors
204 ///
205 /// Returns [`Self::Error`] if the underlying repository fails
206 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
207
208 /// Insert or update an upstream OAuth provider
209 ///
210 /// # Parameters
211 ///
212 /// * `clock`: The clock used to generate timestamps
213 /// * `id`: The ID of the provider to update
214 /// * `params`: The parameters of the provider to update
215 ///
216 /// # Errors
217 ///
218 /// Returns [`Self::Error`] if the underlying repository fails
219 async fn upsert(
220 &mut self,
221 clock: &dyn Clock,
222 id: Ulid,
223 params: UpstreamOAuthProviderParams,
224 ) -> Result<UpstreamOAuthProvider, Self::Error>;
225
226 /// Disable an upstream OAuth provider
227 ///
228 /// Returns the disabled provider
229 ///
230 /// # Parameters
231 ///
232 /// * `clock`: The clock used to generate timestamps
233 /// * `provider`: The provider to disable
234 ///
235 /// # Errors
236 ///
237 /// Returns [`Self::Error`] if the underlying repository fails
238 async fn disable(
239 &mut self,
240 clock: &dyn Clock,
241 provider: UpstreamOAuthProvider,
242 ) -> Result<UpstreamOAuthProvider, Self::Error>;
243
244 /// List [`UpstreamOAuthProvider`] with the given filter and pagination
245 ///
246 /// # Parameters
247 ///
248 /// * `filter`: The filter to apply
249 /// * `pagination`: The pagination parameters
250 ///
251 /// # Errors
252 ///
253 /// Returns [`Self::Error`] if the underlying repository fails
254 async fn list(
255 &mut self,
256 filter: UpstreamOAuthProviderFilter<'_>,
257 pagination: Pagination,
258 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
259
260 /// Count the number of [`UpstreamOAuthProvider`] with the given filter
261 ///
262 /// # Parameters
263 ///
264 /// * `filter`: The filter to apply
265 ///
266 /// # Errors
267 ///
268 /// Returns [`Self::Error`] if the underlying repository fails
269 async fn count(
270 &mut self,
271 filter: UpstreamOAuthProviderFilter<'_>,
272 ) -> Result<usize, Self::Error>;
273
274 /// Get all enabled upstream OAuth providers
275 ///
276 /// # Errors
277 ///
278 /// Returns [`Self::Error`] if the underlying repository fails
279 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
280}
281
282repository_impl!(UpstreamOAuthProviderRepository:
283 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
284
285 async fn add(
286 &mut self,
287 rng: &mut (dyn RngCore + Send),
288 clock: &dyn Clock,
289 params: UpstreamOAuthProviderParams
290 ) -> Result<UpstreamOAuthProvider, Self::Error>;
291
292 async fn upsert(
293 &mut self,
294 clock: &dyn Clock,
295 id: Ulid,
296 params: UpstreamOAuthProviderParams
297 ) -> Result<UpstreamOAuthProvider, Self::Error>;
298
299 async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
300
301 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
302
303 async fn disable(
304 &mut self,
305 clock: &dyn Clock,
306 provider: UpstreamOAuthProvider
307 ) -> Result<UpstreamOAuthProvider, Self::Error>;
308
309 async fn list(
310 &mut self,
311 filter: UpstreamOAuthProviderFilter<'_>,
312 pagination: Pagination
313 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
314
315 async fn count(
316 &mut self,
317 filter: UpstreamOAuthProviderFilter<'_>
318 ) -> Result<usize, Self::Error>;
319
320 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
321);