From d802b0202a8a1a6aa12043a787e24378609510c4 Mon Sep 17 00:00:00 2001 From: Wesley Schwengle Date: Sat, 28 May 2022 13:39:48 -0400 Subject: [PATCH] Set defaults for host, port, dbname and user later in the invocation The problem is that if one uses sqlalchemy or some other tool that uses some kind of object that returns None for a value, eg the port you need to reimplement a lot of the code that is in this module already. from sqlalchemy.engine.url import make_url uri = make_url('postgres://localhost/foo'); # yields an error that a NoneType cannot be used for int(); password = getpass(uri.host, uri.port, uri.username, uri.database); In order to fix this a lot of the defaults need to be implemented at the caller level. Furthermore, the default database name is the name of the user that is supplied to the function, eg, a user 'foo' will default to the database 'foo', as seen in behaviour by `psql -U foo` where 'foo' isn't your username of the OS. Signed-off-by: Wesley Schwengle --- pgpasslib.py | 17 ++++++++++++++--- tests.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/pgpasslib.py b/pgpasslib.py index 3cfd221..84c3cb6 100644 --- a/pgpasslib.py +++ b/pgpasslib.py @@ -38,13 +38,11 @@ DEFAULT_HOST = 'localhost' DEFAULT_PORT = 5432 DEFAULT_USER = stdlib_getpass.getuser() -DEFAULT_DBNAME = DEFAULT_USER PATTERN = re.compile(r'^(.*):(.*):(.*):(.*):(.*)$', re.MULTILINE) -def getpass(host=DEFAULT_HOST, port=DEFAULT_PORT, dbname=DEFAULT_DBNAME, - user=DEFAULT_USER): +def getpass(host, port, dbname, user): """Return the password for the specified host, port, dbname and user. :py:const:`None` will be returned if a password can not be found for the specified connection parameters. @@ -70,6 +68,19 @@ def getpass(host=DEFAULT_HOST, port=DEFAULT_PORT, dbname=DEFAULT_DBNAME, :raises: InvalidEntry """ + + if not host: + host = DEFAULT_HOST + + if not port: + port = DEFAULT_PORT + + if not user: + user = DEFAULT_USER + + if not dbname: + dbname = user + if not isinstance(port, int): port = int(port) for entry in _get_entries(): diff --git a/tests.py b/tests.py index f4f68cd..c0f8dc1 100644 --- a/tests.py +++ b/tests.py @@ -204,6 +204,18 @@ def test_no_match_on_host(self): self.assertFalse(self.entry.match('foo', 6000, 'bar', 'qux')) +class GetpassDefaultsTest(unittest.TestCase): + + def setUp(self): + self.entry = pgpasslib._Entry('localhost', 5432, '*', 'mydb', 'mypass') + + def test_match_on_port_none(self): + self.assertTrue(self.entry.match('localhost', None, 'mydb', 'mydb')) + + def test_match_on_default_db(self): + self.assertTrue(self.entry.match('localhost', None, None, 'mydb')) + + class GetPassMatch1Test(unittest.TestCase): def test_getpass_returns_expected_result(self):