Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions backend/migrations/20260226140000_create_user_2fa.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- Migration to create user_2fa table
CREATE TABLE IF NOT EXISTS user_2fa (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
otp_hash VARCHAR(255) NOT NULL,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
attempts INTEGER DEFAULT 0 NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL
);

CREATE INDEX IF NOT EXISTS idx_user_2fa_user_id ON user_2fa(user_id);

-- Trigger to update updated_at timestamp
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_user_2fa_updated_at') THEN
CREATE TRIGGER update_user_2fa_updated_at
BEFORE UPDATE ON user_2fa
FOR EACH ROW EXECUTE PROCEDURE update_updated_at_column();
END IF;
END $$;
14 changes: 6 additions & 8 deletions backend/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ pub async fn create_app(db: PgPool, config: Config) -> Result<Router, ApiError>
)
.route("/api/auth/web3-login", post(crate::auth::web3_login))
.route("/api/auth/wallet-login", post(crate::auth::web3_login))
.route("/user/send-2fa", post(crate::auth::send_2fa))
.route("/user/verify-2fa", post(crate::auth::verify_2fa))
.layer(
ServiceBuilder::new()
.layer(axum::middleware::from_fn_with_state(
Expand Down Expand Up @@ -161,10 +163,8 @@ async fn create_plan(
return Err(ApiError::Forbidden("KYC not approved".to_string()));
}

// Require 2FA verification (stub, replace with actual logic)
// if !verify_2fa(user.user_id, req.2fa_code) {
// return Err(ApiError::Forbidden("2FA verification failed".to_string()));
// }
// Require 2FA verification
crate::auth::verify_2fa_internal(&state.db, user.user_id, &req.two_fa_code).await?;

// Validate input amounts
crate::safe_math::SafeMath::ensure_non_negative(req.net_amount, "net_amount")?;
Expand Down Expand Up @@ -294,10 +294,8 @@ async fn claim_plan(
return Err(ApiError::Forbidden("KYC not approved".to_string()));
}

// Require 2FA verification (stub, replace with actual logic)
// if !verify_2fa(user.user_id, req.2fa_code) {
// return Err(ApiError::Forbidden("2FA verification failed".to_string()));
// }
// Require 2FA verification
crate::auth::verify_2fa_internal(&state.db, user.user_id, &req.two_fa_code).await?;

let plan = PlanService::claim_plan(&state.db, plan_id, user.user_id, &req).await?;

Expand Down
155 changes: 154 additions & 1 deletion backend/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::api_error::ApiError;
use crate::app::AppState;
use crate::config::Config;
use crate::notifications::AuditLogService;
use axum::{extract::State, Json};
use bcrypt::verify;
use chrono::{Duration, Utc};
use chrono::{DateTime, Duration, Utc};
use hex;
use jsonwebtoken::{encode, EncodingKey, Header};
use ring::signature;
Expand Down Expand Up @@ -41,6 +42,22 @@ pub struct LoginResponse {
pub token: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Send2faRequest {
pub user_id: Uuid,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Verify2faRequest {
pub user_id: Uuid,
pub otp: String,
}

#[derive(Debug, Serialize)]
pub struct TwoFaResponse {
pub message: String,
}

pub async fn get_nonce(
State(state): State<Arc<AppState>>,
Json(payload): Json<NonceRequest>,
Expand Down Expand Up @@ -308,6 +325,142 @@ pub async fn wallet_login(
web3_login(State(state), Json(payload)).await
}

pub async fn send_2fa(
State(state): State<Arc<AppState>>,
Json(payload): Json<Send2faRequest>,
) -> Result<Json<TwoFaResponse>, ApiError> {
// 1. Check if user exists
let user_exists =
sqlx::query_scalar::<_, bool>("SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)")
.bind(payload.user_id)
.fetch_one(&state.db)
.await?;

if !user_exists {
return Err(ApiError::NotFound("User not found".to_string()));
}

// 2. Generate 6-digit OTP
use ring::rand::SecureRandom;
let rng = ring::rand::SystemRandom::new();
let mut bytes = [0u8; 4];
rng.fill(&mut bytes)
.map_err(|_| ApiError::Internal(anyhow::anyhow!("Failed to generate random bytes")))?;

// Generate a number between 100,000 and 999,999
let otp_num = (u32::from_be_bytes(bytes) % 900_000) + 100_000;
let otp = otp_num.to_string();

// 3. Hash OTP
let otp_hash = bcrypt::hash(&otp, bcrypt::DEFAULT_COST)
.map_err(|e| ApiError::Internal(anyhow::anyhow!("Failed to hash OTP: {}", e)))?;

let expires_at = Utc::now() + Duration::minutes(5);

// 4. Store/Update OTP in user_2fa
sqlx::query(
r#"
INSERT INTO user_2fa (user_id, otp_hash, expires_at, attempts)
VALUES ($1, $2, $3, 0)
ON CONFLICT (user_id) DO UPDATE
SET otp_hash = EXCLUDED.otp_hash,
expires_at = EXCLUDED.expires_at,
attempts = 0,
updated_at = NOW()
"#,
)
.bind(payload.user_id)
.bind(&otp_hash)
.bind(expires_at)
.execute(&state.db)
.await?;

// 5. Mock Email Notification
tracing::info!("--- [2FA OTP] ---");
tracing::info!("User ID: {}", payload.user_id);
tracing::info!("OTP Code: {}", otp);
tracing::info!("-----------------");

// Optional: Log to audit logs and notifications
AuditLogService::log(
&state.db,
Some(payload.user_id),
"2fa_sent",
Some(payload.user_id),
Some("user"),
)
.await?;

Ok(Json(TwoFaResponse {
message: "OTP sent successfully".to_string(),
}))
}

pub async fn verify_2fa(
State(state): State<Arc<AppState>>,
Json(payload): Json<Verify2faRequest>,
) -> Result<Json<TwoFaResponse>, ApiError> {
verify_2fa_internal(&state.db, payload.user_id, &payload.otp).await?;

Ok(Json(TwoFaResponse {
message: "OTP verified successfully".to_string(),
}))
}

pub async fn verify_2fa_internal(db: &PgPool, user_id: Uuid, otp: &str) -> Result<(), ApiError> {
let mut tx = db.begin().await?;

// 1. Retrieve OTP record
let row: Option<(String, DateTime<Utc>, i32)> = sqlx::query_as(
"SELECT otp_hash, expires_at, attempts FROM user_2fa WHERE user_id = $1 FOR UPDATE",
)
.bind(user_id)
.fetch_optional(&mut *tx)
.await?;

let (otp_hash, expires_at, attempts) =
row.ok_or_else(|| ApiError::BadRequest("No pending OTP found".to_string()))?;

// 2. Check attempts
if attempts >= 3 {
return Err(ApiError::BadRequest(
"Too many verification attempts. Please request a new OTP.".to_string(),
));
}

// 3. Check expiry
if expires_at < Utc::now() {
return Err(ApiError::BadRequest("OTP has expired".to_string()));
}

// 4. Verify OTP
let valid = bcrypt::verify(otp, &otp_hash)
.map_err(|e| ApiError::Internal(anyhow::anyhow!("Failed to verify OTP: {}", e)))?;

if !valid {
// Increment attempts
sqlx::query(
"UPDATE user_2fa SET attempts = attempts + 1, updated_at = NOW() WHERE user_id = $1",
)
.bind(user_id)
.execute(&mut *tx)
.await?;

tx.commit().await?;
return Err(ApiError::Unauthorized);
}

// 5. Successful verification - Clear OTP
sqlx::query("DELETE FROM user_2fa WHERE user_id = $1")
.bind(user_id)
.execute(&mut *tx)
.await?;

tx.commit().await?;

Ok(())
}

use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use sqlx::PgPool;
Expand Down
4 changes: 2 additions & 2 deletions backend/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,13 @@ pub struct CreatePlanRequest {
pub bank_account_number: Option<String>,
pub bank_name: Option<String>,
pub currency_preference: String,
pub two_fa_code: String,
}

#[derive(Debug, Deserialize)]
pub struct ClaimPlanRequest {
pub beneficiary_email: String,
#[allow(dead_code)]
pub claim_code: Option<u32>,
pub two_fa_code: String,
}

#[derive(sqlx::FromRow)]
Expand Down
43 changes: 36 additions & 7 deletions backend/tests/claim_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async fn test_claim_before_maturity_returns_400() {
};

let pool = test_context.pool.clone();
let app = test_context.app;
let app = test_context.app.clone();

let user_id = Uuid::new_v4();
let email = format!("test_{}@example.com", user_id);
Expand Down Expand Up @@ -165,12 +165,16 @@ async fn test_claim_before_maturity_returns_400() {
.expect("Server failed");
});

let otp = test_context.prepare_2fa(user_id, "123456").await;
let token = generate_test_token(user_id, &email);
let client = reqwest::Client::new();
let response = client
.post(format!("http://{}/api/plans/{}/claim", addr, plan_id))
.header("Authorization", format!("Bearer {}", token))
.json(&json!({ "beneficiary_email": "beneficiary@example.com" }))
.json(&json!({
"beneficiary_email": "beneficiary@example.com",
"two_fa_code": otp
}))
.send()
.await
.expect("Failed to send request");
Expand Down Expand Up @@ -198,9 +202,14 @@ async fn test_claim_plan_is_due() {
approve_kyc_direct(&ctx.pool, user_id).await;
let plan_id = insert_due_plan(&ctx.pool, user_id).await;

let body = serde_json::json!({ "beneficiary_email": "beneficiary@example.com" });
let otp = ctx.prepare_2fa(user_id, "123456").await;
let body = serde_json::json!({
"beneficiary_email": "beneficiary@example.com",
"two_fa_code": otp
});
let response = ctx
.app
.clone()
.oneshot(
Request::builder()
.method("POST")
Expand Down Expand Up @@ -233,9 +242,14 @@ async fn test_claim_requires_kyc_approved() {
let token = generate_user_token(user_id);
let plan_id = insert_due_plan(&ctx.pool, user_id).await;

let body = serde_json::json!({ "beneficiary_email": "beneficiary@example.com" });
let otp = ctx.prepare_2fa(user_id, "111111").await;
let body = serde_json::json!({
"beneficiary_email": "beneficiary@example.com",
"two_fa_code": otp
});
let response = ctx
.app
.clone()
.oneshot(
Request::builder()
.method("POST")
Expand Down Expand Up @@ -265,9 +279,14 @@ async fn test_claim_recorded_on_success() {
approve_kyc_direct(&ctx.pool, user_id).await;
let plan_id = insert_due_plan(&ctx.pool, user_id).await;

let body = serde_json::json!({ "beneficiary_email": "claim-record@example.com" });
let otp = ctx.prepare_2fa(user_id, "123456").await;
let body = serde_json::json!({
"beneficiary_email": "claim-record@example.com",
"two_fa_code": otp
});
let response = ctx
.app
.clone()
.oneshot(
Request::builder()
.method("POST")
Expand Down Expand Up @@ -311,9 +330,14 @@ async fn test_claim_audit_log_inserted() {
approve_kyc_direct(&ctx.pool, user_id).await;
let plan_id = insert_due_plan(&ctx.pool, user_id).await;

let body = serde_json::json!({ "beneficiary_email": "audit-test@example.com" });
let otp = ctx.prepare_2fa(user_id, "123456").await;
let body = serde_json::json!({
"beneficiary_email": "audit-test@example.com",
"two_fa_code": otp
});
let response = ctx
.app
.clone()
.oneshot(
Request::builder()
.method("POST")
Expand Down Expand Up @@ -368,9 +392,14 @@ async fn test_claim_notification_created() {
.await
.expect("Failed to count notifications before claim");

let body = serde_json::json!({ "beneficiary_email": "notify-test@example.com" });
let otp = ctx.prepare_2fa(user_id, "123456").await;
let body = serde_json::json!({
"beneficiary_email": "notify-test@example.com",
"two_fa_code": otp
});
let response = ctx
.app
.clone()
.oneshot(
Request::builder()
.method("POST")
Expand Down
24 changes: 23 additions & 1 deletion backend/tests/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,32 @@ impl TestContext {
jwt_secret: env::var("JWT_SECRET").unwrap_or_else(|_| "test-jwt-secret".to_string()),
};

// Run migrations
inheritx_backend::db::run_migrations(&pool)
.await
.expect("failed to run migrations");

let app = create_app(pool.clone(), config)
.await
.expect("failed to create app");

Some(Self { app, pool })
}

#[allow(dead_code)]
pub async fn prepare_2fa(&self, user_id: uuid::Uuid, otp: &str) -> String {
let otp_hash = bcrypt::hash(otp, bcrypt::DEFAULT_COST).unwrap();
let expires_at = chrono::Utc::now() + chrono::Duration::minutes(5);

sqlx::query(
"INSERT INTO user_2fa (user_id, otp_hash, expires_at) VALUES ($1, $2, $3) ON CONFLICT (user_id) DO UPDATE SET otp_hash = $2, expires_at = $3"
)
.bind(user_id)
.bind(otp_hash)
.bind(expires_at)
.execute(&self.pool)
.await
.unwrap();

otp.to_string()
}
}
Loading