diff --git a/offchain/concurrency.py b/offchain/concurrency.py index 1cc4c03..9597612 100644 --- a/offchain/concurrency.py +++ b/offchain/concurrency.py @@ -1,9 +1,12 @@ import multiprocessing from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Sequence +from typing import Any, Callable, Sequence, TypeVar from offchain.logger.logging import logger +T = TypeVar("T") +U = TypeVar("U") + MAX_PROCS = (multiprocessing.cpu_count() * 2) + 1 @@ -39,15 +42,11 @@ def parmap(fn: Callable, args: list) -> list: # type: ignore[type-arg] return list(parallelize_with_threads(*map(lambda i: lambda: fn(i), args))) # type: ignore[arg-type] # noqa: E501 -def batched_parmap(fn: Callable, args: list, batch_size: int = 10) -> list: # type: ignore[type-arg] # noqa: E501 +def batched_parmap(fn: Callable[[T], U], args: list[T], batch_size: int = 10) -> list[U]: # noqa: E501 results = [] - i, j = 0, 0 - while i < len(args): - i, j = i + batch_size, i - if len(args) > i: - batch = args[j:i] - else: - batch = args[j:] + for i in range(0, len(args), batch_size): + batch_end = i + batch_size + batch = args[i:batch_end] res = parmap(fn, batch) - results += res + results.extend(res) return results diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000..87b2567 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,14 @@ +import pytest + +from offchain.concurrency import batched_parmap + + +@pytest.mark.parametrize("batch_size", range(1, 11)) +def test_batched_parmap(batch_size): + def square(x): + return x * x + + args = list(range(0, 10)) + expected = [square(x) for x in args] + result = batched_parmap(square, args, batch_size=batch_size) + assert result == expected