AxumとAWS Cognitoで認証トークン検証付きAPIを作る

2024年05月02日

やること

AWS Cognitoから発行されたIDトークンをAxumでデコードする。 良いライブラリが見つからなかったので自前で実装します。

Claim

IDトークンのデコード表現を定義します。IDトークンの仕様は公式Docを参照してください。今回は以下の定義とします。

#[derive(Debug, Deserialize)]
pub struct Claim {
    #[serde(with = "ts_seconds")]
    pub exp: DateTime<Utc>,
    #[serde(with = "ts_seconds")]
    pub iat: DateTime<Utc>,
    pub aud: String,
    pub iss: String,
    pub sub: String,
    #[serde(with = "ts_seconds")]
    pub auth_time: DateTime<Utc>,
}

Routeハンドラでの受け取り方は以下のようなイメージです。Claimのデコードに失敗するとextractorの中でエラーとし、ハンドラに入る前に401 Unauthorizedを返すようにします。

async fn handler(claim: Claim, State(state): State<AppState>) -> Result<impl IntoResponse, CustomError> {
    let mut conn = state.acquire().await?;
    let user: User = get_user(&mut conn, &claim.sub).await?;
    ...
}

extractor

Routeハンドラでclaimを受け取れるようにするにはAxumの作法に従ってextractorを実装する必要があります。

全体の実装イメージから。 FromRequestPartsトレイトの実装です。HTTPヘッダのAuthorizationヘッダからBearerトークンを取得して、JWT検証をしてClaim表現にデコードして返却します。

#[async_trait]
impl<S> FromRequestParts<S> for Claim
where
    AppState: FromRef<S>,
    S: Send + Sync,
{
    type Rejection = CustomError;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let TypedHeader(Authorization(bearer)) = parts
            .extract::<TypedHeader<Authorization<Bearer>>>()
            .await
            .map_err(|_| CustomError::Unauthorized)?;
        let app_state = AppState::from_ref(state);
        let mut tx = app_state.pool.acquire().await?;
        let claim = verify_token(&mut tx, bearer.token()).await?;
        Ok(claim)
    }
}

AxumのRouteハンドラで独自の型を引数展開をさせるには axum::extract::FromRequestParts か axum::extract::FromRequest を実装します。 FromRequestPartsはHTTPヘッダのみ参照する場合、FromRequestはHTTPのリクエストボディも参照したい場合でケース毎に使い分けます。

今回はCognitoのIDトークンはクライアントからHTTPヘッダに乗せる前提とするのでFromRequestPartsを実装します。

エンドポイントへのリクエストは以下のようにAuthorizationヘッダにIDトークンを乗せて、extratorでトークン検証とデコードを行います。

curl -X GET -H 'Authorization: Bearer <IDトークン>` http://localhost:8000/

HTTPリクエストは上記のイメージです。

IDトークンを検証してデコードする

IDトークンの検証にはCognitoの公開鍵でJWTを検証します。

async fn verify_token(conn: &mut MySqlConnection, token_str: &str) -> Result<Claim, CustomError> {
    let header = decode_header(token_str).map_err(|_| CustomError::Unauthorized)?;
    let kid = header.kid.clone().ok_or(CustomError::Unauthorized)?;

    // DBにキャッシュした公開鍵の有効期間をチェック、期限切れならJWKを再取得する
    let public_key = find_cognito_public_key(conn, &kid)
        .await
        .and_then(|public_key| {
            // DBにキャッシュエントリが残っててもMaxAgeを過ぎていたら再取得する
            if public_key.expires_at < Utc::now() {
                tracing::info!(
                    "公開鍵のMaxAgeが切れているため再取得: now = {}, expires_at = {}",
                    Utc::now().with_timezone(&Tokyo),
                    public_key.expires_at.with_timezone(&Tokyo),
                );
                Err(CustomError::KeyExpired)
            } else {
                Ok(public_key)
            }
        });

    let public_key = match public_key {
        Ok(public_key) => public_key,
        Err(_) => {
            let (keys, cache_control) = fetch_cognito_jwks(&header).await?;
            let key = keys.get_key(&kid).ok_or(CustomError::Unauthorized)?;
            save_cognito_public_key(
                conn,
                &kid,
                &key.n,
                &key.e,
                Utc::now() + Duration::seconds(cache_control as i64),
            )
            .await?
        }
    };

    // 公開キーでIDトークンを検証&デコードする
    let token = decode_token(token_str, &header, &public_key)?;

    Ok(token.claims)
}

CognitoのJWKはユーザープールの概要にトークン署名キーURLとして公開URLがアサインされているので、このURLを環境変数に切り出しておきます。

cognito jwk

JWKを取得するコード例は以下です。 毎回JWKをネットワーク越しに取得しにいくと通信のオーバーヘッド分パフォーマンスに影響するので、JWKはDBやRedisなどにキーをキャッシュしておくと良いです。取得した公開キーの有効期間はレスポンスヘッダのCache-Controlのmax-ageに有効秒数が返されてくるので、その期間はキャッシュのキーを利用するように実装します。

/// JWKを取得して有効期間とともに返す
async fn fetch_cognito_jwks() -> Result<(Keys, usize), CustomError> {
    let jwk_url = std::env::var("COGNITO_JWK_URL").map_err(|_| CustomError::InternalServerError)?;
    let response = reqwest::get(jwk_url).await.map_err(|e| {
        CustomError::CantConnectToCognitoJwkUrl
    })?;

    let cache_control = response
        .headers()
        .get("Cache-Control")
        .and_then(|header_val| match header_val.to_str() {
            Ok(s) => get_max_age(s),
            Err(_) => None,
        })
        .ok_or(CustomError::MaxAgeParseError)?;

    let val = response
        .json::<Keys>()
        .await
        .map_err(|_| CustomError::JwkBodyParseError)?;

    Ok((val, cache_control))
}

取得した公開キーを用いてIDトークンを実際に検証&デコードします。 IDトークンはJWTのフォーマットで署名とエンコードされているので、デコードにはjsonwebtokenを使います。以下コードはjsonwebtokenの8.3.0で動作実績がありますが、最新の9.x.xだとAPIの互換がなくなっていて動かないかも。

fn decode_token(
    token_str: &str,
    header: &Header,
    public_key: &CognitoRotatingPublicKey,
) -> Result<TokenData<Claim>, Error> {
    let token = decode::<Claim>(
        token_str,
        &DecodingKey::from_rsa_components(&public_key.n, &public_key.e).map_err(|e| {
            tracing::error!("公開鍵の生成に失敗: {:?}", e);
            Error::UnrecoverableError("Cannot decode cognito rsa key".to_string())
        })?,
        &Validation::new(header.alg),
    )
    .map_err(|e| {
        tracing::info!("Claimのverifyデコードに失敗: {:?}", e);
        Error::Unauthorized
    })?;
    Ok(token)
}

さいごに

CognitoのIDトークンの検証方法をご紹介しました。 Cognitoの認証認可はOIDCのプロトコルで実装されているので、今回のコード例はFirebase Authや他のOIDC認証システムにも応用可能です。Firebase AuthのIDトークン検証はほぼコピペで動きました。

それでは今日はここまで。良きRustライフを!


Profile picture

Written by なまちゃ Web系エンジニアPython好き。バックエンド/フロントエンド問わずマルチな方面でエンジニアリングしています。