diff --git a/std/algorithm.d b/std/algorithm.d index 792b6002635..bb1417af569 100644 --- a/std/algorithm.d +++ b/std/algorithm.d @@ -3681,7 +3681,7 @@ needle = The element searched for. Constraints: -$(D isInputRange!R && is(typeof(binaryFun!pred(haystack.front, needle) +$(D isInputRange!InputRange && is(typeof(binaryFun!pred(haystack.front, needle) : bool))) Returns: @@ -3710,28 +3710,147 @@ assert(!find!("toLower(a) == b")(s, "hello").empty); See_Also: $(WEB sgi.com/tech/stl/_find.html, STL's _find) */ +InputRange find(alias pred = "a == b", InputRange, Element)(InputRange haystack, Element needle) +if (isInputRange!InputRange && + is (typeof(binaryFun!pred(haystack.front, needle)) : bool)) +{ + alias R = InputRange; + alias E = Element; + alias predFun = binaryFun!pred; + static if (is(typeof(pred == "a == b"))) + enum isDefaultPred = pred == "a == b"; + else + enum isDefaultPred = false; + enum isIntegralNeedle = isSomeChar!E || isIntegral!E || isBoolean!E; -R find(alias pred = "a == b", R, E)(R haystack, E needle) -if (isInputRange!R && - is(typeof(binaryFun!pred(haystack.front, needle)) : bool)) -{ - static if (isNarrowString!R && isSomeChar!E && is(typeof(pred == "a == b")) && pred == "a == b") + alias EType = ElementType!R; + + static if (isNarrowString!R) + { + alias EEType = ElementEncodingType!R; + alias UEEType = Unqual!EEType; + + //These are two special cases which can search without decoding the UTF stream. + static if (isDefaultPred && isIntegralNeedle) + { + //This special case deals with UTF8 search, when the needle + //is represented by a single code point. + //Note: "needle <= 0x7F" properly handles sign via unsigned promotion + static if (is(UEEType == char)) + { + if (!__ctfe && needle <= 0x7F) + { + static R trustedMemchr(ref R haystack, ref E needle) @trusted nothrow pure + { + auto ptr = memchr(haystack.ptr, needle, haystack.length); + return ptr ? + haystack[ptr - haystack.ptr .. $] : + haystack[$ .. $]; + } + return trustedMemchr(haystack, needle); + } + } + + //Ditto, but for UTF16 + static if (is(UEEType == wchar)) + { + if (needle <= 0xD7FF || (0xE000 <= needle && needle <= 0xFFFF)) + { + foreach (i, ref EEType e; haystack) + { + if (e == needle) + return haystack[i .. $]; + } + return haystack[$ .. $]; + } + } + } + + //Previous conditonal optimizations did not succeed. Fallback to + //unconditional implementations + static if (isDefaultPred) + { + //In case of default pred, it is faster to do string/string search. + UEEType[is(UEEType == char) ? 4 : 2] buf; + + size_t len = encode(buf, needle); + //TODO: Make find!(R, R) @safe + R trustedFindRR(ref R haystack, UEEType[] needle) @trusted pure + { + return cast(R) std.algorithm.find(haystack, needle); + } + return trustedFindRR(haystack, buf[0 .. len]); + } + else + { + //Explicit pred: we must test each character by the book. + //We choose a manual decoding approach, because it is faster than + //the built-in foreach, or doing a front/popFront for-loop. + immutable len = haystack.length; + size_t i = 0, next = 0; + while (next < len) + { + if (predFun(decode(haystack, next), needle)) + return haystack[i .. $]; + i = next; + } + return haystack[$ .. $]; + } + } + else static if (isArray!R) { - alias Unqual!(ElementEncodingType!R) EEType; - EEType[EEType.sizeof == 1 ? 4 : 2] buf; + //10403 optimization + static if (isDefaultPred && isIntegral!EType && EType.sizeof == 1 && isIntegralNeedle) + { + R findHelper(ref R haystack, ref E needle) @trusted nothrow pure + { + EType* ptr = null; + //Note: we use "min/max" to handle sign mismatch. + if (min(EType.min, needle) == EType.min, needle && max(EType.max, needle) == EType.max) + ptr = cast(EType*) memchr(haystack.ptr, needle, haystack.length); - size_t len = encode(buf, needle); - return () @trusted {return std.algorithm.find!pred(haystack, cast(R)buf[0 .. len]);}(); + return ptr ? + haystack[ptr - haystack.ptr .. $] : + haystack[$ .. $]; + } + + if (!__ctfe) + return findHelper(haystack, needle); + } + + //Default implementation. + foreach (i, ref e; haystack) + if (predFun(e, needle)) + return haystack[i .. $]; + return haystack[$ .. $]; } else { - for (; !haystack.empty; haystack.popFront()) + //Everything else. Walk. + for ( ; !haystack.empty; haystack.popFront() ) { - if (binaryFun!pred(haystack.front, needle)) break; + if (predFun(haystack.front, needle)) + break; } return haystack; } } +/// +unittest +{ + assert(find("hello, world", ',') == ", world"); + assert(find([1, 2, 3, 5], 4) == []); + assert(equal(find(SList!int(1, 2, 3, 4, 5)[], 4), SList!int(4, 5)[])); + assert(find!"a > b"([1, 2, 3, 5], 2) == [3, 5]); + + auto a = [ 1, 2, 3 ]; + assert(find(a, 5).empty); // not found + assert(!find(a, 2).empty); // found + + // Case-insensitive find of a string + string[] s = [ "Hello", "world", "!" ]; + assert(!find!("toLower(a) == b")(s, "hello").empty); +} unittest { @@ -3744,6 +3863,75 @@ unittest assert(find([1, 2, 3, 5], 4).empty); assert(equal(find!"a>b"("hello", 'k'), "llo")); } +@safe pure nothrow unittest +{ + int[] a1 = [1, 2, 3]; + assert(find ([1, 2, 3], 2)); + assert(find!((a,b)=>a==b)([1, 2, 3], 2)); + ubyte[] a2 = [1, 2, 3]; + ubyte b2 = 2; + assert(find ([1, 2, 3], 2)); + assert(find!((a,b)=>a==b)([1, 2, 3], 2)); +} +@safe pure unittest +{ + foreach(R; TypeTuple!(string, wstring, dstring)) + { + foreach(E; TypeTuple!(char, wchar, dchar)) + { + R r1 = "hello world"; + E e1 = 'w'; + assert(find ("hello world", 'w') == "world"); + assert(find!((a,b)=>a==b)("hello world", 'w') == "world"); + R r2 = "日c語"; + E e2 = 'c'; + assert(find ("日c語", 'c') == "c語"); + assert(find!((a,b)=>a==b)("日c語", 'c') == "c語"); + static if (E.sizeof >= 2) + { + R r3 = "hello world"; + E e3 = 'w'; + assert(find ("日本語", '本') == "本語"); + assert(find!((a,b)=>a==b)("日本語", '本') == "本語"); + } + } + } +} +unittest +{ + //CTFE + static assert (find("abc", 'b') == "bc"); + static assert (find("日b語", 'b') == "b語"); + static assert (find("日本語", '本') == "本語"); + static assert (find([1, 2, 3], 2) == [2, 3]); + + int[] a1 = [1, 2, 3]; + static assert(find ([1, 2, 3], 2)); + static assert(find!((a,b)=>a==b)([1, 2, 3], 2)); + ubyte[] a2 = [1, 2, 3]; + ubyte b2 = 2; + static assert(find ([1, 2, 3], 2)); + static assert(find!((a,b)=>a==b)([1, 2, 3], 2)); +} +unittest +{ + void dg() pure @safe nothrow + { + byte[] sarr = [1, 2, 3, 4]; + ubyte[] uarr = [1, 2, 3, 4]; + foreach(arr; TypeTuple!(sarr, uarr)) + { + foreach(T; TypeTuple!(byte, ubyte, int, uint)) + { + assert(find(arr, cast(T) 3) == arr[2 .. $]); + assert(find(arr, cast(T) 9) == arr[$ .. $]); + } + assert(find(arr, 256) == arr[$ .. $]); + } + } + dg(); + assertCTFEable!dg; +} /** Finds a forward range in another. Elements are compared for @@ -4360,22 +4548,67 @@ assert(find!(pred)(arr) == arr); See_Also: $(WEB sgi.com/tech/stl/find_if.html, STL's find_if) */ -Range find(alias pred, Range)(Range haystack) if (isInputRange!(Range)) +InputRange find(alias pred, InputRange)(InputRange haystack) +if (isInputRange!InputRange) { - alias unaryFun!(pred) predFun; - for (; !haystack.empty && !predFun(haystack.front); haystack.popFront()) + alias R = InputRange; + alias predFun = unaryFun!pred; + static if (isNarrowString!R) + { + immutable len = haystack.length; + size_t i = 0, next = 0; + while (next < len) + { + if (predFun(decode(haystack, next))) + return haystack[i .. $]; + i = next; + } + return haystack[$ .. $]; + } + else static if (!isInfinite!R && hasSlicing!R && is(typeof(haystack[cast(size_t)0 .. $]))) { + size_t i = 0; + foreach (ref e; haystack) + { + if (predFun(e)) + return haystack[i .. $]; + ++i; + } + return haystack[$ .. $]; + } + else + { + //standard range + for ( ; !haystack.empty; haystack.popFront() ) + { + if (predFun(haystack.front)) + break; + } + return haystack; } - return haystack; } +/// unittest +{ + auto arr = [ 1, 2, 3, 4, 1 ]; + assert(find!("a > 2")(arr) == [ 3, 4, 1 ]); + + // with predicate alias + bool pred(int x) { return x + 1 > 1.5; } + assert(find!(pred)(arr) == arr); +} + +@safe pure unittest { //scope(success) writeln("unittest @", __FILE__, ":", __LINE__, " done."); - int[] a = [ 1, 2, 3 ]; - assert(find!("a > 2")(a) == [3]); + int[] r = [ 1, 2, 3 ]; + assert(find!(a=>a > 2)(r) == [3]); bool pred(int x) { return x + 1 > 1.5; } - assert(find!(pred)(a) == a); + assert(find!(pred)(r) == r); + + assert(find!(a=>a > 'v')("hello world") == "world"); + assert(find!(a=>a%4 == 0)("日本語") == "本語"); } // findSkip