From f174fc0c00b6fafe41172a26a1ca3cb85df61c58 Mon Sep 17 00:00:00 2001 From: Calen Pennington Date: Fri, 17 Apr 2015 08:26:44 -0400 Subject: [PATCH 1/8] Limit the range of pep8 to versions compatible with the latest flake8 --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index b8c1edf..2103baa 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ envlist = py27, py33, py34 [testenv] deps = nose - pep8 + pep8>=1.5.7,<1.6 # Versions limited to work with flake8 coverage flake8 six From eecee22acf1cb7f3435a0558a7d225a21b0350b6 Mon Sep 17 00:00:00 2001 From: Calen Pennington Date: Wed, 4 Mar 2015 08:32:25 -0500 Subject: [PATCH 2/8] Extract value storage into classes --- ddt.py | 12 ++++++++---- test/test_functional.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/ddt.py b/ddt.py index a084695..95b5c3c 100644 --- a/ddt.py +++ b/ddt.py @@ -11,6 +11,7 @@ import re import sys from functools import wraps +from collections import namedtuple __version__ = '1.0.0' @@ -40,7 +41,7 @@ def data(*values): """ def wrapper(func): - setattr(func, DATA_ATTR, values) + setattr(func, DATA_ATTR, DataValues(values)) return func return wrapper @@ -65,11 +66,14 @@ def file_data(value): """ def wrapper(func): - setattr(func, FILE_ATTR, value) + setattr(func, FILE_ATTR, FileValues(value)) return func return wrapper +DataValues = namedtuple("DataValues", ["values"]) +FileValues = namedtuple("FileValues", ["file_path"]) + def is_hash_randomized(): return (((sys.hexversion >= 0x02070300 and sys.hexversion < 0x03000000) or @@ -211,7 +215,7 @@ def ddt(cls): """ for name, func in list(cls.__dict__.items()): if hasattr(func, DATA_ATTR): - for i, v in enumerate(getattr(func, DATA_ATTR)): + for i, v in enumerate(getattr(func, DATA_ATTR).values): test_name = mk_test_name(name, getattr(v, "__name__", v), i) if hasattr(func, UNPACK_ATTR): if isinstance(v, tuple) or isinstance(v, list): @@ -224,6 +228,6 @@ def ddt(cls): delattr(cls, name) elif hasattr(func, FILE_ATTR): file_attr = getattr(func, FILE_ATTR) - process_file_data(cls, name, func, file_attr) + process_file_data(cls, name, func, file_attr.file_path) delattr(cls, name) return cls diff --git a/test/test_functional.py b/test/test_functional.py index b0e8cfd..aa1ae4f 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -3,7 +3,9 @@ import six -from ddt import ddt, data, file_data, is_hash_randomized +from ddt import ( + ddt, data, file_data, is_hash_randomized, DataValues +) from nose.tools import assert_equal, assert_is_not_none, assert_raises @@ -71,7 +73,7 @@ def hello(): extra_attrs = dh_keys - keys assert_equal(len(extra_attrs), 1) extra_attr = extra_attrs.pop() - assert_equal(getattr(data_hello, extra_attr), (1, 2)) + assert_equal(getattr(data_hello, extra_attr), DataValues((1, 2))) def test_file_data_decorator_with_dict(): @@ -93,7 +95,10 @@ def hello(): extra_attrs = dh_keys - keys assert_equal(len(extra_attrs), 1) extra_attr = extra_attrs.pop() - assert_equal(getattr(data_hello, extra_attr), ("test_data_dict.json",)) + assert_equal( + getattr(data_hello, extra_attr), + DataValues(("test_data_dict.json",)) + ) is_test = lambda x: x.startswith('test_') From a6e0fc2c0cf47d28775901ddd2aa0c1ad61cae7c Mon Sep 17 00:00:00 2001 From: Calen Pennington Date: Wed, 4 Mar 2015 08:45:08 -0500 Subject: [PATCH 3/8] Extract the add_tests method onto the value holding objects --- ddt.py | 82 ++++++++++++++++++++++++++++++---------------------------- 1 file changed, 42 insertions(+), 40 deletions(-) diff --git a/ddt.py b/ddt.py index 95b5c3c..2863cc8 100644 --- a/ddt.py +++ b/ddt.py @@ -71,8 +71,46 @@ def wrapper(func): return wrapper -DataValues = namedtuple("DataValues", ["values"]) -FileValues = namedtuple("FileValues", ["file_path"]) +class DataValues(namedtuple("DataValues", ["values"])): + + def add_tests(self, cls, name, func): + for i, v in enumerate(self.values): + test_name = mk_test_name(name, getattr(v, "__name__", v), i) + if hasattr(func, UNPACK_ATTR): + if isinstance(v, tuple) or isinstance(v, list): + add_test(cls, test_name, func, *v) + else: + # unpack dictionary + add_test(cls, test_name, func, **v) + else: + add_test(cls, test_name, func, v) + + +class FileValues(namedtuple("FileValues", ["file_path"])): + def add_tests(self, cls, name, func): + cls_path = os.path.abspath(inspect.getsourcefile(cls)) + data_file_path = os.path.join( + os.path.dirname(cls_path), + self.file_path + ) + + def _raise_ve(*args): # pylint: disable-msg=W0613 + raise ValueError("%s does not exist" % self.file_path) + + if os.path.exists(data_file_path) is False: + test_name = mk_test_name(name, "error") + add_test(cls, test_name, _raise_ve, None) + else: + data = json.loads(open(data_file_path).read()) + for i, elem in enumerate(data): + if isinstance(data, dict): + key, value = elem, data[elem] + test_name = mk_test_name(name, key, i) + elif isinstance(data, list): + value = elem + test_name = mk_test_name(name, value, i) + add_test(cls, test_name, func, value) + def is_hash_randomized(): return (((sys.hexversion >= 0x02070300 and @@ -163,32 +201,6 @@ def add_test(cls, test_name, func, *args, **kwargs): setattr(cls, test_name, feed_data(func, test_name, *args, **kwargs)) -def process_file_data(cls, name, func, file_attr): - """ - Process the parameter in the `file_data` decorator. - - """ - cls_path = os.path.abspath(inspect.getsourcefile(cls)) - data_file_path = os.path.join(os.path.dirname(cls_path), file_attr) - - def _raise_ve(*args): # pylint: disable-msg=W0613 - raise ValueError("%s does not exist" % file_attr) - - if os.path.exists(data_file_path) is False: - test_name = mk_test_name(name, "error") - add_test(cls, test_name, _raise_ve, None) - else: - data = json.loads(open(data_file_path).read()) - for i, elem in enumerate(data): - if isinstance(data, dict): - key, value = elem, data[elem] - test_name = mk_test_name(name, key, i) - elif isinstance(data, list): - value = elem - test_name = mk_test_name(name, value, i) - add_test(cls, test_name, func, value) - - def ddt(cls): """ Class decorator for subclasses of ``unittest.TestCase``. @@ -215,19 +227,9 @@ def ddt(cls): """ for name, func in list(cls.__dict__.items()): if hasattr(func, DATA_ATTR): - for i, v in enumerate(getattr(func, DATA_ATTR).values): - test_name = mk_test_name(name, getattr(v, "__name__", v), i) - if hasattr(func, UNPACK_ATTR): - if isinstance(v, tuple) or isinstance(v, list): - add_test(cls, test_name, func, *v) - else: - # unpack dictionary - add_test(cls, test_name, func, **v) - else: - add_test(cls, test_name, func, v) + getattr(func, DATA_ATTR).add_tests(cls, name, func) delattr(cls, name) elif hasattr(func, FILE_ATTR): - file_attr = getattr(func, FILE_ATTR) - process_file_data(cls, name, func, file_attr.file_path) + getattr(func, FILE_ATTR).add_tests(cls, name, func) delattr(cls, name) return cls From a66d954adf9ddad2ecfbdcf4c323ed0dae00249e Mon Sep 17 00:00:00 2001 From: Calen Pennington Date: Wed, 4 Mar 2015 08:46:08 -0500 Subject: [PATCH 4/8] Remove the distinct attributes for FileValuse and DataValues --- ddt.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ddt.py b/ddt.py index 2863cc8..3fc6e03 100644 --- a/ddt.py +++ b/ddt.py @@ -20,7 +20,6 @@ # by the `ddt` class decorator. DATA_ATTR = '%values' # store the data the test must run with -FILE_ATTR = '%file_path' # store the path to JSON file UNPACK_ATTR = '%unpack' # remember that we have to unpack values @@ -66,7 +65,7 @@ def file_data(value): """ def wrapper(func): - setattr(func, FILE_ATTR, FileValues(value)) + setattr(func, DATA_ATTR, FileValues(value)) return func return wrapper @@ -229,7 +228,4 @@ def ddt(cls): if hasattr(func, DATA_ATTR): getattr(func, DATA_ATTR).add_tests(cls, name, func) delattr(cls, name) - elif hasattr(func, FILE_ATTR): - getattr(func, FILE_ATTR).add_tests(cls, name, func) - delattr(cls, name) return cls From 7b0930692c9f88a5145743f3f6238c9369d54dd4 Mon Sep 17 00:00:00 2001 From: Calen Pennington Date: Wed, 4 Mar 2015 09:05:43 -0500 Subject: [PATCH 5/8] Extract test value generation from class modification --- ddt.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/ddt.py b/ddt.py index 3fc6e03..ae360aa 100644 --- a/ddt.py +++ b/ddt.py @@ -72,21 +72,21 @@ def wrapper(func): class DataValues(namedtuple("DataValues", ["values"])): - def add_tests(self, cls, name, func): + def generate_tests(self, cls, name, func): for i, v in enumerate(self.values): test_name = mk_test_name(name, getattr(v, "__name__", v), i) if hasattr(func, UNPACK_ATTR): if isinstance(v, tuple) or isinstance(v, list): - add_test(cls, test_name, func, *v) + yield (test_name, func, v, {}) else: # unpack dictionary - add_test(cls, test_name, func, **v) + yield (test_name, func, [], v) else: - add_test(cls, test_name, func, v) + yield (test_name, func, [v], {}) class FileValues(namedtuple("FileValues", ["file_path"])): - def add_tests(self, cls, name, func): + def generate_tests(self, cls, name, func): cls_path = os.path.abspath(inspect.getsourcefile(cls)) data_file_path = os.path.join( os.path.dirname(cls_path), @@ -98,7 +98,7 @@ def _raise_ve(*args): # pylint: disable-msg=W0613 if os.path.exists(data_file_path) is False: test_name = mk_test_name(name, "error") - add_test(cls, test_name, _raise_ve, None) + yield (test_name, _raise_ve, [None], {}) else: data = json.loads(open(data_file_path).read()) for i, elem in enumerate(data): @@ -108,7 +108,7 @@ def _raise_ve(*args): # pylint: disable-msg=W0613 elif isinstance(data, list): value = elem test_name = mk_test_name(name, value, i) - add_test(cls, test_name, func, value) + yield (test_name, func, [value], {}) def is_hash_randomized(): @@ -226,6 +226,12 @@ def ddt(cls): """ for name, func in list(cls.__dict__.items()): if hasattr(func, DATA_ATTR): - getattr(func, DATA_ATTR).add_tests(cls, name, func) + test_specs = getattr(func, DATA_ATTR).generate_tests( + cls, + name, + func + ) + for (test_name, test_func, args, kwargs) in test_specs: + add_test(cls, test_name, test_func, *args, **kwargs) delattr(cls, name) return cls From b4001ffc4f7116b2ee37995ed84e5c87589741a4 Mon Sep 17 00:00:00 2001 From: Calen Pennington Date: Wed, 4 Mar 2015 09:42:15 -0500 Subject: [PATCH 6/8] Reify test values and errors getting test values under a single interface --- ddt.py | 65 ++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/ddt.py b/ddt.py index ae360aa..3bd4c7e 100644 --- a/ddt.py +++ b/ddt.py @@ -71,34 +71,33 @@ def wrapper(func): class DataValues(namedtuple("DataValues", ["values"])): - - def generate_tests(self, cls, name, func): + def test_values(self, cls, name, func): for i, v in enumerate(self.values): test_name = mk_test_name(name, getattr(v, "__name__", v), i) if hasattr(func, UNPACK_ATTR): if isinstance(v, tuple) or isinstance(v, list): - yield (test_name, func, v, {}) + yield TestValue(test_name, func, *v) else: # unpack dictionary - yield (test_name, func, [], v) + yield TestValue(test_name, func, **v) else: - yield (test_name, func, [v], {}) + yield TestValue(test_name, func, v) class FileValues(namedtuple("FileValues", ["file_path"])): - def generate_tests(self, cls, name, func): + def test_values(self, cls, name, func): cls_path = os.path.abspath(inspect.getsourcefile(cls)) data_file_path = os.path.join( os.path.dirname(cls_path), self.file_path ) - def _raise_ve(*args): # pylint: disable-msg=W0613 - raise ValueError("%s does not exist" % self.file_path) - if os.path.exists(data_file_path) is False: test_name = mk_test_name(name, "error") - yield (test_name, _raise_ve, [None], {}) + yield TestError( + test_name, + ValueError("%s does not exist" % self.file_path) + ) else: data = json.loads(open(data_file_path).read()) for i, elem in enumerate(data): @@ -108,7 +107,7 @@ def _raise_ve(*args): # pylint: disable-msg=W0613 elif isinstance(data, list): value = elem test_name = mk_test_name(name, value, i) - yield (test_name, func, [value], {}) + yield TestValue(test_name, func, value) def is_hash_randomized(): @@ -189,15 +188,35 @@ def wrapper(self): return wrapper -def add_test(cls, test_name, func, *args, **kwargs): - """ - Add a test case to this class. +class TestValue(object): + def __init__(self, test_name, func, *args, **kwargs): + self.test_name = test_name + self.func = func + self.args = args + self.kwargs = kwargs - The test will be based on an existing function but will give it a new - name. + def add_test(self, cls): + """ + Add a test case to this class. - """ - setattr(cls, test_name, feed_data(func, test_name, *args, **kwargs)) + The test will be based on an existing function but will give it a new + name. + + """ + setattr( + cls, + self.test_name, + feed_data(self.func, self.test_name, *self.args, **self.kwargs) + ) + + +class TestError(TestValue): + def __init__(self, test_name, exception): + self.exception = exception + super(TestError, self).__init__(test_name, self._raise_exception) + + def _raise_exception(self, test_cls): + raise self.exception def ddt(cls): @@ -226,12 +245,8 @@ def ddt(cls): """ for name, func in list(cls.__dict__.items()): if hasattr(func, DATA_ATTR): - test_specs = getattr(func, DATA_ATTR).generate_tests( - cls, - name, - func - ) - for (test_name, test_func, args, kwargs) in test_specs: - add_test(cls, test_name, test_func, *args, **kwargs) + test_values = getattr(func, DATA_ATTR).test_values(cls, name, func) + for test_value in test_values: + test_value.add_test(cls) delattr(cls, name) return cls From 9e757549e234b045e0768d154fc873e93a3ac70a Mon Sep 17 00:00:00 2001 From: Calen Pennington Date: Wed, 4 Mar 2015 10:33:13 -0500 Subject: [PATCH 7/8] Separate test name creation from value generation --- ddt.py | 70 +++++++++++++++++++++++++++++++--------------------------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/ddt.py b/ddt.py index 3bd4c7e..78650fa 100644 --- a/ddt.py +++ b/ddt.py @@ -71,21 +71,20 @@ def wrapper(func): class DataValues(namedtuple("DataValues", ["values"])): - def test_values(self, cls, name, func): - for i, v in enumerate(self.values): - test_name = mk_test_name(name, getattr(v, "__name__", v), i) + def test_values(self, cls, func): + for v in self.values: if hasattr(func, UNPACK_ATTR): if isinstance(v, tuple) or isinstance(v, list): - yield TestValue(test_name, func, *v) + yield TestValue(func, *v) else: # unpack dictionary - yield TestValue(test_name, func, **v) + yield TestValue(func, **v) else: - yield TestValue(test_name, func, v) + yield TestValue(func, v) class FileValues(namedtuple("FileValues", ["file_path"])): - def test_values(self, cls, name, func): + def test_values(self, cls, func): cls_path = os.path.abspath(inspect.getsourcefile(cls)) data_file_path = os.path.join( os.path.dirname(cls_path), @@ -93,21 +92,15 @@ def test_values(self, cls, name, func): ) if os.path.exists(data_file_path) is False: - test_name = mk_test_name(name, "error") - yield TestError( - test_name, - ValueError("%s does not exist" % self.file_path) - ) + yield TestError(ValueError("%s does not exist" % self.file_path)) else: data = json.loads(open(data_file_path).read()) - for i, elem in enumerate(data): - if isinstance(data, dict): - key, value = elem, data[elem] - test_name = mk_test_name(name, key, i) - elif isinstance(data, list): - value = elem - test_name = mk_test_name(name, value, i) - yield TestValue(test_name, func, value) + if isinstance(data, dict): + for key, value in data.items(): + yield TestValue(func, value, _test_value_name=key) + elif isinstance(data, list): + for value in data: + yield TestValue(func, value) def is_hash_randomized(): @@ -189,13 +182,23 @@ def wrapper(self): class TestValue(object): - def __init__(self, test_name, func, *args, **kwargs): - self.test_name = test_name + def __init__(self, func, *args, **kwargs): + if '_test_value_name' in kwargs: + self.value_name = kwargs.pop('_test_value_name') + elif len(args) == 1 and not kwargs: + self.value_name = getattr(args[0], '__name__', args[0]) + elif args: + self.value_name = args + elif kwargs: + self.value_name = kwargs + else: + raise Exception("unable to generate value names") + self.func = func self.args = args self.kwargs = kwargs - def add_test(self, cls): + def add_test(self, cls, name, index): """ Add a test case to this class. @@ -203,17 +206,18 @@ def add_test(self, cls): name. """ - setattr( - cls, - self.test_name, - feed_data(self.func, self.test_name, *self.args, **self.kwargs) - ) + test_name = mk_test_name(name, self.value_name, index) + test_data = feed_data(self.func, test_name, *self.args, **self.kwargs) + setattr(cls, test_name, test_data) class TestError(TestValue): - def __init__(self, test_name, exception): + def __init__(self, exception): self.exception = exception - super(TestError, self).__init__(test_name, self._raise_exception) + super(TestError, self).__init__( + self._raise_exception, + _test_value_name="error" + ) def _raise_exception(self, test_cls): raise self.exception @@ -245,8 +249,8 @@ def ddt(cls): """ for name, func in list(cls.__dict__.items()): if hasattr(func, DATA_ATTR): - test_values = getattr(func, DATA_ATTR).test_values(cls, name, func) - for test_value in test_values: - test_value.add_test(cls) + values = getattr(func, DATA_ATTR).test_values(cls, func) + for idx, test_value in enumerate(values): + test_value.add_test(cls, name, idx) delattr(cls, name) return cls From 9b44abbba92bf2c2392a3d523ab535a31488e157 Mon Sep 17 00:00:00 2001 From: Calen Pennington Date: Wed, 4 Mar 2015 11:37:47 -0500 Subject: [PATCH 8/8] Allow multiple invocations of @data and @file_data to produce data cross-products --- ddt.py | 82 +++++++++++++++++++++++++++++++++++++---- test/test_example.py | 22 +++++++++++ test/test_functional.py | 61 ++++++++++++++++++++++++++++-- 3 files changed, 153 insertions(+), 12 deletions(-) diff --git a/ddt.py b/ddt.py index 78650fa..80b52cc 100644 --- a/ddt.py +++ b/ddt.py @@ -6,6 +6,7 @@ # https://github.com/txels/ddt/blob/master/LICENSE.md import inspect +import itertools import json import os import re @@ -26,7 +27,6 @@ def unpack(func): """ Method decorator to add unpack feature. - """ setattr(func, UNPACK_ATTR, True) return func @@ -34,20 +34,48 @@ def unpack(func): def data(*values): """ - Method decorator to add to your test methods. + Method decorator to add data to your test methods. Should be added to methods of instances of ``unittest.TestCase``. + All arguments to this function will be passed as arguments to the + decorated test method, and each argument passed will generate a new test + method. + + The ``@data`` and ``@file_data`` decorators can be nested. When nested, + the arguments passed to ``data`` and stored in the files named in + ``file_data`` will be combined in a Cartesian product before being passed + to the decorated method. For example: + + .. code-block:: python + + @data(1, 2) + @data(3, 4) + def test_foo(self, fst, snd): + pass + + would result in four test calls: + + .. code-block:: python + + test_foo(1, 3) + test_foo(2, 3) + test_foo(1, 4) + test_foo(2, 4) """ def wrapper(func): - setattr(func, DATA_ATTR, DataValues(values)) + if not hasattr(func, DATA_ATTR): + setattr(func, DATA_ATTR, []) + # Prepend the new set of values, so that the the stack of @data + # decorators inserts arguments from top to bottom + getattr(func, DATA_ATTR).insert(0, DataValues(values)) return func return wrapper def file_data(value): """ - Method decorator to add to your test methods. + Method decorator to add data from a file to your test methods. Should be added to methods of instances of ``unittest.TestCase``. @@ -65,7 +93,11 @@ def file_data(value): """ def wrapper(func): - setattr(func, DATA_ATTR, FileValues(value)) + if not hasattr(func, DATA_ATTR): + setattr(func, DATA_ATTR, []) + # Prepend the new set of values, so that the the stack of @data + # decorators inserts arguments from top to bottom + getattr(func, DATA_ATTR).insert(0, FileValues(value)) return func return wrapper @@ -223,6 +255,37 @@ def _raise_exception(self, test_cls): raise self.exception +def combine(test_values): + if len(test_values) == 1: + return test_values[0] + + func = None + args = [] + kwargs = {} + names = [] + for test_value in test_values: + if isinstance(test_value, TestError): + return test_value + + if func is None: + func = test_value.func + elif test_value.func != func: + return TestError( + ValueError( + "{} is not the same function as {}".format( + test_value.func, + func + ) + ) + ) + + args.extend(test_value.args) + kwargs.update(test_value.kwargs) + names.append(test_value.value_name) + + return TestValue(func, _test_value_name=names, *args, **kwargs) + + def ddt(cls): """ Class decorator for subclasses of ``unittest.TestCase``. @@ -249,8 +312,11 @@ def ddt(cls): """ for name, func in list(cls.__dict__.items()): if hasattr(func, DATA_ATTR): - values = getattr(func, DATA_ATTR).test_values(cls, func) - for idx, test_value in enumerate(values): - test_value.add_test(cls, name, idx) + value_combinations = itertools.product(*( + value_set.test_values(cls, func) + for value_set in getattr(func, DATA_ATTR) + )) + for idx, values in enumerate(value_combinations): + combine(values).add_test(cls, name, idx) delattr(cls, name) return cls diff --git a/test/test_example.py b/test/test_example.py index 35fc002..d7bea4b 100644 --- a/test/test_example.py +++ b/test/test_example.py @@ -55,6 +55,28 @@ def test_list_extracted_into_arguments(self, first_value, second_value): def test_dicts_extracted_into_kwargs(self, first, second, third): self.assertTrue(first < third < second) + @data(1, 2, 3) + @data(4, 5, 6) + @data(7, 8, 9) + def test_products(self, first_value, second_value, third_value): + self.assertTrue(first_value < second_value < third_value) + + @unpack + @data({'first': 1}, {'first': 2}, {'first': 3}) + @data( + {'second': 4, 'third': 5}, + {'second': 5, 'third': 6}, + {'second': 6, 'third': 7} + ) + def test_dict_products(self, first, second, third): + self.assertTrue(first < second < third) + + @unpack + @data([1], [2], [3]) + @data((4, 5), (5, 6), (6, 7)) + def test_list_products(self, first, second, third): + self.assertTrue(first < second < third) + @data(u'ascii', u'non-ascii-\N{SNOWMAN}') def test_unicode(self, value): self.assertIn(value, (u'ascii', u'non-ascii-\N{SNOWMAN}')) diff --git a/test/test_functional.py b/test/test_functional.py index aa1ae4f..64e8150 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -4,7 +4,7 @@ import six from ddt import ( - ddt, data, file_data, is_hash_randomized, DataValues + ddt, data, file_data, is_hash_randomized, DataValues, FileValues ) from nose.tools import assert_equal, assert_is_not_none, assert_raises @@ -73,7 +73,31 @@ def hello(): extra_attrs = dh_keys - keys assert_equal(len(extra_attrs), 1) extra_attr = extra_attrs.pop() - assert_equal(getattr(data_hello, extra_attr), DataValues((1, 2))) + assert_equal(getattr(data_hello, extra_attr), [DataValues((1, 2))]) + + +def test_multiple_data_decorators(): + """ + Test the ``data`` method decorator with multiple applications + """ + + def hello(): + pass + + pre_size = len(hello.__dict__) + keys = set(hello.__dict__.keys()) + data_hello = data(1, 2)(data(3)(hello)) + dh_keys = set(data_hello.__dict__.keys()) + post_size = len(data_hello.__dict__) + + assert_equal(post_size, pre_size + 1) + extra_attrs = dh_keys - keys + assert_equal(len(extra_attrs), 1) + extra_attr = extra_attrs.pop() + assert_equal( + getattr(data_hello, extra_attr), + [DataValues((1, 2)), DataValues((3,))] + ) def test_file_data_decorator_with_dict(): @@ -86,7 +110,33 @@ def hello(): pre_size = len(hello.__dict__) keys = set(hello.__dict__.keys()) - data_hello = data("test_data_dict.json")(hello) + data_hello = file_data("test_data_dict.json")(hello) + + dh_keys = set(data_hello.__dict__.keys()) + post_size = len(data_hello.__dict__) + + assert_equal(post_size, pre_size + 1) + extra_attrs = dh_keys - keys + assert_equal(len(extra_attrs), 1) + extra_attr = extra_attrs.pop() + assert_equal( + getattr(data_hello, extra_attr), + [FileValues("test_data_dict.json")] + ) + + +def test_multiple_file_data_decorators_with_dict(): + """ + Test the ``file_data`` method decorator with multiple applications + """ + + def hello(): + pass + + pre_size = len(hello.__dict__) + keys = set(hello.__dict__.keys()) + data_hello = file_data("test_other_data.json")(hello) + data_hello = file_data("test_data_dict.json")(data_hello) dh_keys = set(data_hello.__dict__.keys()) post_size = len(data_hello.__dict__) @@ -97,7 +147,10 @@ def hello(): extra_attr = extra_attrs.pop() assert_equal( getattr(data_hello, extra_attr), - DataValues(("test_data_dict.json",)) + [ + FileValues("test_data_dict.json"), + FileValues("test_other_data.json"), + ] )