Skip to content

Commit 291d4ad

Browse files
authored
Add guards to the Stream module (#9069)
Before we hadn't added guards to functions in the Stream module because the error message for invoking a function with the wrong number of arguments was better than the FunctionClauseError. Now that FunctionClauseErrors are better and display the failed clauses and everything, having guards in Stream functions means that we can fail fast in many cases instead of waiting to fail until the stream is realized.
1 parent daf47dd commit 291d4ad

File tree

1 file changed

+50
-36
lines changed

1 file changed

+50
-36
lines changed

lib/elixir/lib/stream.ex

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ defmodule Stream do
207207
208208
"""
209209
@spec chunk_by(Enumerable.t(), (element -> any)) :: Enumerable.t()
210-
def chunk_by(enum, fun) do
210+
def chunk_by(enum, fun) when is_function(fun, 1) do
211211
R.chunk_by(&chunk_while/4, enum, fun)
212212
end
213213

@@ -248,7 +248,8 @@ defmodule Stream do
248248
(acc -> {:cont, chunk, acc} | {:cont, acc})
249249
) :: Enumerable.t()
250250
when chunk: any
251-
def chunk_while(enum, acc, chunk_fun, after_fun) do
251+
def chunk_while(enum, acc, chunk_fun, after_fun)
252+
when is_function(chunk_fun, 2) and is_function(after_fun, 1) do
252253
lazy(
253254
enum,
254255
[acc | after_fun],
@@ -314,7 +315,7 @@ defmodule Stream do
314315
315316
"""
316317
@spec dedup_by(Enumerable.t(), (element -> term)) :: Enumerable.t()
317-
def dedup_by(enum, fun) do
318+
def dedup_by(enum, fun) when is_function(fun, 1) do
318319
lazy(enum, nil, fn f1 -> R.dedup(fun, f1) end)
319320
end
320321

@@ -411,7 +412,7 @@ defmodule Stream do
411412
412413
"""
413414
@spec drop_while(Enumerable.t(), (element -> as_boolean(term))) :: Enumerable.t()
414-
def drop_while(enum, fun) do
415+
def drop_while(enum, fun) when is_function(fun, 1) do
415416
lazy(enum, true, fn f1 -> R.drop_while(fun, f1) end)
416417
end
417418

@@ -433,7 +434,7 @@ defmodule Stream do
433434
434435
"""
435436
@spec each(Enumerable.t(), (element -> term)) :: Enumerable.t()
436-
def each(enum, fun) do
437+
def each(enum, fun) when is_function(fun, 1) do
437438
lazy(enum, fn f1 ->
438439
fn x, acc ->
439440
fun.(x)
@@ -460,7 +461,7 @@ defmodule Stream do
460461
461462
"""
462463
@spec flat_map(Enumerable.t(), (element -> Enumerable.t())) :: Enumerable.t()
463-
def flat_map(enum, mapper) do
464+
def flat_map(enum, mapper) when is_function(mapper, 1) do
464465
transform(enum, nil, fn val, nil -> {mapper.(val), nil} end)
465466
end
466467

@@ -476,7 +477,7 @@ defmodule Stream do
476477
477478
"""
478479
@spec filter(Enumerable.t(), (element -> as_boolean(term))) :: Enumerable.t()
479-
def filter(enum, fun) do
480+
def filter(enum, fun) when is_function(fun, 1) do
480481
lazy(enum, fn f1 -> R.filter(fun, f1) end)
481482
end
482483

@@ -505,7 +506,7 @@ defmodule Stream do
505506
506507
"""
507508
@spec interval(non_neg_integer) :: Enumerable.t()
508-
def interval(n) do
509+
def interval(n) when is_integer(n) and n >= 0 do
509510
unfold(0, fn count ->
510511
Process.sleep(n)
511512
{count, count + 1}
@@ -519,7 +520,7 @@ defmodule Stream do
519520
is delayed until the stream is executed. See `run/1` for an example.
520521
"""
521522
@spec into(Enumerable.t(), Collectable.t(), (term -> term)) :: Enumerable.t()
522-
def into(enum, collectable, transform \\ fn x -> x end) do
523+
def into(enum, collectable, transform \\ fn x -> x end) when is_function(transform, 1) do
523524
&do_into(enum, collectable, transform, &1, &2)
524525
end
525526

@@ -564,7 +565,7 @@ defmodule Stream do
564565
565566
"""
566567
@spec map(Enumerable.t(), (element -> any)) :: Enumerable.t()
567-
def map(enum, fun) do
568+
def map(enum, fun) when is_function(fun, 1) do
568569
lazy(enum, fn f1 -> R.map(fun, f1) end)
569570
end
570571

@@ -593,13 +594,15 @@ defmodule Stream do
593594
"""
594595
@doc since: "1.4.0"
595596
@spec map_every(Enumerable.t(), non_neg_integer, (element -> any)) :: Enumerable.t()
596-
def map_every(enum, nth, fun)
597+
def map_every(enum, nth, fun) when is_integer(nth) and nth >= 0 and is_function(fun, 1) do
598+
map_every_after_guards(enum, nth, fun)
599+
end
597600

598-
def map_every(enum, 1, fun), do: map(enum, fun)
599-
def map_every(enum, 0, _fun), do: %Stream{enum: enum}
600-
def map_every([], _nth, _fun), do: %Stream{enum: []}
601+
defp map_every_after_guards(enum, 1, fun), do: map(enum, fun)
602+
defp map_every_after_guards(enum, 0, _fun), do: %Stream{enum: enum}
603+
defp map_every_after_guards([], _nth, _fun), do: %Stream{enum: []}
601604

602-
def map_every(enum, nth, fun) when is_integer(nth) and nth > 0 do
605+
defp map_every_after_guards(enum, nth, fun) do
603606
lazy(enum, nth, fn f1 -> R.map_every(nth, fun, f1) end)
604607
end
605608

@@ -615,7 +618,7 @@ defmodule Stream do
615618
616619
"""
617620
@spec reject(Enumerable.t(), (element -> as_boolean(term))) :: Enumerable.t()
618-
def reject(enum, fun) do
621+
def reject(enum, fun) when is_function(fun, 1) do
619622
lazy(enum, fn f1 -> R.reject(fun, f1) end)
620623
end
621624

@@ -658,7 +661,7 @@ defmodule Stream do
658661
659662
"""
660663
@spec scan(Enumerable.t(), (element, acc -> any)) :: Enumerable.t()
661-
def scan(enum, fun) do
664+
def scan(enum, fun) when is_function(fun, 2) do
662665
lazy(enum, :first, fn f1 -> R.scan2(fun, f1) end)
663666
end
664667

@@ -675,7 +678,7 @@ defmodule Stream do
675678
676679
"""
677680
@spec scan(Enumerable.t(), acc, (element, acc -> any)) :: Enumerable.t()
678-
def scan(enum, acc, fun) do
681+
def scan(enum, acc, fun) when is_function(fun, 2) do
679682
lazy(enum, acc, fn f1 -> R.scan3(fun, f1) end)
680683
end
681684

@@ -705,14 +708,19 @@ defmodule Stream do
705708
706709
"""
707710
@spec take(Enumerable.t(), integer) :: Enumerable.t()
708-
def take(_enum, 0), do: %Stream{enum: []}
709-
def take([], _count), do: %Stream{enum: []}
711+
def take(enum, count) when is_integer(count) do
712+
take_after_guards(enum, count)
713+
end
714+
715+
defp take_after_guards(_enum, 0), do: %Stream{enum: []}
716+
717+
defp take_after_guards([], _count), do: %Stream{enum: []}
710718

711-
def take(enum, count) when is_integer(count) and count > 0 do
719+
defp take_after_guards(enum, count) when count > 0 do
712720
lazy(enum, count, fn f1 -> R.take(f1) end)
713721
end
714722

715-
def take(enum, count) when is_integer(count) and count < 0 do
723+
defp take_after_guards(enum, count) when count < 0 do
716724
&Enumerable.reduce(Enum.take(enum, count), &1, &2)
717725
end
718726

@@ -739,11 +747,15 @@ defmodule Stream do
739747
740748
"""
741749
@spec take_every(Enumerable.t(), non_neg_integer) :: Enumerable.t()
742-
def take_every(enum, nth)
743-
def take_every(_enum, 0), do: %Stream{enum: []}
744-
def take_every([], _nth), do: %Stream{enum: []}
750+
def take_every(enum, nth) when is_integer(nth) and nth >= 0 do
751+
take_every_after_guards(enum, nth)
752+
end
753+
754+
defp take_every_after_guards(_enum, 0), do: %Stream{enum: []}
755+
756+
defp take_every_after_guards([], _nth), do: %Stream{enum: []}
745757

746-
def take_every(enum, nth) when is_integer(nth) and nth > 0 do
758+
defp take_every_after_guards(enum, nth) do
747759
lazy(enum, nth, fn f1 -> R.take_every(nth, f1) end)
748760
end
749761

@@ -759,7 +771,7 @@ defmodule Stream do
759771
760772
"""
761773
@spec take_while(Enumerable.t(), (element -> as_boolean(term))) :: Enumerable.t()
762-
def take_while(enum, fun) do
774+
def take_while(enum, fun) when is_function(fun, 1) do
763775
lazy(enum, fn f1 -> R.take_while(fun, f1) end)
764776
end
765777

@@ -776,7 +788,7 @@ defmodule Stream do
776788
777789
"""
778790
@spec timer(non_neg_integer) :: Enumerable.t()
779-
def timer(n) do
791+
def timer(n) when is_integer(n) and n >= 0 do
780792
take(interval(n), 1)
781793
end
782794

@@ -810,7 +822,7 @@ defmodule Stream do
810822
@spec transform(Enumerable.t(), acc, fun) :: Enumerable.t()
811823
when fun: (element, acc -> {Enumerable.t(), acc} | {:halt, acc}),
812824
acc: any
813-
def transform(enum, acc, reducer) do
825+
def transform(enum, acc, reducer) when is_function(reducer, 2) do
814826
&do_transform(enum, fn -> acc end, reducer, &1, &2, nil)
815827
end
816828

@@ -827,7 +839,8 @@ defmodule Stream do
827839
@spec transform(Enumerable.t(), (() -> acc), fun, (acc -> term)) :: Enumerable.t()
828840
when fun: (element, acc -> {Enumerable.t(), acc} | {:halt, acc}),
829841
acc: any
830-
def transform(enum, start_fun, reducer, after_fun) do
842+
def transform(enum, start_fun, reducer, after_fun)
843+
when is_function(start_fun, 0) and is_function(reducer, 2) and is_function(after_fun, 1) do
831844
&do_transform(enum, start_fun, reducer, &1, &2, after_fun)
832845
end
833846

@@ -1023,7 +1036,7 @@ defmodule Stream do
10231036
10241037
"""
10251038
@spec uniq_by(Enumerable.t(), (element -> term)) :: Enumerable.t()
1026-
def uniq_by(enum, fun) do
1039+
def uniq_by(enum, fun) when is_function(fun, 1) do
10271040
lazy(enum, %{}, fn f1 -> R.uniq_by(fun, f1) end)
10281041
end
10291042

@@ -1045,7 +1058,7 @@ defmodule Stream do
10451058
10461059
"""
10471060
@spec with_index(Enumerable.t(), integer) :: Enumerable.t()
1048-
def with_index(enum, offset \\ 0) do
1061+
def with_index(enum, offset \\ 0) when is_integer(offset) do
10491062
lazy(enum, offset, fn f1 -> R.with_index(f1) end)
10501063
end
10511064

@@ -1303,7 +1316,7 @@ defmodule Stream do
13031316
13041317
"""
13051318
@spec iterate(element, (element -> element)) :: Enumerable.t()
1306-
def iterate(start_value, next_fun) do
1319+
def iterate(start_value, next_fun) when is_function(next_fun, 1) do
13071320
unfold({:ok, start_value}, fn
13081321
{:ok, value} ->
13091322
{value, {:next, value}}
@@ -1326,7 +1339,7 @@ defmodule Stream do
13261339
13271340
"""
13281341
@spec repeatedly((() -> element)) :: Enumerable.t()
1329-
def repeatedly(generator_fun) do
1342+
def repeatedly(generator_fun) when is_function(generator_fun, 0) do
13301343
&do_repeatedly(generator_fun, &1, &2)
13311344
end
13321345

@@ -1374,7 +1387,8 @@ defmodule Stream do
13741387
"""
13751388
@spec resource((() -> acc), (acc -> {[element], acc} | {:halt, acc}), (acc -> term)) ::
13761389
Enumerable.t()
1377-
def resource(start_fun, next_fun, after_fun) do
1390+
def resource(start_fun, next_fun, after_fun)
1391+
when is_function(start_fun, 0) and is_function(next_fun, 1) and is_function(after_fun, 1) do
13781392
&do_resource(start_fun.(), next_fun, &1, &2, after_fun)
13791393
end
13801394

@@ -1482,7 +1496,7 @@ defmodule Stream do
14821496
14831497
"""
14841498
@spec unfold(acc, (acc -> {element, acc} | nil)) :: Enumerable.t()
1485-
def unfold(next_acc, next_fun) do
1499+
def unfold(next_acc, next_fun) when is_function(next_fun, 1) do
14861500
&do_unfold(next_acc, next_fun, &1, &2)
14871501
end
14881502

0 commit comments

Comments
 (0)