Skip to content
This repository was archived by the owner on Jan 14, 2022. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 52 additions & 38 deletions binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"time"
)

type requestBinder func(req *http.Request, userStruct FieldMapper) Errors
type requestBinder func(req *http.Request, userStruct FieldMapper) error

// Bind takes data out of the request and deserializes into a struct according
// to the Content-Type of the request. If no Content-Type is specified, there
Expand Down Expand Up @@ -44,7 +44,12 @@ func Bind(req *http.Request, userStruct FieldMapper) error {

if contentType == "" {
errs.Add([]string{}, ContentTypeError, "Empty Content-Type")
errs = validate(errs, req, userStruct)
err := validate(errs, req, userStruct)
if errs2, ok := err.(Errors); ok {
errs = errs2
} else {
return err
}
} else {
errs.Add([]string{}, ContentTypeError, "Unsupported Content-Type")
}
Expand All @@ -59,15 +64,18 @@ func Bind(req *http.Request, userStruct FieldMapper) error {
// This function invokes data validation after deserialization.
func Form(req *http.Request, userStruct FieldMapper) error {
err := formBinder(req, userStruct)
if len(err) > 0 {
return err
if errs, ok := err.(Errors); ok {
if len(errs) > 0 {
return errs
}
return nil
}
return nil
return err
}

var formBinder requestBinder = defaultFormBinder

func defaultFormBinder(req *http.Request, userStruct FieldMapper) Errors {
func defaultFormBinder(req *http.Request, userStruct FieldMapper) error {
var errs Errors

parseErr := req.ParseForm()
Expand All @@ -83,16 +91,18 @@ func defaultFormBinder(req *http.Request, userStruct FieldMapper) Errors {
// This function invokes data validation after deserialization.
func URL(req *http.Request, userStruct FieldMapper) error {
err := urlBinder(req, userStruct)
if len(err) > 0 {
return err
if errs, ok := err.(Errors); ok {
if len(errs) > 0 {
return errs
}
return nil
}
return nil

return err
}

var urlBinder requestBinder = defaultURLBinder

func defaultURLBinder(req *http.Request, userStruct FieldMapper) Errors {
func defaultURLBinder(req *http.Request, userStruct FieldMapper) error {
return bindForm(req, userStruct, req.URL.Query(), nil)
}

Expand All @@ -101,16 +111,18 @@ func defaultURLBinder(req *http.Request, userStruct FieldMapper) Errors {
// *multipart.FileHeader fields.
func MultipartForm(req *http.Request, userStruct FieldMapper) error {
err := multipartFormBinder(req, userStruct)
if len(err) > 0 {
return err
if errs, ok := err.(Errors); ok {
if len(errs) > 0 {
return errs
}
return nil
}

return nil
return err
}

var multipartFormBinder requestBinder = defaultMultipartFormBinder

func defaultMultipartFormBinder(req *http.Request, userStruct FieldMapper) Errors {
func defaultMultipartFormBinder(req *http.Request, userStruct FieldMapper) error {
var errs Errors

multipartReader, err := req.MultipartReader()
Expand All @@ -135,16 +147,18 @@ func defaultMultipartFormBinder(req *http.Request, userStruct FieldMapper) Error
// This function invokes data validation after deserialization.
func Json(req *http.Request, userStruct FieldMapper) error {
err := jsonBinder(req, userStruct)
if len(err) > 0 {
return err
if errs, ok := err.(Errors); ok {
if len(errs) > 0 {
return errs
}
return nil
}

return nil
return err
}

var jsonBinder requestBinder = defaultJsonBinder

func defaultJsonBinder(req *http.Request, userStruct FieldMapper) Errors {
func defaultJsonBinder(req *http.Request, userStruct FieldMapper) error {
var errs Errors

if req.Body != nil {
Expand All @@ -159,27 +173,31 @@ func defaultJsonBinder(req *http.Request, userStruct FieldMapper) Errors {
return errs
}

errs = validate(errs, req, userStruct)
if len(errs) > 0 {
return errs
err := validate(errs, req, userStruct)
if errs2, ok := err.(Errors); ok {
if len(errs2) > 0 {
return errs2
}
return nil
}

return nil
return err
}

// Validate ensures that all conditions have been met on every field in the
// populated struct. Validation should occur after the request has been
// deserialized into the struct.
func Validate(req *http.Request, userStruct FieldMapper) error {
err := validate(Errors{}, req, userStruct)
if len(err) > 0 {
return err
if errs, ok := err.(Errors); ok {
if len(errs) > 0 {
return errs
}
return nil
}

return nil
return err
}

func validate(errs Errors, req *http.Request, userStruct FieldMapper) Errors {
func validate(errs Errors, req *http.Request, userStruct FieldMapper) error {
fm := userStruct.FieldMap(req)

for fieldPointer, fieldNameOrSpec := range fm {
Expand Down Expand Up @@ -395,20 +413,16 @@ func validate(errs Errors, req *http.Request, userStruct FieldMapper) Errors {
case Errors:
errs = append(errs, e...)
default:
errs.Add([]string{}, "", e.Error())
return err
}
}
}

if len(errs) > 0 {
return errs
}

return nil
return errs
}

func bindForm(req *http.Request, userStruct FieldMapper, formData map[string][]string,
formFile map[string][]*multipart.FileHeader) Errors {
formFile map[string][]*multipart.FileHeader) error {

var errs Errors

Expand Down
15 changes: 9 additions & 6 deletions binding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestBind(t *testing.T) {
Convey("Should invoke the Form deserializer", func() {
model := new(Model)
invoked := false
formBinder = func(req *http.Request, v FieldMapper) Errors {
formBinder = func(req *http.Request, v FieldMapper) error {
invoked = true
return defaultFormBinder(req, v)
}
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestBind(t *testing.T) {
Convey("Should invoke the MultipartForm deserializer", func() {
model := new(Model)
invoked := false
multipartFormBinder = func(req *http.Request, v FieldMapper) Errors {
multipartFormBinder = func(req *http.Request, v FieldMapper) error {
invoked = true
return defaultMultipartFormBinder(req, v)
}
Expand All @@ -147,7 +147,7 @@ func TestBind(t *testing.T) {
Convey("Should invoke Json deserializer", func() {
model := new(Model)
invoked := false
jsonBinder = func(req *http.Request, v FieldMapper) Errors {
jsonBinder = func(req *http.Request, v FieldMapper) error {
invoked = true
return defaultJsonBinder(req, v)
}
Expand Down Expand Up @@ -177,8 +177,7 @@ func TestBindForm(t *testing.T) {
Convey("When bindForm is called", func() {
req, err := http.NewRequest("POST", "http://www.example.com", nil)
So(err, ShouldBeNil)
var errs Errors
errs = bindForm(req, &actual, formData, nil)
err = bindForm(req, &actual, formData, nil)
Convey("Then all of the struct's fields should be populated", func() {
Convey("Then the Uint8 field should have the expected value", func() {
So(actual.Uint8, ShouldEqual, expected.Uint8)
Expand Down Expand Up @@ -363,6 +362,8 @@ func TestBindForm(t *testing.T) {
})

Convey("Then no errors should be produced", FailureContinues, func() {
errs, ok := err.(Errors)
So(ok, ShouldBeTrue)
So(errs.Len(), ShouldEqual, 0)
if errs.Len() > 0 {
for _, e := range errs {
Expand All @@ -377,7 +378,7 @@ func TestBindForm(t *testing.T) {
Convey("When bindForm is called", func() {
req, err := http.NewRequest("POST", "http://www.example.com", nil)
So(err, ShouldBeNil)
errs := bindForm(req, &actual, map[string][]string{}, nil)
err = bindForm(req, &actual, map[string][]string{}, nil)
Convey("Then none of the struct's fields should be populated", func() {
expected := AllTypes{}
So(reflect.DeepEqual(actual, expected), ShouldBeTrue)
Expand All @@ -388,6 +389,8 @@ func TestBindForm(t *testing.T) {
for _, f := range actual.FieldMap(nil) {
fields[f.(Field).Form] = struct{}{}
}
errs, ok := err.(Errors)
So(ok, ShouldEqual, true)
for _, err := range errs {
So(len(err.Fields()), ShouldEqual, 1)
_, ok := fields[err.Fields()[0]]
Expand Down
4 changes: 2 additions & 2 deletions validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestValidate(t *testing.T) {
}
model := NewCompleteModel()
var errs Errors
errs = validate(errs, req, &model)
errs = validate(errs, req, &model).(Errors)

expectedErrs := make(map[string]bool)
for _, v := range model.FieldMap(nil) {
Expand Down Expand Up @@ -65,7 +65,7 @@ func TestValidate(t *testing.T) {
}
model := new(AllTypes)
var errs Errors
errs = validate(errs, req, model)
errs = validate(errs, req, model).(Errors)

expectedErrs := make(map[string]bool)
for _, v := range model.FieldMap(nil) {
Expand Down