Skip to content
203 changes: 148 additions & 55 deletions ddt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
# https://github.com/txels/ddt/blob/master/LICENSE.md

import inspect
import itertools
import json
import os
import re
import sys
from functools import wraps
from collections import namedtuple

__version__ = '1.0.0'

Expand All @@ -19,35 +21,61 @@
# by the `ddt` class decorator.

DATA_ATTR = '%values' # store the data the test must run with
FILE_ATTR = '%file_path' # store the path to JSON file
UNPACK_ATTR = '%unpack' # remember that we have to unpack values


def unpack(func):
"""
Method decorator to add unpack feature.

"""
setattr(func, UNPACK_ATTR, True)
return func


def data(*values):
"""
Method decorator to add to your test methods.
Method decorator to add data to your test methods.

Should be added to methods of instances of ``unittest.TestCase``.

All arguments to this function will be passed as arguments to the
decorated test method, and each argument passed will generate a new test
method.

The ``@data`` and ``@file_data`` decorators can be nested. When nested,
the arguments passed to ``data`` and stored in the files named in
``file_data`` will be combined in a Cartesian product before being passed
to the decorated method. For example:

.. code-block:: python

@data(1, 2)
@data(3, 4)
def test_foo(self, fst, snd):
pass

would result in four test calls:

.. code-block:: python

test_foo(1, 3)
test_foo(2, 3)
test_foo(1, 4)
test_foo(2, 4)
"""
def wrapper(func):
setattr(func, DATA_ATTR, values)
if not hasattr(func, DATA_ATTR):
setattr(func, DATA_ATTR, [])
# Prepend the new set of values, so that the the stack of @data
# decorators inserts arguments from top to bottom
getattr(func, DATA_ATTR).insert(0, DataValues(values))
return func
return wrapper


def file_data(value):
"""
Method decorator to add to your test methods.
Method decorator to add data from a file to your test methods.

Should be added to methods of instances of ``unittest.TestCase``.

Expand All @@ -65,11 +93,48 @@ def file_data(value):

"""
def wrapper(func):
setattr(func, FILE_ATTR, value)
if not hasattr(func, DATA_ATTR):
setattr(func, DATA_ATTR, [])
# Prepend the new set of values, so that the the stack of @data
# decorators inserts arguments from top to bottom
getattr(func, DATA_ATTR).insert(0, FileValues(value))
return func
return wrapper


class DataValues(namedtuple("DataValues", ["values"])):
def test_values(self, cls, func):
for v in self.values:
if hasattr(func, UNPACK_ATTR):
if isinstance(v, tuple) or isinstance(v, list):
yield TestValue(func, *v)
else:
# unpack dictionary
yield TestValue(func, **v)
else:
yield TestValue(func, v)


class FileValues(namedtuple("FileValues", ["file_path"])):
def test_values(self, cls, func):
cls_path = os.path.abspath(inspect.getsourcefile(cls))
data_file_path = os.path.join(
os.path.dirname(cls_path),
self.file_path
)

if os.path.exists(data_file_path) is False:
yield TestError(ValueError("%s does not exist" % self.file_path))
else:
data = json.loads(open(data_file_path).read())
if isinstance(data, dict):
for key, value in data.items():
yield TestValue(func, value, _test_value_name=key)
elif isinstance(data, list):
for value in data:
yield TestValue(func, value)


