Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions UI/auth-base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,20 @@ void Auth::Save()
auth->SaveInternal();
config_save_safe(main->Config(), "tmp", nullptr);
}

void Auth::Delete()
{
OBSBasic *main = OBSBasic::Get();
Auth *auth = main->auth.get();
if (!auth) {
if (config_has_user_value(main->Config(), "Auth", "Type")) {
config_remove_value(main->Config(), "Auth", "Type");
config_save_safe(main->Config(), "tmp", nullptr);
}
return;
}

config_remove_value(main->Config(), "Auth", "type");
auth->DeleteInternal();
config_save_safe(main->Config(), "tmp", nullptr);
}
2 changes: 2 additions & 0 deletions UI/auth-base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Auth : public QObject {
protected:
virtual void SaveInternal() = 0;
virtual bool LoadInternal() = 0;
virtual void DeleteInternal() = 0;

bool firstLoad = true;

Expand Down Expand Up @@ -56,6 +57,7 @@ class Auth : public QObject {
static bool External(const std::string &service);
static void Load();
static void Save();
static void Delete();

protected:
static void RegisterAuth(const Def &d, create_cb create);
Expand Down
177 changes: 165 additions & 12 deletions UI/auth-oauth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <QPushButton>
#include <QHBoxLayout>
#include <QVBoxLayout>
#include <QRandomGenerator>

#include <qt-wrappers.hpp>
#include <obs-app.hpp>
Expand Down Expand Up @@ -158,14 +159,21 @@ void OAuth::DeleteCookies(const std::string &service)
}
}

void OAuth::SaveInternal()
static constexpr char hexChars[] = "abcdef0123456789";
static constexpr int hexCount = sizeof(hexChars) - 1;
static constexpr int kSuffixLength = 8;

static std::string GetRandSuffix()
{
OBSBasic *main = OBSBasic::Get();
config_set_string(main->Config(), service(), "RefreshToken",
refresh_token.c_str());
config_set_string(main->Config(), service(), "Token", token.c_str());
config_set_uint(main->Config(), service(), "ExpireTime", expire_time);
config_set_int(main->Config(), service(), "ScopeVer", currentScopeVer);
char state[kSuffixLength + 1];
QRandomGenerator *rng = QRandomGenerator::system();
int i;

for (i = 0; i < kSuffixLength; i++)
state[i] = hexChars[rng->bounded(0, hexCount)];
state[i] = 0;

return state;
}

