diff --git a/ddt.py b/ddt.py index a084695..80b52cc 100644 --- a/ddt.py +++ b/ddt.py @@ -6,11 +6,13 @@ # https://github.com/txels/ddt/blob/master/LICENSE.md import inspect +import itertools import json import os import re import sys from functools import wraps +from collections import namedtuple __version__ = '1.0.0' @@ -19,14 +21,12 @@ # by the `ddt` class decorator. DATA_ATTR = '%values' # store the data the test must run with -FILE_ATTR = '%file_path' # store the path to JSON file UNPACK_ATTR = '%unpack' # remember that we have to unpack values def unpack(func): """ Method decorator to add unpack feature. - """ setattr(func, UNPACK_ATTR, True) return func @@ -34,20 +34,48 @@ def unpack(func): def data(*values): """ - Method decorator to add to your test methods. + Method decorator to add data to your test methods. Should be added to methods of instances of ``unittest.TestCase``. + All arguments to this function will be passed as arguments to the + decorated test method, and each argument passed will generate a new test + method. + + The ``@data`` and ``@file_data`` decorators can be nested. When nested, + the arguments passed to ``data`` and stored in the files named in + ``file_data`` will be combined in a Cartesian product before being passed + to the decorated method. For example: + + .. code-block:: python + + @data(1, 2) + @data(3, 4) + def test_foo(self, fst, snd): + pass + + would result in four test calls: + + .. code-block:: python + + test_foo(1, 3) + test_foo(2, 3) + test_foo(1, 4) + test_foo(2, 4) """ def wrapper(func): - setattr(func, DATA_ATTR, values) + if not hasattr(func, DATA_ATTR): + setattr(func, DATA_ATTR, []) + # Prepend the new set of values, so that the the stack of @data + # decorators inserts arguments from top to bottom + getattr(func, DATA_ATTR).insert(0, DataValues(values)) return func return wrapper def file_data(value): """ - Method decorator to add to your test methods. + Method decorator to add data from a file to your test methods. Should be added to methods of instances of ``unittest.TestCase``. @@ -65,11 +93,48 @@ def file_data(value): """ def wrapper(func): - setattr(func, FILE_ATTR, value) + if not hasattr(func, DATA_ATTR): + setattr(func, DATA_ATTR, []) + # Prepend the new set of values, so that the the stack of @data + # decorators inserts arguments from top to bottom + getattr(func, DATA_ATTR).insert(0, FileValues(value)) return func return wrapper +class DataValues(namedtuple("DataValues", ["values"])): + def test_values(self, cls, func): + for v in self.values: + if hasattr(func, UNPACK_ATTR): + if isinstance(v, tuple) or isinstance(v, list): + yield TestValue(func, *v) + else: + # unpack dictionary + yield TestValue(func, **v) + else: + yield TestValue(func, v) + + +class FileValues(namedtuple("FileValues", ["file_path"])): + def test_values(self, cls, func): + cls_path = os.path.abspath(inspect.getsourcefile(cls)) + data_file_path = os.path.join( + os.path.dirname(cls_path), + self.file_path + ) + + if os.path.exists(data_file_path) is False: + yield TestError(ValueError("%s does not exist" % self.file_path)) + else: + data = json.loads(open(data_file_path).read()) + if isinstance(data, dict): + for key, value in data.items(): + yield TestValue(func, value, _test_value_name=key) + elif isinstance(data, list): + for value in data: + yield TestValue(func, value) + + def is_hash_randomized(): return (((sys.hexversion >= 0x02070300 and sys.hexversion < 0x03000000) or @@ -148,41 +213,77 @@ def wrapper(self): return wrapper -def add_test(cls, test_name, func, *args, **kwargs): - """ - Add a test case to this class. - - The test will be based on an existing function but will give it a new - name. - - """ - setattr(cls, test_name, feed_data(func, test_name, *args, **kwargs)) - - -def process_file_data(cls, name, func, file_attr): - """ - Process the parameter in the `file_data` decorator. - - """ - cls_path = os.path.abspath(inspect.getsourcefile(cls)) - data_file_path = os.path.join(os.path.dirname(cls_path), file_attr) - - def _raise_ve(*args): # pylint: disable-msg=W0613 - raise ValueError("%s does not exist" % file_attr) - - if os.path.exists(data_file_path) is False: - test_name = mk_test_name(name, "error") - add_test(cls, test_name, _raise_ve, None) - else: - data = json.loads(open(data_file_path).read()) - for i, elem in enumerate(data): - if isinstance(data, dict): - key, value = elem, data[elem] - test_name = mk_test_name(name, key, i) - elif isinstance(data, list): - value = elem - test_name = mk_test_name(name, value, i) - add_test(cls, test_name, func, value) +class TestValue(object): + def __init__(self, func, *args, **kwargs): + if '_test_value_name' in kwargs: + self.value_name = kwargs.pop('_test_value_name') + elif len(args) == 1 and not kwargs: + self.value_name = getattr(args[0], '__name__', args[0]) + elif args: + self.value_name = args + elif kwargs: + self.value_name = kwargs + else: + raise Exception("unable to generate value names") + + self.func = func + self.args = args + self.kwargs = kwargs + + def add_test(self, cls, name, index): + """ + Add a test case to this class. + + The test will be based on an existing function but will give it a new + name. + + """ + test_name = mk_test_name(name, self.value_name, index) + test_data = feed_data(self.func, test_name, *self.args, **self.kwargs) + setattr(cls, test_name, test_data) + + +class TestError(TestValue): + def __init__(self, exception): + self.exception = exception + super(TestError, self).__init__( + self._raise_exception, + _test_value_name="error" + ) + + def _raise_exception(self, test_cls): + raise self.exception + + +def combine(test_values): + if len(test_values) == 1: + return test_values[0] + + func = None + args = [] + kwargs = {} + names = [] + for test_value in test_values: + if isinstance(test_value, TestError): + return test_value + + if func is None: + func = test_value.func + elif test_value.func != func: + return TestError( + ValueError( + "{} is not the same function as {}".format( + test_value.func, + func + ) + ) + ) + + args.extend(test_value.args) + kwargs.update(test_value.kwargs) + names.append(test_value.value_name) + + return TestValue(func, _test_value_name=names, *args, **kwargs) def ddt(cls): @@ -211,19 +312,11 @@ def ddt(cls): """ for name, func in list(cls.__dict__.items()): if hasattr(func, DATA_ATTR): - for i, v in enumerate(getattr(func, DATA_ATTR)): - test_name = mk_test_name(name, getattr(v, "__name__", v), i) - if hasattr(func, UNPACK_ATTR): - if isinstance(v, tuple) or isinstance(v, list): - add_test(cls, test_name, func, *v) - else: - # unpack dictionary - add_test(cls, test_name, func, **v) - else: - add_test(cls, test_name, func, v) - delattr(cls, name) - elif hasattr(func, FILE_ATTR): - file_attr = getattr(func, FILE_ATTR) - process_file_data(cls, name, func, file_attr) + value_combinations = itertools.product(*( + value_set.test_values(cls, func) + for value_set in getattr(func, DATA_ATTR) + )) + for idx, values in enumerate(value_combinations): + combine(values).add_test(cls, name, idx) delattr(cls, name) return cls diff --git a/test/test_example.py b/test/test_example.py index 35fc002..d7bea4b 100644 --- a/test/test_example.py +++ b/test/test_example.py @@ -55,6 +55,28 @@ def test_list_extracted_into_arguments(self, first_value, second_value): def test_dicts_extracted_into_kwargs(self, first, second, third): self.assertTrue(first < third < second) + @data(1, 2, 3) + @data(4, 5, 6) + @data(7, 8, 9) + def test_products(self, first_value, second_value, third_value): + self.assertTrue(first_value < second_value < third_value) + + @unpack + @data({'first': 1}, {'first': 2}, {'first': 3}) + @data( + {'second': 4, 'third': 5}, + {'second': 5, 'third': 6}, + {'second': 6, 'third': 7} + ) + def test_dict_products(self, first, second, third): + self.assertTrue(first < second < third) + + @unpack + @data([1], [2], [3]) + @data((4, 5), (5, 6), (6, 7)) + def test_list_products(self, first, second, third): + self.assertTrue(first < second < third) + @data(u'ascii', u'non-ascii-\N{SNOWMAN}') def test_unicode(self, value): self.assertIn(value, (u'ascii', u'non-ascii-\N{SNOWMAN}')) diff --git a/test/test_functional.py b/test/test_functional.py index b0e8cfd..64e8150 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -3,7 +3,9 @@ import six -from ddt import ddt, data, file_data, is_hash_randomized +from ddt import ( + ddt, data, file_data, is_hash_randomized, DataValues, FileValues +) from nose.tools import assert_equal, assert_is_not_none, assert_raises @@ -71,7 +73,31 @@ def hello(): extra_attrs = dh_keys - keys assert_equal(len(extra_attrs), 1) extra_attr = extra_attrs.pop() - assert_equal(getattr(data_hello, extra_attr), (1, 2)) + assert_equal(getattr(data_hello, extra_attr), [DataValues((1, 2))]) + + +def test_multiple_data_decorators(): + """ + Test the ``data`` method decorator with multiple applications + """ + + def hello(): + pass + + pre_size = len(hello.__dict__) + keys = set(hello.__dict__.keys()) + data_hello = data(1, 2)(data(3)(hello)) + dh_keys = set(data_hello.__dict__.keys()) + post_size = len(data_hello.__dict__) + + assert_equal(post_size, pre_size + 1) + extra_attrs = dh_keys - keys + assert_equal(len(extra_attrs), 1) + extra_attr = extra_attrs.pop() + assert_equal( + getattr(data_hello, extra_attr), + [DataValues((1, 2)), DataValues((3,))] + ) def test_file_data_decorator_with_dict(): @@ -84,7 +110,7 @@ def hello(): pre_size = len(hello.__dict__) keys = set(hello.__dict__.keys()) - data_hello = data("test_data_dict.json")(hello) + data_hello = file_data("test_data_dict.json")(hello) dh_keys = set(data_hello.__dict__.keys()) post_size = len(data_hello.__dict__) @@ -93,7 +119,39 @@ def hello(): extra_attrs = dh_keys - keys assert_equal(len(extra_attrs), 1) extra_attr = extra_attrs.pop() - assert_equal(getattr(data_hello, extra_attr), ("test_data_dict.json",)) + assert_equal( + getattr(data_hello, extra_attr), + [FileValues("test_data_dict.json")] + ) + + +def test_multiple_file_data_decorators_with_dict(): + """ + Test the ``file_data`` method decorator with multiple applications + """ + + def hello(): + pass + + pre_size = len(hello.__dict__) + keys = set(hello.__dict__.keys()) + data_hello = file_data("test_other_data.json")(hello) + data_hello = file_data("test_data_dict.json")(data_hello) + + dh_keys = set(data_hello.__dict__.keys()) + post_size = len(data_hello.__dict__) + + assert_equal(post_size, pre_size + 1) + extra_attrs = dh_keys - keys + assert_equal(len(extra_attrs), 1) + extra_attr = extra_attrs.pop() + assert_equal( + getattr(data_hello, extra_attr), + [ + FileValues("test_data_dict.json"), + FileValues("test_other_data.json"), + ] + ) is_test = lambda x: x.startswith('test_') diff --git a/tox.ini b/tox.ini index b8c1edf..2103baa 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ envlist = py27, py33, py34 [testenv] deps = nose - pep8 + pep8>=1.5.7,<1.6 # Versions limited to work with flake8 coverage flake8 six