diff --git a/pgpasslib.py b/pgpasslib.py index 3cfd221..84c3cb6 100644 --- a/pgpasslib.py +++ b/pgpasslib.py @@ -38,13 +38,11 @@ DEFAULT_HOST = 'localhost' DEFAULT_PORT = 5432 DEFAULT_USER = stdlib_getpass.getuser() -DEFAULT_DBNAME = DEFAULT_USER PATTERN = re.compile(r'^(.*):(.*):(.*):(.*):(.*)$', re.MULTILINE) -def getpass(host=DEFAULT_HOST, port=DEFAULT_PORT, dbname=DEFAULT_DBNAME, - user=DEFAULT_USER): +def getpass(host, port, dbname, user): """Return the password for the specified host, port, dbname and user. :py:const:`None` will be returned if a password can not be found for the specified connection parameters. @@ -70,6 +68,19 @@ def getpass(host=DEFAULT_HOST, port=DEFAULT_PORT, dbname=DEFAULT_DBNAME, :raises: InvalidEntry """ + + if not host: + host = DEFAULT_HOST + + if not port: + port = DEFAULT_PORT + + if not user: + user = DEFAULT_USER + + if not dbname: + dbname = user + if not isinstance(port, int): port = int(port) for entry in _get_entries(): diff --git a/tests.py b/tests.py index f4f68cd..c0f8dc1 100644 --- a/tests.py +++ b/tests.py @@ -204,6 +204,18 @@ def test_no_match_on_host(self): self.assertFalse(self.entry.match('foo', 6000, 'bar', 'qux')) +class GetpassDefaultsTest(unittest.TestCase): + + def setUp(self): + self.entry = pgpasslib._Entry('localhost', 5432, '*', 'mydb', 'mypass') + + def test_match_on_port_none(self): + self.assertTrue(self.entry.match('localhost', None, 'mydb', 'mydb')) + + def test_match_on_default_db(self): + self.assertTrue(self.entry.match('localhost', None, None, 'mydb')) + + class GetPassMatch1Test(unittest.TestCase): def test_getpass_returns_expected_result(self):