static inline std::string get_config_str(OBSBasic *main, const char *section,
Expand All @@ -175,17 +183,101 @@ static inline std::string get_config_str(OBSBasic *main, const char *section,
return val ? val : "";
}

void OAuth::SaveInternal()
{
OBSBasic *main = OBSBasic::Get();

bool keychain_success = false;

if (!App()->IsPortableMode()) {
std::string keychain_key =
get_config_str(main, service(), "KeychainItem");

if (keychain_key.empty()) {
keychain_key = service();
keychain_key += "::" + GetRandSuffix();
}

Json data = Json::object{{"refresh_token", refresh_token},
{"token", token},
{"expire_time",
static_cast<int>(expire_time)},
{"scope_ver", currentScopeVer}};
std::string json = data.dump();

if (os_keychain_save(GetKeychainLabel(), keychain_key.c_str(),
json.c_str())) {
config_set_string(main->Config(), service(),
"KeychainItem", keychain_key.c_str());
keychain_success = true;
}
}

if (!keychain_success) {
config_set_string(main->Config(), service(), "RefreshToken",
refresh_token.c_str());
config_set_string(main->Config(), service(), "Token",
token.c_str());
config_set_uint(main->Config(), service(), "ExpireTime",
expire_time);
config_set_int(main->Config(), service(), "ScopeVer",
currentScopeVer);
}
}

bool OAuth::LoadInternal()
{
OBSBasic *main = OBSBasic::Get();
refresh_token = get_config_str(main, service(), "RefreshToken");
token = get_config_str(main, service(), "Token");
expire_time = config_get_uint(main->Config(), service(), "ExpireTime");
currentScopeVer =
(int)config_get_int(main->Config(), service(), "ScopeVer");

bool keychain_success = false;

if (!App()->IsPortableMode()) {
const char *keychain_key = config_get_string(
main->Config(), service(), "KeychainItem");

BPtr<char> data;
if (keychain_key &&
os_keychain_load(GetKeychainLabel(), keychain_key, &data) &&
data) {
std::string err;
Json parsed = Json::parse(data, err);
if (err.empty()) {
refresh_token =
parsed["refresh_token"].string_value();
token = parsed["token"].string_value();
expire_time = parsed["expire_time"].int_value();
currentScopeVer =
parsed["scope_ver"].int_value();
keychain_success = true;
}
}
}

if (!keychain_success) {
refresh_token = get_config_str(main, service(), "RefreshToken");
token = get_config_str(main, service(), "Token");
expire_time = config_get_uint(main->Config(), service(),
"ExpireTime");
currentScopeVer = (int)config_get_int(main->Config(), service(),
"ScopeVer");
}

return implicit ? !token.empty() : !refresh_token.empty();
}

void OAuth::DeleteInternal()
{
OBSBasic *main = OBSBasic::Get();

/* Delete keychain item (if it exists) */
os_keychain_delete(GetKeychainLabel(),
config_get_string(main->Config(), service(),
"KeychainItem"));

/* Delete OAuth config section */
config_remove_section(main->Config(), service());
}

bool OAuth::TokenExpired()
{
if (token.empty())
Expand Down Expand Up @@ -321,6 +413,67 @@ try {
return false;
}

bool OAuth::InvalidateToken(const char *url)
{
return InvalidateTokenInternal(url, "", true);
}

bool OAuth::InvalidateToken(const char *url, const std::string &client_id)
{
return InvalidateTokenInternal(url, client_id);
}

bool OAuth::InvalidateTokenInternal(const char *base_url,
const std::string &client_id,
const bool token_as_parameter)
try {
std::string url(base_url);
std::string output;
std::string error;
std::string desc;
std::string post_data;

if (token.empty()) {
return true;
}

/* Google wants the token as a parameter, but still wants us to POST... */
if (token_as_parameter) {
url += "?token=" + token;
} else {
post_data += "token=" + token;
}

/* Only required for Twitch as far as I can tell */
if (!client_id.empty()) {
post_data += "&client_id=" + client_id;
}

bool success = false;

auto func = [&]() {
success = GetRemoteFile(url.c_str(), output, error, nullptr,
"application/x-www-form-urlencoded",
"POST", post_data.c_str(),
std::vector<std::string>(), nullptr, 5,
false);
};

ExecThreadedWithoutBlocking(func, QTStr("Auth.Revoking.Title"),
QTStr("Auth.Revoking.Text").arg(service()));
if (!success)
throw ErrorInfo("Failed to revoke token", error);

/* We don't really care about the result here, just assume it either
* succeeded or didn't matter. */
return true;

} catch (ErrorInfo &info) {
blog(LOG_WARNING, "%s: %s: %s", __FUNCTION__, info.message.c_str(),
info.error.c_str());
return false;
}

void OAuthStreamKey::OnStreamConfig()
{
if (key_.empty())
Expand Down
12 changes: 12 additions & 0 deletions UI/auth-oauth.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class OAuth : public Auth {

virtual void SaveInternal() override;
virtual bool LoadInternal() override;
virtual void DeleteInternal() override;

virtual bool RetryLogin() = 0;
bool TokenExpired();
Expand All @@ -70,12 +71,23 @@ class OAuth : public Auth {
const std::string &secret,
const std::string &redirect_uri, int scope_ver,
const std::string &auth_code, bool retry);
bool InvalidateToken(const char *url);
bool InvalidateToken(const char *url, const std::string &client_id);

static const char *GetKeychainLabel()
{
return "OBS Studio OAuth Credentials";
}

private:
bool GetTokenInternal(const char *url, const std::string &client_id,
const std::string &secret,
const std::string &redirect_uri, int scope_ver,
const std::string &auth_code, bool retry);

bool InvalidateTokenInternal(const char *base_url,
const std::string &client_id,
bool token_as_parameter = false);
};

class OAuthStreamKey : public OAuth {
Expand Down
7 changes: 7 additions & 0 deletions UI/auth-restream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using namespace json11;

#define RESTREAM_AUTH_URL OAUTH_BASE_URL "v1/restream/redirect"
#define RESTREAM_TOKEN_URL OAUTH_BASE_URL "v1/restream/token"
#define RESTREAM_REVOKE_URL "https://api.restream.io/oauth/revoke"
#define RESTREAM_STREAMKEY_URL "https://api.restream.io/v2/user/streamKey"
#define RESTREAM_SCOPE_VERSION 1

Expand Down Expand Up @@ -129,6 +130,12 @@ bool RestreamAuth::LoadInternal()
return OAuthStreamKey::LoadInternal();
}

void RestreamAuth::DeleteInternal()
{
InvalidateToken(RESTREAM_REVOKE_URL, "");
OAuthStreamKey::DeleteInternal();
}

void RestreamAuth::LoadUI()
{
if (uiLoaded)
Expand Down
1 change: 1 addition & 0 deletions UI/auth-restream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class RestreamAuth : public OAuthStreamKey {

virtual void SaveInternal() override;
virtual bool LoadInternal() override;
virtual void DeleteInternal() override;

bool GetChannelInfo();

Expand Down
10 changes: 10 additions & 0 deletions UI/auth-twitch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using namespace json11;

#define TWITCH_AUTH_URL OAUTH_BASE_URL "v1/twitch/redirect"
#define TWITCH_TOKEN_URL OAUTH_BASE_URL "v1/twitch/token"
#define TWITCH_REOKVE_URL "https://id.twitch.tv/oauth2/revoke"

#define TWITCH_SCOPE_VERSION 1

Expand Down Expand Up @@ -499,6 +500,15 @@ std::shared_ptr<Auth> TwitchAuth::Login(QWidget *parent, const std::string &)
return nullptr;
}

void TwitchAuth::DeleteInternal()
{
std::string client_id = TWITCH_CLIENTID;
deobfuscate_str(&client_id[0], TWITCH_HASH);

InvalidateToken(TWITCH_REOKVE_URL, client_id);
OAuthStreamKey::DeleteInternal();
}

static std::shared_ptr<Auth> CreateTwitchAuth()
{
return std::make_shared<TwitchAuth>(twitchDef);
Expand Down
1 change: 1 addition & 0 deletions UI/auth-twitch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class TwitchAuth : public OAuthStreamKey {

virtual void SaveInternal() override;
virtual bool LoadInternal() override;
virtual void DeleteInternal() override;

bool MakeApiRequest(const char *path, json11::Json &json_out);
bool GetChannelInfo();
Expand Down
Loading