From 6cb27c68f607ac54702ca4069d0722fe264cb849 Mon Sep 17 00:00:00 2001 From: viatoriche / Maxim Panfilov Date: Mon, 9 Sep 2019 18:04:34 +0300 Subject: [PATCH] add support for pointers fields --- decode.go | 10 ++++++++++ decode_test.go | 13 ++++++++++++- encode.go | 24 ++++++++++++++++++------ encode_test.go | 12 +++++++++++- 4 files changed, 51 insertions(+), 8 deletions(-) diff --git a/decode.go b/decode.go index f4f9130..ae04ff2 100644 --- a/decode.go +++ b/decode.go @@ -109,6 +109,16 @@ func (d *decoder) coerce(query string, target reflect.Kind, field reflect.Value) var err error var c interface{} + if target == reflect.Ptr { + elemType := field.Type().Elem() + k := elemType.Kind() + if field.IsNil() { + field.Set(reflect.New(elemType)) + } + elem := field.Elem() + return d.coerce(query, k, elem) + } + switch target { case reflect.String: field.SetString(query) diff --git a/decode_test.go b/decode_test.go index 1c8b04e..920fb96 100644 --- a/decode_test.go +++ b/decode_test.go @@ -49,7 +49,9 @@ type TestStruct struct { Float32s []float32 Float64s []float64 hidden int - Hidden int `qstring:"-"` + Hidden int `qstring:"-"` + PtrInt *int `qstring:"ptrInt"` + PtrNil *struct{} `qstring:"ptrNil"` } func TestUnmarshall(t *testing.T) { @@ -83,6 +85,7 @@ func TestUnmarshall(t *testing.T) { "ubigs": []string{"12", "13", "14"}, "float32s": []string{"6000", "6001", "6002"}, "float64s": []string{"7000", "7001", "7002"}, + "ptrInt": []string{"10"}, } err := Unmarshal(query, &ts) @@ -97,6 +100,14 @@ func TestUnmarshall(t *testing.T) { if len(ts.Fields) != 2 { t.Errorf("Expected 2 fields, got %d", len(ts.Fields)) } + + if *ts.PtrInt != 10 { + t.Errorf("Wrong PtrInt value: %v", ts.PtrInt) + } + + if ts.PtrNil != nil { + t.Errorf("Wrong PtrNil value: %v", ts.PtrNil) + } } func TestUnmarshalNested(t *testing.T) { diff --git a/encode.go b/encode.go index 9bc437a..d151239 100644 --- a/encode.go +++ b/encode.go @@ -73,14 +73,14 @@ func (e *encoder) marshal() (url.Values, error) { func (e *encoder) value(val reflect.Value) (url.Values, error) { elem := val.Elem() - typ := elem.Type() + elemType := elem.Type() var err error var output = make(url.Values) for i := 0; i < elem.NumField(); i++ { // pull out the qstring struct tag elemField := elem.Field(i) - typField := typ.Field(i) + typField := elemType.Field(i) qstring, omit := parseTag(typField.Tag.Get(Tag)) if qstring == "" { // resolvable fields must have at least the `flag` struct tag @@ -93,16 +93,28 @@ func (e *encoder) value(val reflect.Value) (url.Values, error) { continue } + var typFieldCheck reflect.Type + + if typField.Type.Kind() == reflect.Ptr { + typFieldCheck = typField.Type.Elem() + if elemField.IsNil() { + continue + } + elemField = elemField.Elem() + } else { + typFieldCheck = typField.Type + } + // only do work if the current fields query string parameter was provided - switch k := typField.Type.Kind(); k { + switch k := typFieldCheck.Kind(); k { default: output.Set(qstring, marshalValue(elemField, k)) case reflect.Slice: output[qstring] = marshalSlice(elemField) - case reflect.Ptr: - marshalStruct(output, qstring, reflect.Indirect(elemField), k) case reflect.Struct: - marshalStruct(output, qstring, elemField, k) + if err := marshalStruct(output, qstring, elemField, k); err != nil { + return nil, err + } } } return output, err diff --git a/encode_test.go b/encode_test.go index cd12cd5..16b8e10 100644 --- a/encode_test.go +++ b/encode_test.go @@ -8,6 +8,7 @@ import ( ) func TestMarshallString(t *testing.T) { + i := 10 ts := TestStruct{ Name: "SomeName", Do: true, @@ -37,6 +38,8 @@ func TestMarshallString(t *testing.T) { UBigs: []uint64{12, 13}, Float32s: []float32{6000, 6001}, Float64s: []float64{7000, 7001}, + PtrInt: &i, + PtrNil: nil, } expected := []string{"name=SomeName", "do=true", "page=1", "id=12", "small=13", @@ -46,7 +49,7 @@ func TestMarshallString(t *testing.T) { "smalls=7", "meds=9", "meds=10", "bigs=12", "bigs=13", "upages=2", "upages=3", "uids=5", "uids=6", "usmalls=8", "usmalls=9", "umeds=9", "umeds=10", "ubigs=12", "ubigs=13", "float32s=6000", "float32s=6001", - "float64s=7000", "float64s=7001"} + "float64s=7000", "float64s=7001", "ptrInt=10"} query, err := MarshalString(&ts) if err != nil { t.Fatal(err.Error()) @@ -60,6 +63,7 @@ func TestMarshallString(t *testing.T) { } func TestMarshallValues(t *testing.T) { + i := 10 ts := TestStruct{ Name: "SomeName", Do: true, @@ -89,6 +93,8 @@ func TestMarshallValues(t *testing.T) { UBigs: []uint64{12, 13}, Float32s: []float32{6000, 6001}, Float64s: []float64{7000, 7001}, + PtrInt: &i, + PtrNil: nil, } expected := url.Values{ @@ -120,6 +126,7 @@ func TestMarshallValues(t *testing.T) { "ubigs": []string{"12", "13", "14"}, "float32s": []string{"6000", "6001", "6002"}, "float64s": []string{"7000", "7001", "7002"}, + "ptrInt": []string{"10"}, } values, err := Marshal(&ts) if err != nil { @@ -130,6 +137,9 @@ func TestMarshallValues(t *testing.T) { t.Errorf("Expected %d fields, got %d. Hidden is %q", len(expected), len(values), values["hidden"]) } + if values["ptrInt"][0] != expected["ptrInt"][0] { + t.Errorf("Wrong ptrInt value: %v", values["ptrInt"][0]) + } } func TestInvalidMarshalString(t *testing.T) {