| 1 |
|
| 2 |
|
| 3 |
use axum::{ |
| 4 |
extract::{Path, State}, |
| 5 |
response::{IntoResponse, Response}, |
| 6 |
Form, Json, |
| 7 |
}; |
| 8 |
use serde::Deserialize; |
| 9 |
use tower_sessions::Session; |
| 10 |
use webauthn_rs::prelude::*; |
| 11 |
|
| 12 |
use crate::{ |
| 13 |
auth::{verify_password, AuthUser}, |
| 14 |
db::{self, PasskeyId}, |
| 15 |
error::{AppError, Result, ResultExt}, |
| 16 |
helpers::hx_toast, |
| 17 |
templates::{PasskeyListTemplate, PasskeyDisplay}, |
| 18 |
AppState, |
| 19 |
}; |
| 20 |
|
| 21 |
|
| 22 |
const PASSKEY_REG_STATE_KEY: &str = "passkey_reg_state"; |
| 23 |
|
| 24 |
|
| 25 |
const MAX_PASSKEYS_PER_USER: i64 = 20; |
| 26 |
|
| 27 |
|
| 28 |
#[derive(Deserialize)] |
| 29 |
pub struct RegisterStartForm { |
| 30 |
password: String, |
| 31 |
} |
| 32 |
|
| 33 |
|
| 34 |
|
| 35 |
#[tracing::instrument(skip_all, name = "passkeys::register_start")] |
| 36 |
pub(super) async fn register_start( |
| 37 |
State(state): State<AppState>, |
| 38 |
AuthUser(user): AuthUser, |
| 39 |
session: Session, |
| 40 |
Form(form): Form<RegisterStartForm>, |
| 41 |
) -> Result<Response> { |
| 42 |
user.check_not_sandbox()?; |
| 43 |
|
| 44 |
|
| 45 |
let db_user = db::users::get_user_by_id(&state.db, user.id) |
| 46 |
.await? |
| 47 |
.ok_or(AppError::Unauthorized)?; |
| 48 |
if !verify_password(&form.password, &db_user.password_hash)? { |
| 49 |
return Err(AppError::BadRequest("Incorrect password".to_string())); |
| 50 |
} |
| 51 |
|
| 52 |
let count = db::passkeys::count_passkeys(&state.db, user.id).await?; |
| 53 |
if count >= MAX_PASSKEYS_PER_USER { |
| 54 |
return Err(AppError::BadRequest(format!( |
| 55 |
"Maximum of {} passkeys reached", |
| 56 |
MAX_PASSKEYS_PER_USER |
| 57 |
))); |
| 58 |
} |
| 59 |
|
| 60 |
|
| 61 |
let existing_json = db::passkeys::get_passkey_credentials(&state.db, user.id).await?; |
| 62 |
let exclude_creds: Vec<CredentialID> = existing_json |
| 63 |
.iter() |
| 64 |
.filter_map(|j| serde_json::from_value::<Passkey>(j.clone()).ok()) |
| 65 |
.map(|pk| pk.cred_id().clone()) |
| 66 |
.collect(); |
| 67 |
|
| 68 |
let exclude = if exclude_creds.is_empty() { |
| 69 |
None |
| 70 |
} else { |
| 71 |
Some(exclude_creds) |
| 72 |
}; |
| 73 |
|
| 74 |
let (ccr, reg_state) = state |
| 75 |
.webauthn |
| 76 |
.start_passkey_registration( |
| 77 |
*user.id.as_uuid(), |
| 78 |
user.username.as_ref(), |
| 79 |
user.username.as_ref(), |
| 80 |
exclude, |
| 81 |
) |
| 82 |
.context("webauthn registration start")?; |
| 83 |
|
| 84 |
|
| 85 |
session |
| 86 |
.insert(PASSKEY_REG_STATE_KEY, ®_state) |
| 87 |
.await |
| 88 |
.context("session error")?; |
| 89 |
|
| 90 |
Ok(Json(ccr).into_response()) |
| 91 |
} |
| 92 |
|
| 93 |
|
| 94 |
#[tracing::instrument(skip_all, name = "passkeys::register_finish")] |
| 95 |
pub(super) async fn register_finish( |
| 96 |
State(state): State<AppState>, |
| 97 |
AuthUser(user): AuthUser, |
| 98 |
session: Session, |
| 99 |
Json(reg): Json<RegisterPublicKeyCredential>, |
| 100 |
) -> Result<Response> { |
| 101 |
let reg_state: PasskeyRegistration = session |
| 102 |
.get(PASSKEY_REG_STATE_KEY) |
| 103 |
.await |
| 104 |
.context("session error")? |
| 105 |
.ok_or_else(|| AppError::BadRequest("No pending registration".to_string()))?; |
| 106 |
|
| 107 |
|
| 108 |
session |
| 109 |
.remove::<PasskeyRegistration>(PASSKEY_REG_STATE_KEY) |
| 110 |
.await |
| 111 |
.ok(); |
| 112 |
|
| 113 |
let passkey = state |
| 114 |
.webauthn |
| 115 |
.finish_passkey_registration(®, ®_state) |
| 116 |
.map_err(|e| AppError::BadRequest(format!("Registration failed: {}", e)))?; |
| 117 |
|
| 118 |
let credential_json = serde_json::to_value(&passkey) |
| 119 |
.context("serialize passkey")?; |
| 120 |
let credential_id = passkey.cred_id().to_vec(); |
| 121 |
|
| 122 |
db::passkeys::create_passkey(&state.db, user.id, "Passkey", &credential_json, &credential_id) |
| 123 |
.await?; |
| 124 |
|
| 125 |
tracing::info!(user_id = %user.id, event = "passkey_registered", "Passkey registered"); |
| 126 |
|
| 127 |
Ok(( |
| 128 |
[("HX-Trigger", hx_toast("Passkey registered", "success"))], |
| 129 |
list_inner(&state, user.id).await?, |
| 130 |
) |
| 131 |
.into_response()) |
| 132 |
} |
| 133 |
|
| 134 |
|
| 135 |
#[tracing::instrument(skip_all, name = "passkeys::list")] |
| 136 |
pub(super) async fn list( |
| 137 |
State(state): State<AppState>, |
| 138 |
AuthUser(user): AuthUser, |
| 139 |
) -> Result<Response> { |
| 140 |
Ok(list_inner(&state, user.id).await?.into_response()) |
| 141 |
} |
| 142 |
|
| 143 |
|
| 144 |
async fn list_inner(state: &AppState, user_id: db::UserId) -> Result<PasskeyListTemplate> { |
| 145 |
let passkeys = db::passkeys::list_passkeys(&state.db, user_id).await?; |
| 146 |
let passkeys = passkeys |
| 147 |
.into_iter() |
| 148 |
.map(|p| PasskeyDisplay { |
| 149 |
id: p.id.to_string(), |
| 150 |
name: p.name, |
| 151 |
created_at: p.created_at.format("%Y-%m-%d").to_string(), |
| 152 |
last_used_at: p.last_used_at.map(|d| d.format("%Y-%m-%d").to_string()), |
| 153 |
}) |
| 154 |
.collect(); |
| 155 |
|
| 156 |
Ok(PasskeyListTemplate { passkeys }) |
| 157 |
} |
| 158 |
|
| 159 |
|
| 160 |
#[derive(Deserialize)] |
| 161 |
pub struct RenameForm { |
| 162 |
name: String, |
| 163 |
} |
| 164 |
|
| 165 |
#[tracing::instrument(skip_all, name = "passkeys::rename")] |
| 166 |
pub(super) async fn rename( |
| 167 |
State(state): State<AppState>, |
| 168 |
AuthUser(user): AuthUser, |
| 169 |
Path(id): Path<PasskeyId>, |
| 170 |
Form(form): Form<RenameForm>, |
| 171 |
) -> Result<Response> { |
| 172 |
let name = form.name.trim(); |
| 173 |
if name.is_empty() || name.len() > 100 { |
| 174 |
return Err(AppError::validation("Name must be 1-100 characters".to_string())); |
| 175 |
} |
| 176 |
|
| 177 |
if !db::passkeys::rename_passkey(&state.db, id, user.id, name).await? { |
| 178 |
return Err(AppError::NotFound); |
| 179 |
} |
| 180 |
|
| 181 |
Ok(( |
| 182 |
[("HX-Trigger", hx_toast("Passkey renamed", "success"))], |
| 183 |
list_inner(&state, user.id).await?, |
| 184 |
) |
| 185 |
.into_response()) |
| 186 |
} |
| 187 |
|
| 188 |
|
| 189 |
#[derive(Deserialize)] |
| 190 |
pub struct DeleteForm { |
| 191 |
password: String, |
| 192 |
} |
| 193 |
|
| 194 |
#[tracing::instrument(skip_all, name = "passkeys::delete")] |
| 195 |
pub(super) async fn delete( |
| 196 |
State(state): State<AppState>, |
| 197 |
AuthUser(user): AuthUser, |
| 198 |
Path(id): Path<PasskeyId>, |
| 199 |
Form(form): Form<DeleteForm>, |
| 200 |
) -> Result<Response> { |
| 201 |
let db_user = db::users::get_user_by_id(&state.db, user.id) |
| 202 |
.await? |
| 203 |
.ok_or(AppError::Unauthorized)?; |
| 204 |
|
| 205 |
if !verify_password(&form.password, &db_user.password_hash)? { |
| 206 |
return Err(AppError::BadRequest("Incorrect password".to_string())); |
| 207 |
} |
| 208 |
|
| 209 |
if !db::passkeys::delete_passkey(&state.db, id, user.id).await? { |
| 210 |
return Err(AppError::NotFound); |
| 211 |
} |
| 212 |
|
| 213 |
tracing::info!(user_id = %user.id, passkey_id = %id, event = "passkey_deleted", "Passkey deleted"); |
| 214 |
|
| 215 |
Ok(( |
| 216 |
[("HX-Trigger", hx_toast("Passkey deleted", "success"))], |
| 217 |
list_inner(&state, user.id).await?, |
| 218 |
) |
| 219 |
.into_response()) |
| 220 |
} |
| 221 |
|