Skip to content

Commit 45de6e5

Browse files
authored
Add type inference for literal equality in guards (#15041)
1 parent 15f2165 commit 45de6e5

6 files changed

Lines changed: 271 additions & 63 deletions

File tree

lib/elixir/lib/module/types/apply.ex

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ defmodule Module.Types.Apply do
477477
remote_domain(mod, fun, args, expected, elem(expr, 1), stack, context)
478478
end
479479

480+
@number union(integer(), float())
480481
@empty_list empty_list()
481482
@non_empty_list non_empty_list(term())
482483
@empty_map empty_map()
@@ -522,21 +523,65 @@ defmodule Module.Types.Apply do
522523

523524
{actual, context} = of_fun.(arg, expected, expr, stack, context)
524525
result = if compatible?(actual, expected), do: return, else: boolean()
526+
527+
# We can skip return compare because literal is always an integer,
528+
# so it cannot be a disjoint comparison
525529
{result, context}
526530
end
527531
end
528532

529-
defp custom_compare(name, left, right, _expected, expr, stack, context, of_fun) do
530-
compare(name, left, right, false, expr, stack, context, of_fun)
533+
defp custom_compare(name, arg, literal, expected, expr, stack, context, of_fun) do
534+
case booleaness(expected) do
535+
booleaness when booleaness in [:maybe_both, :none] ->
536+
compare(name, arg, literal, false, expr, stack, context, of_fun)
537+
538+
booleaness ->
539+
{literal_type, context} = of_fun.(literal, term(), expr, stack, context)
540+
541+
{polarity, return} =
542+
case booleaness do
543+
:maybe_true -> {name in [:==, :"=:="], @atom_true}
544+
:maybe_false -> {name in [:"/=", :"=/="], @atom_false}
545+
end
546+
547+
# If it is a singleton, we can always be precise
548+
if singleton?(literal_type) do
549+
expected = if polarity, do: literal_type, else: negation(literal_type)
550+
{arg_type, context} = of_fun.(arg, expected, expr, stack, context)
551+
result = if compatible?(arg_type, expected), do: return, else: boolean()
552+
553+
# Because reverse polarity means we will infer negated types
554+
# (which are naturally disjoint), we skip checks in such cases
555+
skip_check? = not polarity
556+
return_compare(name, arg_type, literal_type, result, skip_check?, expr, stack, context)
557+
else
558+
expected =
559+
cond do
560+
# We are checking for `not x == 1` or similar, we can't say anything about x
561+
polarity == false -> term()
562+
# We are checking for `x == 1`, make sure x is integer or float
563+
number_type?(literal_type) and name in [:==, :"/="] -> union(literal_type, @number)
564+
# Otherwise we have the literal type as is
565+
true -> literal_type
566+
end
567+
568+
{arg_type, context} = of_fun.(arg, expected, expr, stack, context)
569+
return_compare(name, arg_type, literal_type, boolean(), false, expr, stack, context)
570+
end
571+
end
531572
end
532573

533-
defp compare(name, left, right, literal?, expr, stack, context, of_fun) do
574+
defp compare(name, left, right, both_literal?, expr, stack, context, of_fun) do
534575
{left_type, context} = of_fun.(left, term(), expr, stack, context)
535576
{right_type, context} = of_fun.(right, term(), expr, stack, context)
536-
result = return(boolean(), [left_type, right_type], stack)
577+
return_compare(name, left_type, right_type, boolean(), both_literal?, expr, stack, context)
578+
end
579+
580+
defp return_compare(name, left_type, right_type, result, skip_check?, expr, stack, context) do
581+
result = return(result, [left_type, right_type], stack)
537582

538583
cond do
539-
literal? or not is_warning(stack) ->
584+
skip_check? or not is_warning(stack) ->
540585
{result, context}
541586

542587
name in [:==, :"/="] and number_type?(left_type) and number_type?(right_type) ->

lib/elixir/lib/module/types/descr.ex

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ defmodule Module.Types.Descr do
270270
end
271271

272272
defp unwrap_domain_tuple(%{tuple: bdd} = descr, transform) when map_size(descr) == 1 do
273-
tuple_normalize(bdd) |> Enum.map(transform)
273+
tuple_bdd_to_dnf(bdd) |> Enum.map(transform)
274274
end
275275

276276
defp unwrap_domain_tuple(descr, _transform) when descr == %{}, do: []
@@ -604,6 +604,56 @@ defmodule Module.Types.Descr do
604604
defp empty_key?(:tuple, value), do: tuple_empty?(value)
605605
defp empty_key?(_, _value), do: false
606606

607+
@doc """
608+
Returns if the type is a singleton.
609+
"""
610+
def singleton?(:term), do: false
611+
def singleton?(descr), do: static_singleton?(Map.get(descr, :dynamic, descr))
612+
613+
defp static_singleton?(:term), do: false
614+
defp static_singleton?(%{optional: _}), do: false
615+
defp static_singleton?(%{list: _}), do: false
616+
defp static_singleton?(%{fun: _}), do: false
617+
defp static_singleton?(descr), do: each_singleton?(descr, [:atom, :bitmap, :map, :tuple], false)
618+
619+
defp each_singleton?(descr, [key | keys], acc) do
620+
case descr do
621+
%{^key => value} ->
622+
case each_singleton?(key, value) do
623+
true when acc == true -> false
624+
true -> each_singleton?(descr, keys, true)
625+
false -> false
626+
:empty -> each_singleton?(descr, keys, acc)
627+
end
628+
629+
%{} ->
630+
each_singleton?(descr, keys, acc)
631+
end
632+
end
633+
634+
defp each_singleton?(_descr, [], acc), do: acc
635+
636+
# Implement for each type
637+
defp each_singleton?(:bitmap, bitmap), do: bitmap == @bit_empty_list
638+
639+
defp each_singleton?(:atom, atoms), do: match?({:union, set} when map_size(set) == 1, atoms)
640+
641+
defp each_singleton?(:tuple, bdd) do
642+
case tuple_bdd_to_dnf(bdd) do
643+
[] -> :empty
644+
[{:closed, entries}] -> Enum.all?(entries, &static_singleton?/1)
645+
_ -> false
646+
end
647+
end
648+
649+
defp each_singleton?(:map, bdd) do
650+
case map_bdd_to_dnf(bdd) do
651+
[] -> :empty
652+
[{:closed, fields, _negs}] -> Enum.all?(fields, fn {_, v} -> static_singleton?(v) end)
653+
_ -> false
654+
end
655+
end
656+
607657
@doc """
608658
Converts a descr to its quoted representation.
609659
@@ -3850,15 +3900,6 @@ defmodule Module.Types.Descr do
38503900
end)
38513901
end
38523902

3853-
# Use heuristics to normalize a map bdd for pretty printing.
3854-
defp map_normalize(bdd) do
3855-
map_bdd_to_dnf(bdd)
3856-
|> Enum.map(fn {tag, fields, negs} ->
3857-
map_eliminate_while_negs_decrease(tag, fields, negs)
3858-
end)
3859-
|> map_fusion()
3860-
end
3861-
38623903
# Continue to eliminate negations while length of list of negs decreases
38633904
defp map_eliminate_while_negs_decrease(tag, fields, []), do: {tag, fields, []}
38643905

@@ -3950,7 +3991,12 @@ defmodule Module.Types.Descr do
39503991
end
39513992

39523993
defp map_to_quoted(bdd, opts) do
3953-
map_normalize(bdd)
3994+
bdd
3995+
|> map_bdd_to_dnf()
3996+
|> Enum.map(fn {tag, fields, negs} ->
3997+
map_eliminate_while_negs_decrease(tag, fields, negs)
3998+
end)
3999+
|> map_fusion()
39544000
|> Enum.map(&map_each_to_quoted(&1, opts))
39554001
end
39564002

@@ -4472,14 +4518,14 @@ defmodule Module.Types.Descr do
44724518
end
44734519

44744520
defp tuple_to_quoted(bdd, opts) do
4475-
tuple_normalize(bdd)
4521+
tuple_bdd_to_dnf(bdd)
44764522
|> tuple_fusion()
44774523
|> Enum.map(&tuple_literal_to_quoted(&1, opts))
44784524
end
44794525

44804526
# Transforms a bdd into a union of tuples with no negations.
44814527
# Note: it is important to compose the results with tuple_dnf_union/2 to avoid duplicates
4482-
defp tuple_normalize(bdd) do
4528+
defp tuple_bdd_to_dnf(bdd) do
44834529
bdd_to_dnf(bdd)
44844530
|> Enum.reduce([], fn {positive_tuples, negative_tuples}, acc ->
44854531
case non_empty_tuple_literals_intersection(positive_tuples) do
@@ -4639,7 +4685,7 @@ defmodule Module.Types.Descr do
46394685
end
46404686

46414687
defp tuple_get(bdd, index) do
4642-
tuple_normalize(bdd)
4688+
tuple_bdd_to_dnf(bdd)
46434689
|> Enum.reduce(none(), fn
46444690
{tag, elements}, acc -> Enum.at(elements, index, tuple_tag_to_type(tag)) |> union(acc)
46454691
end)
@@ -4670,7 +4716,7 @@ defmodule Module.Types.Descr do
46704716
end
46714717

46724718
defp process_tuples_values(bdd) do
4673-
tuple_normalize(bdd)
4719+
tuple_bdd_to_dnf(bdd)
46744720
|> Enum.reduce(none(), fn {tag, elements}, acc ->
46754721
cond do
46764722
Enum.any?(elements, &empty?/1) -> none()
@@ -4808,7 +4854,7 @@ defmodule Module.Types.Descr do
48084854
defp tuple_of_size_at_least_static?(descr, index) do
48094855
case descr do
48104856
%{tuple: bdd} ->
4811-
tuple_normalize(bdd)
4857+
tuple_bdd_to_dnf(bdd)
48124858
|> Enum.all?(fn {_, elements} -> length(elements) >= index end)
48134859

48144860
%{} ->

lib/elixir/test/elixir/application_test.exs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,11 @@ defmodule ApplicationTest do
163163
assert is_list(Application.spec(:elixir))
164164
assert Application.spec(:unknown) == nil
165165
assert Application.spec(:unknown, :description) == nil
166-
167166
assert Application.spec(:elixir, :description) == ~c"elixir"
168-
assert_raise FunctionClauseError, fn -> Application.spec(:elixir, :unknown) end
167+
168+
assert_raise FunctionClauseError, fn ->
169+
Application.spec(:elixir, Process.get(:unknown, :unknown))
170+
end
169171
end
170172

171173
test "application module" do

lib/elixir/test/elixir/calendar/iso_test.exs

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,6 @@ defmodule Calendar.ISOTest do
142142
assert Calendar.ISO.parse_date("20150123", :extended) == {:error, :invalid_format}
143143
assert Calendar.ISO.parse_date("2015-01-23", :extended) == {:ok, {2015, 1, 23}}
144144
end
145-
146-
test "errors on other format names" do
147-
assert_raise FunctionClauseError, fn ->
148-
Calendar.ISO.parse_date("20150123", :other)
149-
end
150-
151-
assert_raise FunctionClauseError, fn ->
152-
Calendar.ISO.parse_date("2015-01-23", :other)
153-
end
154-
end
155145
end
156146

157147
describe "parse_time/1" do
@@ -225,16 +215,6 @@ defmodule Calendar.ISOTest do
225215
assert Calendar.ISO.parse_time("235007", :extended) == {:error, :invalid_format}
226216
assert Calendar.ISO.parse_time("23:50:07", :extended) == {:ok, {23, 50, 7, {0, 0}}}
227217
end
228-
229-
test "errors on other format names" do
230-
assert_raise FunctionClauseError, fn ->
231-
Calendar.ISO.parse_time("235007", :other)
232-
end
233-
234-
assert_raise FunctionClauseError, fn ->
235-
Calendar.ISO.parse_time("23:50:07", :other)
236-
end
237-
end
238218
end
239219

240220
describe "parse_naive_datetime/1" do
@@ -312,16 +292,6 @@ defmodule Calendar.ISOTest do
312292
assert Calendar.ISO.parse_naive_datetime("2015-01-23 23:50:07.123", :extended) ==
313293
{:ok, {2015, 1, 23, 23, 50, 7, {123_000, 3}}}
314294
end
315-
316-
test "errors on other format names" do
317-
assert_raise FunctionClauseError, fn ->
318-
Calendar.ISO.parse_naive_datetime("20150123 235007.123", :other)
319-
end
320-
321-
assert_raise FunctionClauseError, fn ->
322-
Calendar.ISO.parse_naive_datetime("2015-01-23 23:50:07.123", :other)
323-
end
324-
end
325295
end
326296

327297
describe "parse_utc_datetime/1" do
@@ -400,16 +370,6 @@ defmodule Calendar.ISOTest do
400370
{:ok, {2015, 1, 23, 23, 50, 7, {123_000, 3}}, 0}
401371
end
402372

403-
test "errors on other format names" do
404-
assert_raise FunctionClauseError, fn ->
405-
Calendar.ISO.parse_naive_datetime("20150123 235007.123Z", :other)
406-
end
407-
408-
assert_raise FunctionClauseError, fn ->
409-
Calendar.ISO.parse_naive_datetime("2015-01-23 23:50:07.123Z", :other)
410-
end
411-
end
412-
413373
test "errors on mixed basic and extended formats" do
414374
assert Calendar.ISO.parse_utc_datetime("20150123 23:50:07.123Z", :basic) ==
415375
{:error, :invalid_format}

lib/elixir/test/elixir/module/types/descr_test.exs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,57 @@ defmodule Module.Types.DescrTest do
12221222
end
12231223
end
12241224

1225+
describe "singleton?" do
1226+
test "non-singleton?" do
1227+
refute singleton?(term())
1228+
refute singleton?(none())
1229+
refute singleton?(dynamic())
1230+
refute singleton?(integer())
1231+
refute singleton?(float())
1232+
refute singleton?(pid())
1233+
refute singleton?(reference())
1234+
refute singleton?(fun(1))
1235+
refute singleton?(non_empty_list(atom([:foo])))
1236+
end
1237+
1238+
@disguised_empty_map closed_map(key: atom([:value]))
1239+
|> difference(open_map(key: atom(), optional: if_set(atom())))
1240+
1241+
test "atoms" do
1242+
assert singleton?(atom([:foo]))
1243+
refute singleton?(atom([:foo, :bar]))
1244+
assert singleton?(atom([:foo]) |> union(@disguised_empty_map))
1245+
refute singleton?(atom() |> difference(atom([:foo])))
1246+
end
1247+
1248+
test "empty list" do
1249+
assert singleton?(empty_list())
1250+
refute singleton?(non_empty_list(term()))
1251+
refute singleton?(union(empty_list(), atom([:foo])))
1252+
assert singleton?(union(empty_list(), @disguised_empty_map))
1253+
end
1254+
1255+
test "maps" do
1256+
assert singleton?(empty_map())
1257+
assert singleton?(closed_map(key: atom([:value])))
1258+
assert singleton?(closed_map(key: atom([:value])) |> union(@disguised_empty_map))
1259+
refute singleton?(closed_map(key: binary()))
1260+
refute singleton?(closed_map(key: if_set(atom([:value]))))
1261+
refute singleton?(open_map())
1262+
refute singleton?(open_map(key: atom([:value])))
1263+
refute singleton?(union(closed_map(key: atom([:value])), closed_map(other: atom([:value]))))
1264+
end
1265+
1266+
test "tuples" do
1267+
assert singleton?(tuple([]))
1268+
assert singleton?(tuple([atom([:foo])]))
1269+
refute singleton?(tuple([binary()]))
1270+
refute singleton?(open_tuple([]))
1271+
refute singleton?(union(tuple([atom([:value])]), tuple([atom([:other_value])])))
1272+
refute singleton?(union(tuple([atom([:value])]), closed_map(other: atom([:value]))))
1273+
end
1274+
end
1275+
12251276
describe "projections" do
12261277
test "booleaness" do
12271278
for type <- [none(), open_map(), negation(boolean()), difference(atom(), boolean())] do

0 commit comments

Comments
 (0)