From e34a3422cc32c808d2e8b0e0ef51112d53fa896d Mon Sep 17 00:00:00 2001 From: Sebastiaan van Stijn Date: Sun, 24 Aug 2025 18:43:33 +0200 Subject: [PATCH] templates: make "join" work with non-string slices and map values Add a custom join function that allows for non-string slices to be joined, following the same rules as "fmt.Sprint", it will use the fmt.Stringer interface if implemented, or "error" if the type has an "Error()". For maps, it joins the map-values, for example: docker image inspect --format '{{join .Config.Labels ", "}}' ubuntu 24.04, ubuntu Signed-off-by: Sebastiaan van Stijn --- templates/templates.go | 42 ++++++++++++++++- templates/templates_test.go | 90 +++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/templates/templates.go b/templates/templates.go index e339050abb3f..1145e82c82ad 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -6,6 +6,9 @@ package templates import ( "bytes" "encoding/json" + "fmt" + "reflect" + "sort" "strings" "text/template" ) @@ -15,7 +18,7 @@ import ( var basicFunctions = template.FuncMap{ "json": formatJSON, "split": strings.Split, - "join": strings.Join, + "join": joinElements, "title": strings.Title, //nolint:nolintlint,staticcheck // strings.Title is deprecated, but we only use it for ASCII, so replacing with golang.org/x/text is out of scope "lower": strings.ToLower, "upper": strings.ToUpper, @@ -97,3 +100,40 @@ func formatJSON(v any) string { // Remove the trailing new line added by the encoder return strings.TrimSpace(buf.String()) } + +// joinElements joins a slice of items with the given separator. It uses +// [strings.Join] if it's a slice of strings, otherwise uses [fmt.Sprint] +// to join each item to the output. +func joinElements(elems any, sep string) (string, error) { + if elems == nil { + return "", nil + } + + if ss, ok := elems.([]string); ok { + return strings.Join(ss, sep), nil + } + + switch rv := reflect.ValueOf(elems); rv.Kind() { //nolint:exhaustive // ignore: too many options to make exhaustive + case reflect.Array, reflect.Slice: + var b strings.Builder + for i := range rv.Len() { + if i > 0 { + b.WriteString(sep) + } + _, _ = fmt.Fprint(&b, rv.Index(i).Interface()) + } + return b.String(), nil + + case reflect.Map: + var out []string + for _, k := range rv.MapKeys() { + out = append(out, fmt.Sprint(rv.MapIndex(k).Interface())) + } + // Not ideal, but trying to keep a consistent order + sort.Strings(out) + return strings.Join(out, sep), nil + + default: + return "", fmt.Errorf("expected slice, got %T", elems) + } +} diff --git a/templates/templates_test.go b/templates/templates_test.go index e9dbaefd0e5e..ed1ee5b95d13 100644 --- a/templates/templates_test.go +++ b/templates/templates_test.go @@ -3,6 +3,7 @@ package templates import ( "bytes" "testing" + "text/template" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" @@ -139,3 +140,92 @@ func TestHeaderFunctions(t *testing.T) { }) } } + +type stringerString string + +func (s stringerString) String() string { + return "stringer" + string(s) +} + +type stringerAndError string + +func (s stringerAndError) String() string { + return "stringer" + string(s) +} + +func (s stringerAndError) Error() string { + return "error" + string(s) +} + +func TestJoinElements(t *testing.T) { + tests := []struct { + doc string + data any + expOut string + expErr string + }{ + { + doc: "nil", + data: nil, + expOut: `output: ""`, + }, + { + doc: "non-slice", + data: "hello", + expOut: `output: "`, + expErr: `error calling join: expected slice, got string`, + }, + { + doc: "structs", + data: []struct{ A, B string }{{"1", "2"}, {"3", "4"}}, + expOut: `output: "{1 2}, {3 4}"`, + }, + { + doc: "map with strings", + data: map[string]string{"A": "1", "B": "2", "C": "3"}, + expOut: `output: "1, 2, 3"`, + }, + { + doc: "map with stringers", + data: map[string]stringerString{"A": "1", "B": "2", "C": "3"}, + expOut: `output: "stringer1, stringer2, stringer3"`, + }, + { + doc: "map with errors", + data: []stringerAndError{"1", "2", "3"}, + expOut: `output: "error1, error2, error3"`, + }, + { + doc: "stringers", + data: []stringerString{"1", "2", "3"}, + expOut: `output: "stringer1, stringer2, stringer3"`, + }, + { + doc: "stringer with errors", + data: []stringerAndError{"1", "2", "3"}, + expOut: `output: "error1, error2, error3"`, + }, + { + doc: "slice of bools", + data: []bool{true, false, true}, + expOut: `output: "true, false, true"`, + }, + } + + const formatStr = `output: "{{- join . ", " -}}"` + tmpl, err := New("my-template").Funcs(template.FuncMap{"join": joinElements}).Parse(formatStr) + assert.NilError(t, err) + + for _, tc := range tests { + t.Run(tc.doc, func(t *testing.T) { + var b bytes.Buffer + err := tmpl.Execute(&b, tc.data) + if tc.expErr != "" { + assert.ErrorContains(t, err, tc.expErr) + } else { + assert.NilError(t, err) + } + assert.Equal(t, b.String(), tc.expOut) + }) + } +}