diff --git a/bool.go b/bool.go index 99b43bb..e22f937 100644 --- a/bool.go +++ b/bool.go @@ -39,3 +39,7 @@ func (n *NullBool) Scan(src any) error { n.Valid = reflect.TypeOf(src) != nil return nil } + +func (n *NullBool) IsZero() bool { + return !n.Valid +} diff --git a/byte.go b/byte.go index 4ddfa0c..7a0c7ca 100644 --- a/byte.go +++ b/byte.go @@ -39,3 +39,7 @@ func (n *NullByte) Scan(src any) error { n.Valid = reflect.TypeOf(src) != nil return nil } + +func (n *NullByte) IsZero() bool { + return !n.Valid +} diff --git a/float64.go b/float64.go index 741ccbc..43261f1 100644 --- a/float64.go +++ b/float64.go @@ -39,3 +39,7 @@ func (n *NullFloat64) Scan(src any) error { n.Valid = reflect.TypeOf(src) != nil return nil } + +func (n *NullFloat64) IsZero() bool { + return !n.Valid +} diff --git a/generic.go b/generic.go index efd3029..61580e5 100644 --- a/generic.go +++ b/generic.go @@ -40,3 +40,7 @@ func (n *Null[T]) Scan(src any) error { n.Valid = reflect.TypeOf(src) != nil return nil } + +func (n *Null[T]) IsZero() bool { + return !n.Valid +} diff --git a/go.mod b/go.mod index 7fe11c5..71de205 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/gkits/sqlnull -go 1.23.2 +go 1.24 require ( github.com/mattn/go-sqlite3 v1.14.24 diff --git a/int16.go b/int16.go index ed66ced..9af25da 100644 --- a/int16.go +++ b/int16.go @@ -39,3 +39,7 @@ func (n *NullInt16) Scan(src any) error { n.Valid = reflect.TypeOf(src) != nil return nil } + +func (n *NullInt16) IsZero() bool { + return !n.Valid +} diff --git a/int32.go b/int32.go index 6b881f6..d4078ff 100644 --- a/int32.go +++ b/int32.go @@ -39,3 +39,7 @@ func (n *NullInt32) Scan(src any) error { n.Valid = reflect.TypeOf(src) != nil return nil } + +func (n *NullInt32) IsZero() bool { + return !n.Valid +} diff --git a/int64.go b/int64.go index 9564dc1..05189f7 100644 --- a/int64.go +++ b/int64.go @@ -39,3 +39,7 @@ func (n *NullInt64) Scan(src any) error { n.Valid = reflect.TypeOf(src) != nil return nil } + +func (n *NullInt64) IsZero() bool { + return !n.Valid +} diff --git a/sqlnull_test.go b/sqlnull_test.go index b42199a..e96696f 100644 --- a/sqlnull_test.go +++ b/sqlnull_test.go @@ -24,6 +24,18 @@ type target struct { Time sqlnull.NullTime `json:"time"` } +type targetOmitZero struct { + Generic sqlnull.Null[string] `json:"generic,omitzero"` + String sqlnull.NullString `json:"string,omitzero"` + Bool sqlnull.NullBool `json:"bool,omitzero"` + Byte sqlnull.NullByte `json:"byte,omitzero"` + Int16 sqlnull.NullInt16 `json:"int16,omitzero"` + Int32 sqlnull.NullInt32 `json:"int32,omitzero"` + Int64 sqlnull.NullInt64 `json:"int64,omitzero"` + Float64 sqlnull.NullFloat64 `json:"float64,omitzero"` + Time sqlnull.NullTime `json:"time,omitzero"` +} + func Test_MarshalJSON(t *testing.T) { cases := []struct { name string @@ -98,7 +110,7 @@ func Test_UnmarshalJSON(t *testing.T) { want target }{ { - name: "successfully marshal with values", + name: "successfully unmarshal with values", in: []byte(`{ "generic": "generic", "string": "string", @@ -123,7 +135,7 @@ func Test_UnmarshalJSON(t *testing.T) { }, }, { - name: "successfully marshal with null", + name: "successfully unmarshal with null", in: []byte(`{ "generic": null, "string": null, @@ -147,6 +159,21 @@ func Test_UnmarshalJSON(t *testing.T) { Time: sqlnull.NullTime{Time: time.Time{}, Valid: false}, }, }, + { + name: "successfully unmarshal with empty json", + in: []byte(`{}`), + want: target{ + Generic: sqlnull.Null[string]{V: "", Valid: false}, + String: sqlnull.NullString{String: "", Valid: false}, + Bool: sqlnull.NullBool{Bool: false, Valid: false}, + Byte: sqlnull.NullByte{Byte: 0, Valid: false}, + Int16: sqlnull.NullInt16{Int16: 0, Valid: false}, + Int32: sqlnull.NullInt32{Int32: 0, Valid: false}, + Int64: sqlnull.NullInt64{Int64: 0, Valid: false}, + Float64: sqlnull.NullFloat64{Float64: 0, Valid: false}, + Time: sqlnull.NullTime{Time: time.Time{}, Valid: false}, + }, + }, } for _, c := range cases { @@ -244,3 +271,61 @@ func Test_Scan(t *testing.T) { }) } } + +func Test_MarshalJSONOmitZero(t *testing.T) { + cases := []struct { + name string + in targetOmitZero + want string + }{ + { + name: "successfully marshal with values", + in: targetOmitZero{ + Generic: sqlnull.Null[string]{V: "generic", Valid: true}, + String: sqlnull.NullString{String: "string", Valid: true}, + Bool: sqlnull.NullBool{Bool: true, Valid: true}, + Byte: sqlnull.NullByte{Byte: 255, Valid: true}, + Int16: sqlnull.NullInt16{Int16: 16, Valid: true}, + Int32: sqlnull.NullInt32{Int32: 32, Valid: true}, + Int64: sqlnull.NullInt64{Int64: 64, Valid: true}, + Float64: sqlnull.NullFloat64{Float64: 64.6464, Valid: true}, + Time: sqlnull.NullTime{Time: time.Date(2024, 10, 23, 17, 50, 0, 0, time.UTC), Valid: true}, + }, + want: `{ + "generic": "generic", + "string": "string", + "bool": true, + "byte": 255, + "int16": 16, + "int32": 32, + "int64": 64, + "float64": 64.6464, + "time": "2024-10-23T17:50:00Z" + }`, + }, + { + name: "successfully marshal empty json", + in: targetOmitZero{ + Generic: sqlnull.Null[string]{V: "generic", Valid: false}, + String: sqlnull.NullString{String: "string", Valid: false}, + Bool: sqlnull.NullBool{Bool: true, Valid: false}, + Byte: sqlnull.NullByte{Byte: 255, Valid: false}, + Int16: sqlnull.NullInt16{Int16: 16, Valid: false}, + Int32: sqlnull.NullInt32{Int32: 32, Valid: false}, + Int64: sqlnull.NullInt64{Int64: 64, Valid: false}, + Float64: sqlnull.NullFloat64{Float64: 64.6464, Valid: false}, + Time: sqlnull.NullTime{Time: time.Date(2024, 10, 23, 17, 50, 0, 0, time.UTC), Valid: false}, + }, + want: `{}`, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := json.Marshal(c.in) + t.Log(string(got)) + require.NoError(t, err) + assert.JSONEq(t, c.want, string(got)) + }) + } +} diff --git a/string.go b/string.go index 05a8fdf..00a15e0 100644 --- a/string.go +++ b/string.go @@ -39,3 +39,7 @@ func (n *NullString) Scan(src any) error { n.Valid = reflect.TypeOf(src) != nil return nil } + +func (n *NullString) IsZero() bool { + return !n.Valid +} diff --git a/time.go b/time.go index 59f0f1d..aa87eda 100644 --- a/time.go +++ b/time.go @@ -40,3 +40,7 @@ func (n *NullTime) Scan(src any) error { n.Valid = reflect.TypeOf(src) != nil return nil } + +func (n *NullTime) IsZero() bool { + return !n.Valid +}