From 4829459b60e4b38d7981038673164ec77b8b1bdc Mon Sep 17 00:00:00 2001 From: erez Date: Wed, 17 Apr 2019 16:08:29 +0200 Subject: [PATCH] add the get user function --- pgpasslib.py | 33 ++++++++++++++++++++++++++++++++- tests.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/pgpasslib.py b/pgpasslib.py index 3cfd221..d09a9ff 100644 --- a/pgpasslib.py +++ b/pgpasslib.py @@ -42,6 +42,37 @@ PATTERN = re.compile(r'^(.*):(.*):(.*):(.*):(.*)$', re.MULTILINE) +def getuser(host=DEFAULT_HOST, port=DEFAULT_PORT, dbname=DEFAULT_DBNAME): + """Return the first matching user for the specified host, port, and dbname. + :py:const:`None` will be returned if a user can not be found for the + specified connection parameters. + + If the password file can not be located, a :py:class:`FileNotFound` + exception will be raised. + + If the password file is group or world readable, the file will not be read, + per the specification, and a :py:class:`InvalidPermissions` exception will + be raised. + + If an entry in the password file is not parsable, a + :py:class:`InvalidPermissions` exception will be raised. + + :param str host: PostgreSQL hostname + :param port: PostgreSQL port + :type port: int or str + :param str dbname: Database name + :rtype: str + :raises: FileNotFound + :raises: InvalidPermissions + :raises: InvalidEntry + + """ + if not isinstance(port, int): + port = int(port) + for entry in _get_entries(): + if entry.match(host, port, dbname, None): + return entry.user + return None def getpass(host=DEFAULT_HOST, port=DEFAULT_PORT, dbname=DEFAULT_DBNAME, user=DEFAULT_USER): @@ -147,7 +178,7 @@ def match(self, host, port, dbname, user): return all([any([self.host == '*', self.host == host]), any([self.port == '*', self.port == port]), any([self.dbname == '*', self.dbname == dbname]), - any([self.user == '*', self.user == user])]) + any([user is None and self.user != '*', self.user == '*', self.user == user])]) @staticmethod def _sanitize_port(value): diff --git a/tests.py b/tests.py index f4f68cd..8da5203 100644 --- a/tests.py +++ b/tests.py @@ -248,6 +248,50 @@ def test_getpass_returns_expected_result(self): read_file.return_value = MOCK_CONTENT self.assertIsNone(pgpasslib.getpass('fail', '5432', 'foo', 'bar')) +class GetUserMatch1Test(unittest.TestCase): + + def test_getuser_returns_expected_result(self): + with mock.patch('pgpasslib._read_file') as read_file: + read_file.return_value = MOCK_CONTENT + self.assertEqual(pgpasslib.getuser('localhost', 5432, + 'foo'), 'kermit') + + +class GetUserMatch2Test(unittest.TestCase): + + def test_getuser_returns_expected_result(self): + with mock.patch('pgpasslib._read_file') as read_file: + read_file.return_value = MOCK_CONTENT + self.assertEqual(pgpasslib.getuser('bouncer', 6000, + 'bumpers'), 'rubber') + + +class GetUserMatch3Test(unittest.TestCase): + + def test_getuser_returns_expected_result(self): + with mock.patch('pgpasslib._read_file') as read_file: + read_file.return_value = MOCK_CONTENT + self.assertEqual(pgpasslib.getuser('foo.abjdite.us-east-1.' + 'redshift.amazonaws.com', 5439, + 'redshift'), 'fonzy') + + +class GetUserMatch4Test(unittest.TestCase): + + def test_getuser_returns_expected_result(self): + with mock.patch('pgpasslib._read_file') as read_file: + read_file.return_value = MOCK_CONTENT + self.assertEqual(pgpasslib.getuser('foo:bar', '6000', + 'corgie'), 'baz') + + +class GetUserNoMatchTest(unittest.TestCase): + + def test_getuser_returns_expected_result(self): + with mock.patch('pgpasslib._read_file') as read_file: + read_file.return_value = MOCK_CONTENT + self.assertIsNone(pgpasslib.getuser('fail', '5432', 'foo')) + class FileNotFoundStrFormatting(unittest.TestCase):