Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions common/include/common/IAuthenticationManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

// Standard imports
#include <string>

namespace SDMS {

struct LogContext;
/**
* Interface class for managing authenticating
*
Expand All @@ -26,7 +25,7 @@ class IAuthenticationManager {
* Increments the number of times that the key has been accessed, this is
*useful information when deciding if a key should be purged.
**/
virtual void incrementKeyAccessCounter(const std::string &public_key) = 0;
virtual void incrementKeyAccessCounter(const std::string &public_key, LogContext log_context) = 0;

/**
* Will return true if the public key is known. This is also dependent on the
Expand All @@ -39,7 +38,7 @@ class IAuthenticationManager {
* - SESSION
* - PERSISTENT
**/
virtual bool hasKey(const std::string &pub_key) const = 0;
virtual bool hasKey(const std::string &pub_key, LogContext log_context) const = 0;

/**
* Will get the unique id or throw an error
Expand All @@ -49,7 +48,7 @@ class IAuthenticationManager {
* - SESSION
* - PERSISTENT - user or repo
**/
virtual std::string getUID(const std::string &pub_key) const = 0;
virtual std::string getUID(const std::string &pub_key, LogContext log_context) const = 0;

/**
* Purge keys if needed
Expand Down
15 changes: 12 additions & 3 deletions common/source/operators/AuthenticationOperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

// Local public includes
#include "common/TraceException.hpp"
#include "common/DynaLog.hpp"

// Standard includes
#include <any>
#include <iostream>
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>

namespace SDMS {

Expand All @@ -25,17 +29,22 @@ void AuthenticationOperator::execute(IMessage &message) {
if (message.exists(MessageAttribute::KEY) == 0) {
EXCEPT(1, "'KEY' attribute not defined.");
}
// 🔹 Generate correlation ID for this request
boost::uuids::random_generator generator;
boost::uuids::uuid uuid = generator();

LogContext log_context;
log_context.correlation_id = boost::uuids::to_string(uuid);
m_authentication_manager->purge();

std::string key = std::get<std::string>(message.get(MessageAttribute::KEY));

std::string uid = "anon";
if (m_authentication_manager->hasKey(key)) {
m_authentication_manager->incrementKeyAccessCounter(key);
if (m_authentication_manager->hasKey(key, log_context)) {
m_authentication_manager->incrementKeyAccessCounter(key, log_context);

try {
uid = m_authentication_manager->getUID(key);
uid = m_authentication_manager->getUID(key, log_context);
} catch (const std::exception& e) {
// Log the exception to help diagnose authentication issues
std::cerr << "[AuthenticationOperator] Failed to get UID for key: "
Expand Down
7 changes: 4 additions & 3 deletions common/tests/unit/test_OperatorFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "common/MessageFactory.hpp"
#include "common/OperatorFactory.hpp"
#include "common/OperatorTypes.hpp"
#include "common/DynaLog.hpp"

// Third party includes
#include <google/protobuf/stubs/common.h>
Expand Down Expand Up @@ -38,15 +39,15 @@ class DummyAuthManager : public IAuthenticationManager {
/**
* Methods only available via the interface
**/
virtual void incrementKeyAccessCounter(const std::string &pub_key) final {
virtual void incrementKeyAccessCounter(const std::string &pub_key, LogContext log_context) final {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Improve Operator/Authentication integration tests to validate LogContext usage and correlation-id generation

DummyAuthManager now accepts a LogContext, but the tests don't verify how it's used. Since AuthenticationOperator::execute generates a correlation ID and passes the same LogContext to hasKey, incrementKeyAccessCounter, and getUID, please enhance the tests to:

  • Have DummyAuthManager store the last LogContext it receives, and
  • Assert that execute() provides a non-empty correlation ID and reuses the same LogContext across all three calls.

This will validate the new wiring rather than just exercising the updated signature.

Suggested implementation:

  /**
   * Methods only available via the interface
   **/
  virtual void incrementKeyAccessCounter(const std::string &pub_key, LogContext log_context) final {
    ++m_counters.at(pub_key);
    m_log_contexts.push_back(log_context);
  }

  virtual bool hasKey(const std::string &pub_key, LogContext log_context) const {
    m_log_contexts.push_back(log_context);
    return m_counters.count(pub_key);
  }
  // Just assume all keys map to the anon_uid
  virtual std::string getUID(const std::string &, LogContext log_context) const {
    m_log_contexts.push_back(log_context);
    return "authenticated_uid";
  }

  const std::vector<LogContext> &logContexts() const {
    return m_log_contexts;
  }

To fully implement the test coverage you described, you will also need to:

  1. Add a member field to DummyAuthManager (in the same class where m_counters is defined):

    mutable std::vector<LogContext> m_log_contexts;

    Make sure the header <vector> is included if it is not already.

  2. In the relevant Operator/Authentication integration test that invokes AuthenticationOperator::execute(...) (or uses OperatorFactory to create the authentication operator with DummyAuthManager), add assertions such as:

    const auto &log_contexts = dummy_auth_manager.logContexts();
    ASSERT_EQ(log_contexts.size(), 3u); // hasKey, incrementKeyAccessCounter, getUID
    
    // All three calls must receive the same LogContext instance
    EXPECT_TRUE(log_contexts[0] == log_contexts[1]);
    EXPECT_TRUE(log_contexts[1] == log_contexts[2]);
    
    // And the correlation id must be non-empty
    const auto &ctx = log_contexts[0];
    // Adjust the accessor below to however correlation-id is exposed by LogContext
    EXPECT_FALSE(ctx.correlation_id().empty());

    If LogContext does not support operator== or correlation_id(), adapt the checks accordingly (for example, compare pointer identity if LogContext is a smart pointer type, or use the appropriate getter for the correlation ID as defined in common/DynaLog.hpp).

  3. If the same DummyAuthManager instance is reused across tests, ensure you clear m_log_contexts (e.g., via a clearLogContexts() helper) between test cases or create a fresh instance per test so that logContexts().size() is deterministic.

++m_counters.at(pub_key);
}

virtual bool hasKey(const std::string &pub_key) const {
virtual bool hasKey(const std::string &pub_key, LogContext log_context) const {
return m_counters.count(pub_key);
}
// Just assume all keys map to the anon_uid
virtual std::string getUID(const std::string &) const {
virtual std::string getUID(const std::string &, LogContext log_context) const {
return "authenticated_uid";
}

Expand Down
18 changes: 11 additions & 7 deletions core/server/AuthMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ size_t AuthMap::size(const PublicKeyType pub_key_type) const {
}

void AuthMap::incrementKeyAccessCounter(const PublicKeyType pub_key_type,
const std::string &public_key) {
const std::string &public_key,
LogContext log_context) {
if (pub_key_type == PublicKeyType::TRANSIENT) {
lock_guard<mutex> lock(m_trans_clients_mtx);
if (m_trans_auth_clients.count(public_key)) {
Expand All @@ -183,7 +184,8 @@ void AuthMap::incrementKeyAccessCounter(const PublicKeyType pub_key_type,
}

bool AuthMap::hasKey(const PublicKeyType pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {
if (pub_key_type == PublicKeyType::TRANSIENT) {
lock_guard<mutex> lock(m_trans_clients_mtx);
return m_trans_auth_clients.count(public_key) > 0;
Expand All @@ -203,7 +205,7 @@ bool AuthMap::hasKey(const PublicKeyType pub_key_type,
try {
DatabaseAPI db(m_db_url, m_db_user, m_db_pass);
std::string uid;
if (db.uidByPubKey(public_key, uid)) {
if (db.uidByPubKey(public_key, uid, log_context)) {
return true;
}
} catch (const std::exception& e) {
Expand All @@ -217,9 +219,10 @@ bool AuthMap::hasKey(const PublicKeyType pub_key_type,
}

std::string AuthMap::getUID(const PublicKeyType pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {

std::string uid = getUIDSafe(pub_key_type, public_key);
std::string uid = getUIDSafe(pub_key_type, public_key, log_context);

if (uid.empty()) {
if (pub_key_type == PublicKeyType::TRANSIENT) {
Expand All @@ -238,7 +241,8 @@ std::string AuthMap::getUID(const PublicKeyType pub_key_type,
}

std::string AuthMap::getUIDSafe(const PublicKeyType pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {
if (pub_key_type == PublicKeyType::TRANSIENT) {
lock_guard<mutex> lock(m_trans_clients_mtx);
if (m_trans_auth_clients.count(public_key)) {
Expand All @@ -261,7 +265,7 @@ std::string AuthMap::getUIDSafe(const PublicKeyType pub_key_type,
// Check database for user keys
DatabaseAPI db(m_db_url, m_db_user, m_db_pass);
std::string uid;
if (db.uidByPubKey(public_key, uid)) {
if (db.uidByPubKey(public_key, uid, log_context)) {
return uid;
}
}
Expand Down
13 changes: 9 additions & 4 deletions core/server/AuthMap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// Local common includes
#include "common/IAuthenticationManager.hpp"
#include "common/DynaLog.hpp"

// Standard includes
#include <map>
Expand Down Expand Up @@ -113,13 +114,15 @@ class AuthMap {
*does not exist. Best to call hasKey first.
**/
std::string getUID(const PublicKeyType pub_key_type,
const std::string &public_key) const;
const std::string &public_key,
LogContext log_context) const;

/**
* Safe version that returns empty string if key not found
**/
std::string getUIDSafe(const PublicKeyType pub_key_type,
const std::string &public_key) const;
const std::string &public_key,
LogContext log_context) const;

/**
* Will return the number of keys of the provided type. Does not currently
Expand All @@ -128,7 +131,8 @@ class AuthMap {
size_t size(const PublicKeyType pub_key_type) const;

bool hasKey(const PublicKeyType pub_key_type,
const std::string &public_key) const;
const std::string &public_key,
LogContext log_context) const;

/***********************************************************************************
* Manipulators
Expand All @@ -138,7 +142,8 @@ class AuthMap {
* Increase the recorded times the the public key has been accessed by one.
**/
void incrementKeyAccessCounter(const PublicKeyType pub_key_type,
const std::string &public_key);
const std::string &public_key,
LogContext log_context);

/**
* Adds the key to the AuthMap object
Expand Down
47 changes: 25 additions & 22 deletions core/server/AuthenticationManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// Common includes
#include "common/TraceException.hpp"
#include "common/DynaLog.hpp"

// Standard includes
#include <iostream>
Expand Down Expand Up @@ -69,46 +70,47 @@ void AuthenticationManager::purge(const PublicKeyType pub_key_type) {
}

void AuthenticationManager::incrementKeyAccessCounter(
const std::string &public_key) {
const std::string &public_key,
LogContext log_context) {
std::lock_guard<std::mutex> lock(m_lock);
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key, log_context)) {
m_auth_mapper.incrementKeyAccessCounter(PublicKeyType::TRANSIENT,
public_key);
} else if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key)) {
m_auth_mapper.incrementKeyAccessCounter(PublicKeyType::SESSION, public_key);
public_key, log_context);
} else if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key, log_context)) {
m_auth_mapper.incrementKeyAccessCounter(PublicKeyType::SESSION, public_key, log_context);
}
// Ignore persistent cases because counter does nothing for them
}

bool AuthenticationManager::hasKey(const std::string &public_key) const {
bool AuthenticationManager::hasKey(const std::string &public_key, LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);

if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key, log_context)) {
return true;
}

if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key, log_context)) {
return true;
}

if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key, log_context)) {
return true;
}

return false;
}

std::string AuthenticationManager::getUID(const std::string &public_key) const {
std::string AuthenticationManager::getUID(const std::string &public_key, LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);

if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key)) {
return m_auth_mapper.getUID(PublicKeyType::TRANSIENT, public_key);
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key, log_context)) {
return m_auth_mapper.getUID(PublicKeyType::TRANSIENT, public_key, log_context);
}
if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key)) {
return m_auth_mapper.getUID(PublicKeyType::SESSION, public_key);
if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key, log_context)) {
return m_auth_mapper.getUID(PublicKeyType::SESSION, public_key, log_context);
}
if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key)) {
return m_auth_mapper.getUID(PublicKeyType::PERSISTENT, public_key);
if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key, log_context)) {
return m_auth_mapper.getUID(PublicKeyType::PERSISTENT, public_key, log_context);
}

EXCEPT(1, "Unrecognized public_key during execution of getUID.");
Expand All @@ -122,9 +124,10 @@ void AuthenticationManager::addKey(const PublicKeyType &pub_key_type,
}

bool AuthenticationManager::hasKey(const PublicKeyType &pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);
return m_auth_mapper.hasKey(pub_key_type, public_key);
return m_auth_mapper.hasKey(pub_key_type, public_key, log_context);
}

void AuthenticationManager::migrateKey(const PublicKeyType &from_type,
Expand All @@ -150,21 +153,21 @@ void AuthenticationManager::clearAllNonPersistentKeys() {
m_auth_mapper.clearAllNonPersistentKeys();
}

std::string AuthenticationManager::getUIDSafe(const std::string &public_key) const {
std::string AuthenticationManager::getUIDSafe(const std::string &public_key, LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);

// Try each key type in order
std::string uid = m_auth_mapper.getUIDSafe(PublicKeyType::TRANSIENT, public_key);
std::string uid = m_auth_mapper.getUIDSafe(PublicKeyType::TRANSIENT, public_key, log_context);
if (!uid.empty()) {
return uid;
}

uid = m_auth_mapper.getUIDSafe(PublicKeyType::SESSION, public_key);
uid = m_auth_mapper.getUIDSafe(PublicKeyType::SESSION, public_key, log_context);
if (!uid.empty()) {
return uid;
}

uid = m_auth_mapper.getUIDSafe(PublicKeyType::PERSISTENT, public_key);
uid = m_auth_mapper.getUIDSafe(PublicKeyType::PERSISTENT, public_key, log_context);
if (!uid.empty()) {
return uid;
}
Expand Down
10 changes: 5 additions & 5 deletions core/server/AuthenticationManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class AuthenticationManager : public IAuthenticationManager {
*allotted purge time frame. If the count is above one then the session key
*not be purged.
**/
virtual void incrementKeyAccessCounter(const std::string &public_key) final;
virtual void incrementKeyAccessCounter(const std::string &public_key, LogContext log_context) final;

/**
* This will purge all keys of a particular type that have expired.
Comment on lines +57 to 60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Consider passing LogContext by const reference instead of by value through the authentication APIs.

All public AuthenticationManager / IAuthenticationManager methods now take LogContext by value and forward it through to AuthMap and DatabaseAPI. To avoid repeated copies as LogContext grows, you could change these to accept const LogContext& instead and propagate that through the implementation and call sites, e.g.:

virtual void incrementKeyAccessCounter(const std::string &public_key, const LogContext &log_context) final;
virtual bool hasKey(const std::string &pub_key, const LogContext &log_context) const final;
virtual std::string getUID(const std::string &pub_key, const LogContext &log_context) const final;
std::string getUIDSafe(const std::string &pub_key, const LogContext &log_context) const;

Suggested implementation:

   *allotted purge time frame. If the count is above one then the session key
   *not be purged.
   **/
  virtual void incrementKeyAccessCounter(const std::string &public_key, const LogContext &log_context) final;

  /**
   * This will purge all keys of a particular type that have expired.
   * - SESSION
   * - PERSISTENT
   **/
  virtual bool hasKey(const std::string &pub_key, const LogContext &log_context) const final;

To fully implement the suggestion, you should also:

  1. Update all other AuthenticationManager / IAuthenticationManager method declarations in this header that currently take LogContext by value to instead take const LogContext & (e.g. getUID, getUIDSafe, and any similar methods).
  2. Update the corresponding method definitions in the .cpp files (e.g. AuthenticationManager.cpp, AuthMap.cpp, DatabaseAPI implementations) to match the new signatures (const LogContext &).
  3. Adjust all call sites that pass a LogContext to these methods so they pass by reference; typically no call-site syntax change is required, but ensure there are no temporary rvalues that would now bind to a const LogContext &.
  4. If there are overridden methods in derived classes, make sure their signatures are also updated to use const LogContext & so they still correctly override the interface.

Expand All @@ -79,15 +79,15 @@ class AuthenticationManager : public IAuthenticationManager {
* - SESSION
* - PERSISTENT
**/
virtual bool hasKey(const std::string &pub_key) const final;
virtual bool hasKey(const std::string &pub_key, LogContext log_context) const final;

void addKey(const PublicKeyType &pub_key_type, const std::string &public_key,
const std::string &uid);

/**
* Check if a specific key exists in a specific map type
**/
bool hasKey(const PublicKeyType &pub_key_type, const std::string &public_key) const;
bool hasKey(const PublicKeyType &pub_key_type, const std::string &public_key, LogContext log_context) const;

/**
* Migrate a key from one type to another
Expand Down Expand Up @@ -121,13 +121,13 @@ class AuthenticationManager : public IAuthenticationManager {
* - SESSION
* - PERSISTENT
**/
virtual std::string getUID(const std::string &pub_key) const final;
virtual std::string getUID(const std::string &pub_key, LogContext log_context) const final;

/**
* Safe version that returns empty string if key not found
* instead of throwing an exception
**/
std::string getUIDSafe(const std::string &pub_key) const;
std::string getUIDSafe(const std::string &pub_key, LogContext log_context) const;
};

} // namespace Core
Expand Down
11 changes: 10 additions & 1 deletion core/server/Condition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,26 @@

// Standard includes
#include <iostream>
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>

namespace SDMS {
namespace Core {

void Promote::enforce(AuthMap &auth_map, const std::string &public_key) {
if (auth_map.hasKeyType(m_promote_from, public_key)) {
size_t access_count = auth_map.getAccessCount(m_promote_from, public_key);
boost::uuids::random_generator generator;
boost::uuids::uuid uuid = generator();

LogContext log_context;
log_context.correlation_id = boost::uuids::to_string(uuid);

if (access_count >= m_transient_to_session_count_threshold) {
// Convert transient key to session key if has been accessed more than the
// threshold
std::string uid = auth_map.getUID(m_promote_from, public_key);
std::string uid = auth_map.getUID(m_promote_from, public_key, log_context);
auth_map.addKey(m_promote_to, public_key, uid);
}
// Remove expired short lived transient key
Expand Down
Loading