def is_hash_randomized():
return (((sys.hexversion >= 0x02070300 and
sys.hexversion < 0x03000000) or
Expand Down Expand Up @@ -148,41 +213,77 @@ def wrapper(self):
return wrapper


def add_test(cls, test_name, func, *args, **kwargs):
"""
Add a test case to this class.

The test will be based on an existing function but will give it a new
name.

"""
setattr(cls, test_name, feed_data(func, test_name, *args, **kwargs))


def process_file_data(cls, name, func, file_attr):
"""
Process the parameter in the `file_data` decorator.

"""
cls_path = os.path.abspath(inspect.getsourcefile(cls))
data_file_path = os.path.join(os.path.dirname(cls_path), file_attr)

def _raise_ve(*args): # pylint: disable-msg=W0613
raise ValueError("%s does not exist" % file_attr)

if os.path.exists(data_file_path) is False:
test_name = mk_test_name(name, "error")
add_test(cls, test_name, _raise_ve, None)
else:
data = json.loads(open(data_file_path).read())
for i, elem in enumerate(data):
if isinstance(data, dict):
key, value = elem, data[elem]
test_name = mk_test_name(name, key, i)
elif isinstance(data, list):
value = elem
test_name = mk_test_name(name, value, i)
add_test(cls, test_name, func, value)
class TestValue(object):
def __init__(self, func, *args, **kwargs):
if '_test_value_name' in kwargs:
self.value_name = kwargs.pop('_test_value_name')
elif len(args) == 1 and not kwargs:
self.value_name = getattr(args[0], '__name__', args[0])
elif args:
self.value_name = args
elif kwargs:
self.value_name = kwargs
else:
raise Exception("unable to generate value names")

self.func = func
self.args = args
self.kwargs = kwargs

def add_test(self, cls, name, index):
"""
Add a test case to this class.

The test will be based on an existing function but will give it a new
name.

"""
test_name = mk_test_name(name, self.value_name, index)
test_data = feed_data(self.func, test_name, *self.args, **self.kwargs)
setattr(cls, test_name, test_data)


class TestError(TestValue):
def __init__(self, exception):
self.exception = exception
super(TestError, self).__init__(
self._raise_exception,
_test_value_name="error"
)

def _raise_exception(self, test_cls):
raise self.exception


def combine(test_values):
if len(test_values) == 1:
return test_values[0]

func = None
args = []
kwargs = {}
names = []
for test_value in test_values:
if isinstance(test_value, TestError):
return test_value

if func is None:
func = test_value.func
elif test_value.func != func:
return TestError(
ValueError(
"{} is not the same function as {}".format(
test_value.func,
func
)
)
)

args.extend(test_value.args)
kwargs.update(test_value.kwargs)
names.append(test_value.value_name)

return TestValue(func, _test_value_name=names, *args, **kwargs)


def ddt(cls):
Expand Down Expand Up @@ -211,19 +312,11 @@ def ddt(cls):
"""
for name, func in list(cls.__dict__.items()):
if hasattr(func, DATA_ATTR):
for i, v in enumerate(getattr(func, DATA_ATTR)):
test_name = mk_test_name(name, getattr(v, "__name__", v), i)
if hasattr(func, UNPACK_ATTR):
if isinstance(v, tuple) or isinstance(v, list):
add_test(cls, test_name, func, *v)
else:
# unpack dictionary
add_test(cls, test_name, func, **v)
else:
add_test(cls, test_name, func, v)
delattr(cls, name)
elif hasattr(func, FILE_ATTR):
file_attr = getattr(func, FILE_ATTR)
process_file_data(cls, name, func, file_attr)
value_combinations = itertools.product(*(
value_set.test_values(cls, func)
for value_set in getattr(func, DATA_ATTR)
))
for idx, values in enumerate(value_combinations):
combine(values).add_test(cls, name, idx)
delattr(cls, name)
return cls
22 changes: 22 additions & 0 deletions test/test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ def test_list_extracted_into_arguments(self, first_value, second_value):
def test_dicts_extracted_into_kwargs(self, first, second, third):
self.assertTrue(first < third < second)

@data(1, 2, 3)
@data(4, 5, 6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All your examples include 2 sets of @data. I haven't debugged your implementation to see if this is a limitation there, but it doesn't feel like it should. In that case, it would be nice to include some example case with more than 2 @data entries to make that explicit.

@data(7, 8, 9)
def test_products(self, first_value, second_value, third_value):
self.assertTrue(first_value < second_value < third_value)

@unpack
@data({'first': 1}, {'first': 2}, {'first': 3})
@data(
{'second': 4, 'third': 5},
{'second': 5, 'third': 6},
{'second': 6, 'third': 7}
)
def test_dict_products(self, first, second, third):
self.assertTrue(first < second < third)

@unpack
@data([1], [2], [3])
@data((4, 5), (5, 6), (6, 7))
def test_list_products(self, first, second, third):
self.assertTrue(first < second < third)

@data(u'ascii', u'non-ascii-\N{SNOWMAN}')
def test_unicode(self, value):
self.assertIn(value, (u'ascii', u'non-ascii-\N{SNOWMAN}'))
66 changes: 62 additions & 4 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import six

from ddt import ddt, data, file_data, is_hash_randomized
from ddt import (
ddt, data, file_data, is_hash_randomized, DataValues, FileValues
)
from nose.tools import assert_equal, assert_is_not_none, assert_raises


Expand Down Expand Up @@ -71,7 +73,31 @@ def hello():
extra_attrs = dh_keys - keys
assert_equal(len(extra_attrs), 1)
extra_attr = extra_attrs.pop()
assert_equal(getattr(data_hello, extra_attr), (1, 2))
assert_equal(getattr(data_hello, extra_attr), [DataValues((1, 2))])


def test_multiple_data_decorators():
"""
Test the ``data`` method decorator with multiple applications
"""

def hello():
pass

pre_size = len(hello.__dict__)
keys = set(hello.__dict__.keys())
data_hello = data(1, 2)(data(3)(hello))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer adding a new test case for the new feature, rather than overloading an existing test case.

The main reason being, any new feature should strive to be backwards compatible... meaning all existing tests should pass unmodified. It is harder for me to figure out whether your contribution is backwards-compatible or not, if you modify existing tests.

dh_keys = set(data_hello.__dict__.keys())
post_size = len(data_hello.__dict__)

assert_equal(post_size, pre_size + 1)
extra_attrs = dh_keys - keys
assert_equal(len(extra_attrs), 1)
extra_attr = extra_attrs.pop()
assert_equal(
getattr(data_hello, extra_attr),
[DataValues((1, 2)), DataValues((3,))]
)


def test_file_data_decorator_with_dict():
Expand All @@ -84,7 +110,7 @@ def hello():

pre_size = len(hello.__dict__)
keys = set(hello.__dict__.keys())
data_hello = data("test_data_dict.json")(hello)
data_hello = file_data("test_data_dict.json")(hello)

dh_keys = set(data_hello.__dict__.keys())
post_size = len(data_hello.__dict__)
Expand All @@ -93,7 +119,39 @@ def hello():
extra_attrs = dh_keys - keys
assert_equal(len(extra_attrs), 1)
extra_attr = extra_attrs.pop()
assert_equal(getattr(data_hello, extra_attr), ("test_data_dict.json",))
assert_equal(
getattr(data_hello, extra_attr),
[FileValues("test_data_dict.json")]
)


def test_multiple_file_data_decorators_with_dict():
"""
Test the ``file_data`` method decorator with multiple applications
"""

def hello():
pass

pre_size = len(hello.__dict__)
keys = set(hello.__dict__.keys())
data_hello = file_data("test_other_data.json")(hello)
data_hello = file_data("test_data_dict.json")(data_hello)

dh_keys = set(data_hello.__dict__.keys())
post_size = len(data_hello.__dict__)

assert_equal(post_size, pre_size + 1)
extra_attrs = dh_keys - keys
assert_equal(len(extra_attrs), 1)
extra_attr = extra_attrs.pop()
assert_equal(
getattr(data_hello, extra_attr),
[
FileValues("test_data_dict.json"),
FileValues("test_other_data.json"),
]
)


is_test = lambda x: x.startswith('test_')
Expand Down
Loading