Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions decode_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,30 @@ func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc {
return nil
}

// cachedDecodeHook takes a raw DecodeHookFunc (an interface{}) and turns
// it into a closure to be used directly
// if the type fails to convert we return a closure always erroring to keep the previous behaviour
func cachedDecodeHook(raw DecodeHookFunc) func(from reflect.Value, to reflect.Value) (interface{}, error) {
switch f := typedDecodeHook(raw).(type) {
case DecodeHookFuncType:
return func(from reflect.Value, to reflect.Value) (interface{}, error) {
return f(from.Type(), to.Type(), from.Interface())
}
case DecodeHookFuncKind:
return func(from reflect.Value, to reflect.Value) (interface{}, error) {
return f(from.Kind(), to.Kind(), from.Interface())
}
case DecodeHookFuncValue:
return func(from reflect.Value, to reflect.Value) (interface{}, error) {
return f(from, to)
}
default:
return func(from reflect.Value, to reflect.Value) (interface{}, error) {
return nil, errors.New("invalid decode hook signature")
}
}
}

// DecodeHookExec executes the given decode hook. This should be used
// since it'll naturally degrade to the older backwards compatible DecodeHookFunc
// that took reflect.Kind instead of reflect.Type.
Expand All @@ -61,13 +85,17 @@ func DecodeHookExec(
// The composed funcs are called in order, with the result of the
// previous transformation.
func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc {
cached := make([]func(from reflect.Value, to reflect.Value) (interface{}, error), 0, len(fs))
for _, f := range fs {
cached = append(cached, cachedDecodeHook(f))
}
return func(f reflect.Value, t reflect.Value) (interface{}, error) {
var err error
data := f.Interface()

newFrom := f
for _, f1 := range fs {
data, err = DecodeHookExec(f1, newFrom, t)
for _, c := range cached {
data, err = c(newFrom, t)
if err != nil {
return nil, err
}
Expand All @@ -81,13 +109,17 @@ func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc {
// OrComposeDecodeHookFunc executes all input hook functions until one of them returns no error. In that case its value is returned.
// If all hooks return an error, OrComposeDecodeHookFunc returns an error concatenating all error messages.
func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFunc {
cached := make([]func(from reflect.Value, to reflect.Value) (interface{}, error), 0, len(ff))
for _, f := range ff {
cached = append(cached, cachedDecodeHook(f))
}
return func(a, b reflect.Value) (interface{}, error) {
var allErrs string
var out interface{}
var err error

for _, f := range ff {
out, err = DecodeHookExec(f, a, b)
for _, c := range cached {
out, err = c(a, b)
if err != nil {
allErrs += err.Error() + "\n"
continue
Expand Down
10 changes: 7 additions & 3 deletions mapstructure.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ type DecoderConfig struct {
// structure. The top-level Decode method is just a convenience that sets
// up the most basic Decoder.
type Decoder struct {
config *DecoderConfig
config *DecoderConfig
cachedDecodeHook func(from reflect.Value, to reflect.Value) (interface{}, error)
}

// Metadata contains information about decoding a structure that
Expand Down Expand Up @@ -408,6 +409,9 @@ func NewDecoder(config *DecoderConfig) (*Decoder, error) {
result := &Decoder{
config: config,
}
if config.DecodeHook != nil {
result.cachedDecodeHook = cachedDecodeHook(config.DecodeHook)
}

return result, nil
}
Expand Down Expand Up @@ -462,10 +466,10 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e
return nil
}

if d.config.DecodeHook != nil {
if d.cachedDecodeHook != nil {
// We have a DecodeHook, so let's pre-process the input.
var err error
input, err = DecodeHookExec(d.config.DecodeHook, inputVal, outVal)
input, err = d.cachedDecodeHook(inputVal, outVal)
if err != nil {
return fmt.Errorf("error decoding '%s': %w", name, err)
}
Expand Down