diff --git a/pgpasslib.py b/pgpasslib.py index 3cfd221..7f92ee1 100644 --- a/pgpasslib.py +++ b/pgpasslib.py @@ -26,7 +26,7 @@ import sys import platform -__version__ = '1.1.0' +__version__ = '1.1.1' LOGGER = logging.getLogger(__name__) @@ -78,6 +78,44 @@ def getpass(host=DEFAULT_HOST, port=DEFAULT_PORT, dbname=DEFAULT_DBNAME, return None +def getconnectionstring( + host=DEFAULT_HOST, + port=DEFAULT_PORT, + dbname=DEFAULT_DBNAME, + user=DEFAULT_USER, +): + """Return the connection string for the specified host, port, dbname and user. + :py:const:`None` will be returned if a password can not be found for the + specified connection parameters. + + :param str host: PostgreSQL hostname + :param port: PostgreSQL port + :type port: int or str + :param str dbname: Database name + :param str user: Database role/user + :rtype: str + :raises: FileNotFound + :raises: InvalidPermissions + :raises: InvalidEntry + + """ + password=getpass( + host=host, + port=port, + dbname=dbname, + user=user, + ) + if password: + return 'postgresql://{user}:{password}@{host}:{port}/{dbname}'.format( + host=host, + port=port, + dbname=dbname, + user=user, + password=password + ) + return '' + + class PgPassException(Exception): """Base exception for all pgpasslib exceptions""" MESSAGE = 'Base Exception: {}' diff --git a/tests.py b/tests.py index f4f68cd..516f867 100644 --- a/tests.py +++ b/tests.py @@ -2,7 +2,10 @@ Tests for pgpasslib """ -import mock +try: + import mock +except ImportError: + from unittest import mock import os from os import path import stat @@ -268,3 +271,40 @@ class InvalidPermissionsExceptionStrFormatting(unittest.TestCase): def test_str_matches_expectation(self): self.assertEqual(str(pgpasslib.InvalidPermissions('foo', '0x000')), 'Invalid Permissions for foo: 0x000') + + +class GetConnectionStringMatch1Test(unittest.TestCase): + + def test_getconnectionstring_returns_expected_result(self): + with mock.patch('pgpasslib._read_file') as read_file: + read_file.return_value = MOCK_CONTENT + self.assertEqual(pgpasslib.getconnectionstring('localhost', 5432, + 'foo', 'kermit'), '') + + +class GetConnectionStringMatch2Test(unittest.TestCase): + + def test_getconnectionstring_returns_expected_result(self): + with mock.patch('pgpasslib._read_file') as read_file: + read_file.return_value = MOCK_CONTENT + self.assertEqual(pgpasslib.getconnectionstring('bouncer', 6000, + 'bumpers', 'rubber'), 'postgresql://rubber:buggy@bouncer:6000/bumpers') + + +class GetConnectionStringMatch3Test(unittest.TestCase): + + def test_getconnectionstring_returns_expected_result(self): + with mock.patch('pgpasslib._read_file') as read_file: + read_file.return_value = MOCK_CONTENT + self.assertEqual(pgpasslib.getconnectionstring('foo.abjdite.us-east-1.' + 'redshift.amazonaws.com', 5439, + 'redshift', 'fonzy'), 'postgresql://fonzy:b3ar@foo.abjdite.us-east-1.redshift.amazonaws.com:5439/redshift') + + +class GetConnectionStringMatch4Test(unittest.TestCase): + + def test_getconnectionstring_returns_expected_result(self): + with mock.patch('pgpasslib._read_file') as read_file: + read_file.return_value = MOCK_CONTENT + self.assertEqual(pgpasslib.getconnectionstring('foo:bar', '6000', + 'corgie', 'baz'), 'postgresql://baz:qux@foo:bar:6000/corgie')