diff --git a/binding.go b/binding.go index 21d14a2..1c9befa 100644 --- a/binding.go +++ b/binding.go @@ -14,7 +14,7 @@ import ( "time" ) -type requestBinder func(req *http.Request, userStruct FieldMapper) Errors +type requestBinder func(req *http.Request, userStruct FieldMapper) error // Bind takes data out of the request and deserializes into a struct according // to the Content-Type of the request. If no Content-Type is specified, there @@ -44,7 +44,12 @@ func Bind(req *http.Request, userStruct FieldMapper) error { if contentType == "" { errs.Add([]string{}, ContentTypeError, "Empty Content-Type") - errs = validate(errs, req, userStruct) + err := validate(errs, req, userStruct) + if errs2, ok := err.(Errors); ok { + errs = errs2 + } else { + return err + } } else { errs.Add([]string{}, ContentTypeError, "Unsupported Content-Type") } @@ -59,15 +64,18 @@ func Bind(req *http.Request, userStruct FieldMapper) error { // This function invokes data validation after deserialization. func Form(req *http.Request, userStruct FieldMapper) error { err := formBinder(req, userStruct) - if len(err) > 0 { - return err + if errs, ok := err.(Errors); ok { + if len(errs) > 0 { + return errs + } + return nil } - return nil + return err } var formBinder requestBinder = defaultFormBinder -func defaultFormBinder(req *http.Request, userStruct FieldMapper) Errors { +func defaultFormBinder(req *http.Request, userStruct FieldMapper) error { var errs Errors parseErr := req.ParseForm() @@ -83,16 +91,18 @@ func defaultFormBinder(req *http.Request, userStruct FieldMapper) Errors { // This function invokes data validation after deserialization. func URL(req *http.Request, userStruct FieldMapper) error { err := urlBinder(req, userStruct) - if len(err) > 0 { - return err + if errs, ok := err.(Errors); ok { + if len(errs) > 0 { + return errs + } + return nil } - return nil - + return err } var urlBinder requestBinder = defaultURLBinder -func defaultURLBinder(req *http.Request, userStruct FieldMapper) Errors { +func defaultURLBinder(req *http.Request, userStruct FieldMapper) error { return bindForm(req, userStruct, req.URL.Query(), nil) } @@ -101,16 +111,18 @@ func defaultURLBinder(req *http.Request, userStruct FieldMapper) Errors { // *multipart.FileHeader fields. func MultipartForm(req *http.Request, userStruct FieldMapper) error { err := multipartFormBinder(req, userStruct) - if len(err) > 0 { - return err + if errs, ok := err.(Errors); ok { + if len(errs) > 0 { + return errs + } + return nil } - - return nil + return err } var multipartFormBinder requestBinder = defaultMultipartFormBinder -func defaultMultipartFormBinder(req *http.Request, userStruct FieldMapper) Errors { +func defaultMultipartFormBinder(req *http.Request, userStruct FieldMapper) error { var errs Errors multipartReader, err := req.MultipartReader() @@ -135,16 +147,18 @@ func defaultMultipartFormBinder(req *http.Request, userStruct FieldMapper) Error // This function invokes data validation after deserialization. func Json(req *http.Request, userStruct FieldMapper) error { err := jsonBinder(req, userStruct) - if len(err) > 0 { - return err + if errs, ok := err.(Errors); ok { + if len(errs) > 0 { + return errs + } + return nil } - - return nil + return err } var jsonBinder requestBinder = defaultJsonBinder -func defaultJsonBinder(req *http.Request, userStruct FieldMapper) Errors { +func defaultJsonBinder(req *http.Request, userStruct FieldMapper) error { var errs Errors if req.Body != nil { @@ -159,12 +173,14 @@ func defaultJsonBinder(req *http.Request, userStruct FieldMapper) Errors { return errs } - errs = validate(errs, req, userStruct) - if len(errs) > 0 { - return errs + err := validate(errs, req, userStruct) + if errs2, ok := err.(Errors); ok { + if len(errs2) > 0 { + return errs2 + } + return nil } - - return nil + return err } // Validate ensures that all conditions have been met on every field in the @@ -172,14 +188,16 @@ func defaultJsonBinder(req *http.Request, userStruct FieldMapper) Errors { // deserialized into the struct. func Validate(req *http.Request, userStruct FieldMapper) error { err := validate(Errors{}, req, userStruct) - if len(err) > 0 { - return err + if errs, ok := err.(Errors); ok { + if len(errs) > 0 { + return errs + } + return nil } - - return nil + return err } -func validate(errs Errors, req *http.Request, userStruct FieldMapper) Errors { +func validate(errs Errors, req *http.Request, userStruct FieldMapper) error { fm := userStruct.FieldMap(req) for fieldPointer, fieldNameOrSpec := range fm { @@ -395,20 +413,16 @@ func validate(errs Errors, req *http.Request, userStruct FieldMapper) Errors { case Errors: errs = append(errs, e...) default: - errs.Add([]string{}, "", e.Error()) + return err } } } - if len(errs) > 0 { - return errs - } - - return nil + return errs } func bindForm(req *http.Request, userStruct FieldMapper, formData map[string][]string, - formFile map[string][]*multipart.FileHeader) Errors { + formFile map[string][]*multipart.FileHeader) error { var errs Errors diff --git a/binding_test.go b/binding_test.go index cadabbe..2384abf 100644 --- a/binding_test.go +++ b/binding_test.go @@ -85,7 +85,7 @@ func TestBind(t *testing.T) { Convey("Should invoke the Form deserializer", func() { model := new(Model) invoked := false - formBinder = func(req *http.Request, v FieldMapper) Errors { + formBinder = func(req *http.Request, v FieldMapper) error { invoked = true return defaultFormBinder(req, v) } @@ -120,7 +120,7 @@ func TestBind(t *testing.T) { Convey("Should invoke the MultipartForm deserializer", func() { model := new(Model) invoked := false - multipartFormBinder = func(req *http.Request, v FieldMapper) Errors { + multipartFormBinder = func(req *http.Request, v FieldMapper) error { invoked = true return defaultMultipartFormBinder(req, v) } @@ -147,7 +147,7 @@ func TestBind(t *testing.T) { Convey("Should invoke Json deserializer", func() { model := new(Model) invoked := false - jsonBinder = func(req *http.Request, v FieldMapper) Errors { + jsonBinder = func(req *http.Request, v FieldMapper) error { invoked = true return defaultJsonBinder(req, v) } @@ -177,8 +177,7 @@ func TestBindForm(t *testing.T) { Convey("When bindForm is called", func() { req, err := http.NewRequest("POST", "http://www.example.com", nil) So(err, ShouldBeNil) - var errs Errors - errs = bindForm(req, &actual, formData, nil) + err = bindForm(req, &actual, formData, nil) Convey("Then all of the struct's fields should be populated", func() { Convey("Then the Uint8 field should have the expected value", func() { So(actual.Uint8, ShouldEqual, expected.Uint8) @@ -363,6 +362,8 @@ func TestBindForm(t *testing.T) { }) Convey("Then no errors should be produced", FailureContinues, func() { + errs, ok := err.(Errors) + So(ok, ShouldBeTrue) So(errs.Len(), ShouldEqual, 0) if errs.Len() > 0 { for _, e := range errs { @@ -377,7 +378,7 @@ func TestBindForm(t *testing.T) { Convey("When bindForm is called", func() { req, err := http.NewRequest("POST", "http://www.example.com", nil) So(err, ShouldBeNil) - errs := bindForm(req, &actual, map[string][]string{}, nil) + err = bindForm(req, &actual, map[string][]string{}, nil) Convey("Then none of the struct's fields should be populated", func() { expected := AllTypes{} So(reflect.DeepEqual(actual, expected), ShouldBeTrue) @@ -388,6 +389,8 @@ func TestBindForm(t *testing.T) { for _, f := range actual.FieldMap(nil) { fields[f.(Field).Form] = struct{}{} } + errs, ok := err.(Errors) + So(ok, ShouldEqual, true) for _, err := range errs { So(len(err.Fields()), ShouldEqual, 1) _, ok := fields[err.Fields()[0]] diff --git a/validate_test.go b/validate_test.go index 0c06149..558e203 100644 --- a/validate_test.go +++ b/validate_test.go @@ -18,7 +18,7 @@ func TestValidate(t *testing.T) { } model := NewCompleteModel() var errs Errors - errs = validate(errs, req, &model) + errs = validate(errs, req, &model).(Errors) expectedErrs := make(map[string]bool) for _, v := range model.FieldMap(nil) { @@ -65,7 +65,7 @@ func TestValidate(t *testing.T) { } model := new(AllTypes) var errs Errors - errs = validate(errs, req, model) + errs = validate(errs, req, model).(Errors) expectedErrs := make(map[string]bool) for _, v := range model.FieldMap(nil) {