diff --git a/builtin/builtin.go b/builtin/builtin.go index 87e73614..6f88e40a 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -1034,6 +1034,45 @@ var Builtins = []*Function{ new(func([]int) []any), ), }, + { + Name: "merge", + Safe: func(args ...any) (any, uint, error) { + if len(args) < 2 { + return nil, 0, fmt.Errorf("invalid number of arguments (expected at least 2, got %d)", len(args)) + } + + out := make(map[string]any) + + for _, arg := range args { + v := reflect.ValueOf(arg) + + if v.Kind() != reflect.Map { + return nil, 0, fmt.Errorf("cannot merge %s", v.Kind()) + } + + for _, key := range v.MapKeys() { + out[fmt.Sprint(key.Interface())] = v.MapIndex(key).Interface() + } + } + + return out, uint(len(out)), nil + }, + Validate: func(args []reflect.Type) (reflect.Type, error) { + if len(args) < 2 { + return anyType, fmt.Errorf("invalid number of arguments (expected at least 2, got %d)", len(args)) + } + + for _, arg := range args { + switch kind(arg) { + case reflect.Interface, reflect.Map: + default: + return anyType, fmt.Errorf("cannot merge %s", arg) + } + } + + return mapType, nil + }, + }, bitFunc("bitand", func(x, y int) (any, error) { return x & y, nil }), diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index 0d0dec35..2de63e4f 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -197,6 +197,12 @@ func TestBuiltin(t *testing.T) { {`flatten([["a", "b"], [1, 2, [3, [[[["c", "d"], "e"]]], 4]]])`, []any{"a", "b", 1, 2, 3, "c", "d", "e", 4}}, {`uniq([1, 15, "a", 2, 3, 5, 2, "a", 2, "b"])`, []any{1, 15, "a", 2, 3, 5, "b"}}, {`uniq([[1, 2], "a", 2, 3, [1, 2], [1, 3]])`, []any{[]any{1, 2}, "a", 2, 3, []any{1, 3}}}, + {`merge({"a": 1}, {"b": 2})`, map[string]any{"a": 1, "b": 2}}, + {`merge({"a": 1, "b": 2}, {"b": 3})`, map[string]any{"a": 1, "b": 3}}, + {`merge({"a": 1}, {"b": 2}, {"c": 3})`, map[string]any{"a": 1, "b": 2, "c": 3}}, + {`merge({"a": 1}, {"a": 2})`, map[string]any{"a": 2}}, + {`merge({}, {"a": 1})`, map[string]any{"a": 1}}, + {`merge({"a": 1}, {})`, map[string]any{"a": 1}}, } for _, test := range tests { @@ -219,6 +225,7 @@ func TestBuiltin_works_with_any(t *testing.T) { "get": {2}, "take": {2}, "sortBy": {2}, + "merge": {2}, } for _, b := range builtin.Builtins { @@ -284,6 +291,10 @@ func TestBuiltin_errors(t *testing.T) { {`flatten([1, 2], [3, 4])`, "invalid number of arguments (expected 1, got 2)"}, {`flatten(1)`, "cannot flatten int"}, {`fromJSON("5e2482")`, "cannot unmarshal number"}, + {`merge()`, "invalid number of arguments (expected at least 2, got 0)"}, + {`merge({"a": 1})`, "invalid number of arguments (expected at least 2, got 1)"}, + {`merge(1, {"a": 1})`, "cannot merge int"}, + {`merge({"a": 1}, 2)`, "cannot merge int"}, } for _, test := range errorTests { t.Run(test.input, func(t *testing.T) { diff --git a/test/issues/895/issue_test.go b/test/issues/895/issue_test.go new file mode 100644 index 00000000..0a72c86e --- /dev/null +++ b/test/issues/895/issue_test.go @@ -0,0 +1,56 @@ +package main + +import ( + "testing" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/internal/testify/require" +) + +func TestIssue895(t *testing.T) { + env := map[string]any{ + "a": map[string]any{"a": 1, "b": 2}, + "b": map[string]any{"b": 3, "c": 4}, + } + + program, err := expr.Compile(`merge(a, b)`, expr.Env(env)) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, map[string]any{"a": 1, "b": 3, "c": 4}, output) +} + +func TestIssue895_multiple_maps(t *testing.T) { + env := map[string]any{ + "a": map[string]any{"x": 1}, + "b": map[string]any{"y": 2}, + "c": map[string]any{"z": 3}, + } + + program, err := expr.Compile(`merge(a, b, c)`, expr.Env(env)) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, map[string]any{"x": 1, "y": 2, "z": 3}, output) +} + +func TestIssue895_does_not_modify_input(t *testing.T) { + a := map[string]any{"a": 1} + b := map[string]any{"b": 2} + env := map[string]any{ + "a": a, + "b": b, + } + + program, err := expr.Compile(`merge(a, b)`, expr.Env(env)) + require.NoError(t, err) + + _, err = expr.Run(program, env) + require.NoError(t, err) + + // Original maps must be unmodified. + require.Equal(t, map[string]any{"a": 1}, a) + require.Equal(t, map[string]any{"b": 2}, b) +}