diff --git a/README.md b/README.md index fe26e192..5397d79e 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,7 @@ The following switches have different behavior in this version of `sqlcmd` compa - To provide the value of the host name in the server certificate when using strict encryption, pass the host name with `-F`. Example: `-Ns -F myhost.domain.com` - More information about client/server encryption negotiation can be found at - `-u` The generated Unicode output file will have the UTF16 Little-Endian Byte-order mark (BOM) written to it. +- `-f` Specifies the code page for input and output files. Format: `codepage | i:codepage[,o:codepage] | o:codepage[,i:codepage]`. Use `65001` for UTF-8. Supported codepages include Unicode (65001, 1200, 1201), Windows (874, 1250-1258), OEM/DOS (437, 850, etc.), ISO-8859 (28591-28606), CJK (932, 936, 949, 950), and EBCDIC (37, 1047, 1140). Use `--list-codepages` to see all supported code pages. - Some behaviors that were kept to maintain compatibility with `OSQL` may be changed, such as alignment of column headers for some data types. - All commands must fit on one line, even `EXIT`. Interactive mode will not check for open parentheses or quotes for commands and prompt for successive lines. The ODBC sqlcmd allows the query run by `EXIT(query)` to span multiple lines. - `-i` doesn't handle a comma `,` in a file name correctly unless the file name argument is triple quoted. For example: diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index ea655b47..4fad0232 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -82,6 +82,8 @@ type SQLCmdArguments struct { ChangePassword string ChangePasswordAndExit string TraceFile string + CodePage string + ListCodePages bool // Keep Help at the end of the list Help bool } @@ -171,6 +173,10 @@ func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) { err = rangeParameterError("-t", fmt.Sprint(a.QueryTimeout), 0, 65534, true) case a.ServerCertificate != "" && !encryptConnectionAllowsTLS(a.EncryptConnection): err = localizer.Errorf("The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict).") + case a.CodePage != "": + if _, parseErr := sqlcmd.ParseCodePage(a.CodePage); parseErr != nil { + err = localizer.Errorf(`'-f %s': %v`, a.CodePage, parseErr) + } } } if err != nil { @@ -239,6 +245,17 @@ func Execute(version string) { listLocalServers() os.Exit(0) } + // List supported codepages + if args.ListCodePages { + fmt.Println(localizer.Sprintf("Supported Code Pages:")) + fmt.Println() + fmt.Printf("%-8s %-20s %s\n", "Code", "Name", "Description") + fmt.Printf("%-8s %-20s %s\n", "----", "----", "-----------") + for _, cp := range sqlcmd.SupportedCodePages() { + fmt.Printf("%-8d %-20s %s\n", cp.CodePage, cp.Name, cp.Description) + } + os.Exit(0) + } if len(argss) > 0 { fmt.Printf("%s'%s': Unknown command. Enter '--help' for command help.", sqlcmdErrorPrefix, argss[0]) os.Exit(1) @@ -479,6 +496,8 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { rootCmd.Flags().BoolVarP(&args.EnableColumnEncryption, "enable-column-encryption", "g", false, localizer.Sprintf("Enable column encryption")) rootCmd.Flags().StringVarP(&args.ChangePassword, "change-password", "z", "", localizer.Sprintf("New password")) rootCmd.Flags().StringVarP(&args.ChangePasswordAndExit, "change-password-exit", "Z", "", localizer.Sprintf("New password and exit")) + rootCmd.Flags().StringVarP(&args.CodePage, "code-page", "f", "", localizer.Sprintf("Specifies the code page for input/output. Use 65001 for UTF-8. Format: codepage | i:codepage[,o:codepage] | o:codepage[,i:codepage]")) + rootCmd.Flags().BoolVar(&args.ListCodePages, "list-codepages", false, localizer.Sprintf("List supported code pages and exit")) } func setScriptVariable(v string) string { @@ -813,6 +832,15 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { defer s.StopCloseHandler() s.UnicodeOutputFile = args.UnicodeOutputFile + // Parse and apply codepage settings + if args.CodePage != "" { + codePageSettings, err := sqlcmd.ParseCodePage(args.CodePage) + if err != nil { + return 1, localizer.Errorf("Invalid code page: %v", err) + } + s.CodePage = codePageSettings + } + if args.DisableCmd != nil { s.Cmd.DisableSysCommands(args.errorOnBlockedCmd()) } diff --git a/cmd/sqlcmd/sqlcmd_test.go b/cmd/sqlcmd/sqlcmd_test.go index 511816b2..cfdbcf31 100644 --- a/cmd/sqlcmd/sqlcmd_test.go +++ b/cmd/sqlcmd/sqlcmd_test.go @@ -123,6 +123,22 @@ func TestValidCommandLineToArgsConversion(t *testing.T) { {[]string{"-N", "true", "-J", "/path/to/cert2.pem"}, func(args SQLCmdArguments) bool { return args.EncryptConnection == "true" && args.ServerCertificate == "/path/to/cert2.pem" }}, + // Codepage flag tests + {[]string{"-f", "65001"}, func(args SQLCmdArguments) bool { + return args.CodePage == "65001" + }}, + {[]string{"-f", "i:1252,o:65001"}, func(args SQLCmdArguments) bool { + return args.CodePage == "i:1252,o:65001" + }}, + {[]string{"-f", "o:65001,i:1252"}, func(args SQLCmdArguments) bool { + return args.CodePage == "o:65001,i:1252" + }}, + {[]string{"--code-page", "1252"}, func(args SQLCmdArguments) bool { + return args.CodePage == "1252" + }}, + {[]string{"--list-codepages"}, func(args SQLCmdArguments) bool { + return args.ListCodePages + }}, } for _, test := range commands { @@ -178,6 +194,11 @@ func TestInvalidCommandLine(t *testing.T) { {[]string{"-N", "optional", "-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."}, {[]string{"-N", "disable", "-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."}, {[]string{"-N", "strict", "-F", "myserver.domain.com", "-J", "/path/to/cert.pem"}, "The -F and the -J options are mutually exclusive."}, + // Codepage validation tests + {[]string{"-f", "invalid"}, `'-f invalid': invalid codepage: invalid`}, + {[]string{"-f", "99999"}, `'-f 99999': unsupported codepage 99999`}, + {[]string{"-f", "i:invalid"}, `'-f i:invalid': invalid input codepage: i:invalid`}, + {[]string{"-f", "x:1252"}, `'-f x:1252': invalid codepage: x:1252`}, } for _, test := range commands { diff --git a/pkg/sqlcmd/codepage.go b/pkg/sqlcmd/codepage.go new file mode 100644 index 00000000..cced2691 --- /dev/null +++ b/pkg/sqlcmd/codepage.go @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "sort" + "strconv" + "strings" + + "github.com/microsoft/go-sqlcmd/internal/localizer" + "golang.org/x/text/encoding" + "golang.org/x/text/encoding/charmap" + "golang.org/x/text/encoding/japanese" + "golang.org/x/text/encoding/korean" + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/encoding/traditionalchinese" + "golang.org/x/text/encoding/unicode" +) + +// codepageEntry defines a codepage with its encoding and metadata +type codepageEntry struct { + encoding encoding.Encoding // nil for UTF-8 (Go's native encoding) + name string + description string +} + +// codepageRegistry is the single source of truth for all supported codepages. +// Both GetEncoding and SupportedCodePages use this registry. +var codepageRegistry = map[int]codepageEntry{ + // Unicode + 65001: {nil, "UTF-8", "Unicode (UTF-8)"}, + 1200: {unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM), "UTF-16LE", "Unicode (UTF-16 Little-Endian)"}, + 1201: {unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM), "UTF-16BE", "Unicode (UTF-16 Big-Endian)"}, + + // OEM/DOS codepages + 437: {charmap.CodePage437, "CP437", "OEM United States"}, + 850: {charmap.CodePage850, "CP850", "OEM Multilingual Latin 1"}, + 852: {charmap.CodePage852, "CP852", "OEM Latin 2"}, + 855: {charmap.CodePage855, "CP855", "OEM Cyrillic"}, + 858: {charmap.CodePage858, "CP858", "OEM Multilingual Latin 1 + Euro"}, + 860: {charmap.CodePage860, "CP860", "OEM Portuguese"}, + 862: {charmap.CodePage862, "CP862", "OEM Hebrew"}, + 863: {charmap.CodePage863, "CP863", "OEM Canadian French"}, + 865: {charmap.CodePage865, "CP865", "OEM Nordic"}, + 866: {charmap.CodePage866, "CP866", "OEM Russian"}, + + // Windows codepages + 874: {charmap.Windows874, "Windows-874", "Thai"}, + 1250: {charmap.Windows1250, "Windows-1250", "Central European"}, + 1251: {charmap.Windows1251, "Windows-1251", "Cyrillic"}, + 1252: {charmap.Windows1252, "Windows-1252", "Western European"}, + 1253: {charmap.Windows1253, "Windows-1253", "Greek"}, + 1254: {charmap.Windows1254, "Windows-1254", "Turkish"}, + 1255: {charmap.Windows1255, "Windows-1255", "Hebrew"}, + 1256: {charmap.Windows1256, "Windows-1256", "Arabic"}, + 1257: {charmap.Windows1257, "Windows-1257", "Baltic"}, + 1258: {charmap.Windows1258, "Windows-1258", "Vietnamese"}, + + // ISO-8859 codepages + 28591: {charmap.ISO8859_1, "ISO-8859-1", "Latin 1 (Western European)"}, + 28592: {charmap.ISO8859_2, "ISO-8859-2", "Latin 2 (Central European)"}, + 28593: {charmap.ISO8859_3, "ISO-8859-3", "Latin 3 (South European)"}, + 28594: {charmap.ISO8859_4, "ISO-8859-4", "Latin 4 (North European)"}, + 28595: {charmap.ISO8859_5, "ISO-8859-5", "Cyrillic"}, + 28596: {charmap.ISO8859_6, "ISO-8859-6", "Arabic"}, + 28597: {charmap.ISO8859_7, "ISO-8859-7", "Greek"}, + 28598: {charmap.ISO8859_8, "ISO-8859-8", "Hebrew"}, + 28599: {charmap.ISO8859_9, "ISO-8859-9", "Turkish"}, + 28600: {charmap.ISO8859_10, "ISO-8859-10", "Nordic"}, + 28603: {charmap.ISO8859_13, "ISO-8859-13", "Baltic"}, + 28604: {charmap.ISO8859_14, "ISO-8859-14", "Celtic"}, + 28605: {charmap.ISO8859_15, "ISO-8859-15", "Latin 9 (Western European with Euro)"}, + 28606: {charmap.ISO8859_16, "ISO-8859-16", "Latin 10 (South-Eastern European)"}, + + // Cyrillic + 20866: {charmap.KOI8R, "KOI8-R", "Russian"}, + 21866: {charmap.KOI8U, "KOI8-U", "Ukrainian"}, + + // Macintosh + 10000: {charmap.Macintosh, "Macintosh", "Mac Roman"}, + 10007: {charmap.MacintoshCyrillic, "x-mac-cyrillic", "Mac Cyrillic"}, + + // EBCDIC + 37: {charmap.CodePage037, "IBM037", "EBCDIC US-Canada"}, + 1047: {charmap.CodePage1047, "IBM1047", "EBCDIC Latin 1/Open System"}, + 1140: {charmap.CodePage1140, "IBM01140", "EBCDIC US-Canada with Euro"}, + + // Japanese + 932: {japanese.ShiftJIS, "Shift_JIS", "Japanese (Shift-JIS)"}, + 20932: {japanese.EUCJP, "EUC-JP", "Japanese (EUC)"}, + 50220: {japanese.ISO2022JP, "ISO-2022-JP", "Japanese (JIS)"}, + 50221: {japanese.ISO2022JP, "csISO2022JP", "Japanese (JIS-Allow 1 byte Kana)"}, + 50222: {japanese.ISO2022JP, "ISO-2022-JP", "Japanese (JIS-Allow 1 byte Kana SO/SI)"}, + + // Korean + 949: {korean.EUCKR, "EUC-KR", "Korean"}, + 51949: {korean.EUCKR, "EUC-KR", "Korean (EUC)"}, + + // Simplified Chinese + 936: {simplifiedchinese.GBK, "GBK", "Chinese Simplified (GBK)"}, + 54936: {simplifiedchinese.GB18030, "GB18030", "Chinese Simplified (GB18030)"}, + 52936: {simplifiedchinese.HZGB2312, "HZ-GB-2312", "Chinese Simplified (HZ)"}, + + // Traditional Chinese + 950: {traditionalchinese.Big5, "Big5", "Chinese Traditional (Big5)"}, +} + +// CodePageSettings holds the input and output codepage settings +type CodePageSettings struct { + InputCodePage int + OutputCodePage int +} + +// ParseCodePage parses the -f codepage argument +// Format: codepage | i:codepage[,o:codepage] | o:codepage[,i:codepage] +func ParseCodePage(arg string) (*CodePageSettings, error) { + if arg == "" { + return nil, nil + } + + settings := &CodePageSettings{} + parts := strings.Split(arg, ",") + + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + if strings.HasPrefix(strings.ToLower(part), "i:") { + // Input codepage + cp, err := strconv.Atoi(strings.TrimPrefix(strings.ToLower(part), "i:")) + if err != nil { + return nil, localizer.Errorf("invalid input codepage: %s", part) + } + settings.InputCodePage = cp + } else if strings.HasPrefix(strings.ToLower(part), "o:") { + // Output codepage + cp, err := strconv.Atoi(strings.TrimPrefix(strings.ToLower(part), "o:")) + if err != nil { + return nil, localizer.Errorf("invalid output codepage: %s", part) + } + settings.OutputCodePage = cp + } else { + // Both input and output + cp, err := strconv.Atoi(part) + if err != nil { + return nil, localizer.Errorf("invalid codepage: %s", part) + } + settings.InputCodePage = cp + settings.OutputCodePage = cp + } + } + + // Validate codepages + if settings.InputCodePage != 0 { + if _, err := GetEncoding(settings.InputCodePage); err != nil { + return nil, err + } + } + if settings.OutputCodePage != 0 { + if _, err := GetEncoding(settings.OutputCodePage); err != nil { + return nil, err + } + } + + return settings, nil +} + +// GetEncoding returns the encoding for a given Windows codepage number. +// Returns nil for UTF-8 (65001) since Go uses UTF-8 natively. +func GetEncoding(codepage int) (encoding.Encoding, error) { + entry, ok := codepageRegistry[codepage] + if !ok { + return nil, localizer.Errorf("unsupported codepage %s", strconv.Itoa(codepage)) + } + return entry.encoding, nil +} + +// CodePageInfo describes a supported codepage +type CodePageInfo struct { + CodePage int + Name string + Description string +} + +// SupportedCodePages returns a list of all supported codepages with descriptions +func SupportedCodePages() []CodePageInfo { + result := make([]CodePageInfo, 0, len(codepageRegistry)) + for cp, entry := range codepageRegistry { + result = append(result, CodePageInfo{ + CodePage: cp, + Name: entry.name, + Description: entry.description, + }) + } + // Sort by codepage number for consistent output + sort.Slice(result, func(i, j int) bool { + return result[i].CodePage < result[j].CodePage + }) + return result +} diff --git a/pkg/sqlcmd/codepage_test.go b/pkg/sqlcmd/codepage_test.go new file mode 100644 index 00000000..47f7fae3 --- /dev/null +++ b/pkg/sqlcmd/codepage_test.go @@ -0,0 +1,265 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseCodePage(t *testing.T) { + tests := []struct { + name string + arg string + wantInput int + wantOutput int + wantErr bool + errContains string + }{ + { + name: "empty string", + arg: "", + wantInput: 0, + wantOutput: 0, + wantErr: false, + }, + { + name: "single codepage sets both", + arg: "65001", + wantInput: 65001, + wantOutput: 65001, + wantErr: false, + }, + { + name: "input only", + arg: "i:1252", + wantInput: 1252, + wantOutput: 0, + wantErr: false, + }, + { + name: "output only", + arg: "o:65001", + wantInput: 0, + wantOutput: 65001, + wantErr: false, + }, + { + name: "input and output", + arg: "i:1252,o:65001", + wantInput: 1252, + wantOutput: 65001, + wantErr: false, + }, + { + name: "output and input reversed", + arg: "o:65001,i:1252", + wantInput: 1252, + wantOutput: 65001, + wantErr: false, + }, + { + name: "uppercase prefix", + arg: "I:1252,O:65001", + wantInput: 1252, + wantOutput: 65001, + wantErr: false, + }, + { + name: "invalid codepage number", + arg: "abc", + wantErr: true, + errContains: "invalid codepage", + }, + { + name: "invalid input codepage", + arg: "i:abc", + wantErr: true, + errContains: "invalid input codepage", + }, + { + name: "invalid output codepage", + arg: "o:xyz", + wantErr: true, + errContains: "invalid output codepage", + }, + { + name: "unsupported codepage", + arg: "99999", + wantErr: true, + errContains: "unsupported codepage", + }, + { + name: "Japanese Shift JIS", + arg: "932", + wantInput: 932, + wantOutput: 932, + wantErr: false, + }, + { + name: "Chinese GBK", + arg: "936", + wantInput: 936, + wantOutput: 936, + wantErr: false, + }, + { + name: "Korean", + arg: "949", + wantInput: 949, + wantOutput: 949, + wantErr: false, + }, + { + name: "Traditional Chinese Big5", + arg: "950", + wantInput: 950, + wantOutput: 950, + wantErr: false, + }, + { + name: "EBCDIC", + arg: "37", + wantInput: 37, + wantOutput: 37, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + settings, err := ParseCodePage(tt.arg) + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + assert.NoError(t, err) + if tt.arg == "" { + assert.Nil(t, settings) + return + } + assert.NotNil(t, settings) + assert.Equal(t, tt.wantInput, settings.InputCodePage) + assert.Equal(t, tt.wantOutput, settings.OutputCodePage) + }) + } +} + +func TestGetEncoding(t *testing.T) { + tests := []struct { + codepage int + wantNil bool // UTF-8 returns nil encoding + wantErr bool + }{ + // Unicode + {65001, true, false}, // UTF-8 + {1200, false, false}, // UTF-16LE + {1201, false, false}, // UTF-16BE + + // OEM/DOS + {437, false, false}, + {850, false, false}, + {866, false, false}, + + // Windows + {874, false, false}, + {1250, false, false}, + {1251, false, false}, + {1252, false, false}, + {1253, false, false}, + {1254, false, false}, + {1255, false, false}, + {1256, false, false}, + {1257, false, false}, + {1258, false, false}, + + // ISO-8859 + {28591, false, false}, + {28592, false, false}, + {28605, false, false}, + + // Cyrillic + {20866, false, false}, + {21866, false, false}, + + // Macintosh + {10000, false, false}, + {10007, false, false}, + + // EBCDIC + {37, false, false}, + {1047, false, false}, + {1140, false, false}, + + // CJK + {932, false, false}, // Japanese Shift JIS + {20932, false, false}, // EUC-JP + {50220, false, false}, // ISO-2022-JP + {949, false, false}, // Korean EUC-KR + {936, false, false}, // Chinese GBK + {54936, false, false}, // GB18030 + {950, false, false}, // Big5 + + // Invalid + {99999, false, true}, + {12345, false, true}, + } + + for _, tt := range tests { + t.Run(strconv.Itoa(tt.codepage), func(t *testing.T) { + enc, err := GetEncoding(tt.codepage) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + if tt.wantNil { + assert.Nil(t, enc, "UTF-8 should return nil encoding") + } else { + assert.NotNil(t, enc, "non-UTF-8 codepage should return encoding") + } + }) + } +} + +func TestSupportedCodePages(t *testing.T) { + cps := SupportedCodePages() + + // Should have entries + assert.Greater(t, len(cps), 0, "should return codepages") + + // Each returned codepage should be valid in GetEncoding + for _, cp := range cps { + _, err := GetEncoding(cp.CodePage) + assert.NoError(t, err, "SupportedCodePages entry %d should be valid in GetEncoding", cp.CodePage) + assert.NotEmpty(t, cp.Name, "codepage %d should have a name", cp.CodePage) + assert.NotEmpty(t, cp.Description, "codepage %d should have a description", cp.CodePage) + } + + // Result should be sorted by codepage number + for i := 1; i < len(cps); i++ { + assert.Less(t, cps[i-1].CodePage, cps[i].CodePage, "codepages should be sorted") + } + + // Check some well-known codepages are present + known := map[int]bool{ + 65001: false, // UTF-8 + 1252: false, // Windows Western + 437: false, // DOS US + 932: false, // Japanese + } + for _, cp := range cps { + if _, ok := known[cp.CodePage]; ok { + known[cp.CodePage] = true + } + } + for cp, found := range known { + assert.True(t, found, "well-known codepage %d should be in list", cp) + } +} diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 66dd1dba..c3938997 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -6,6 +6,7 @@ package sqlcmd import ( "flag" "fmt" + "io" "os" "regexp" "sort" @@ -13,10 +14,28 @@ import ( "strings" "github.com/microsoft/go-sqlcmd/internal/color" + "github.com/microsoft/go-sqlcmd/internal/localizer" "golang.org/x/text/encoding/unicode" "golang.org/x/text/transform" ) +// transformWriteCloser wraps a transform.Writer and ensures the underlying +// file is closed when Close() is called. +type transformWriteCloser struct { + *transform.Writer + underlying io.Closer +} + +// Close flushes the transform writer and closes the underlying file. +func (t *transformWriteCloser) Close() error { + // Close the transform writer (flushes pending data) + if err := t.Writer.Close(); err != nil { + _ = t.underlying.Close() + return err + } + return t.underlying.Close() +} + // Command defines a sqlcmd action which can be intermixed with the SQL batch // Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands type Command struct { @@ -324,8 +343,29 @@ func outCommand(s *Sqlcmd, args []string, line uint) error { // ODBC sqlcmd doesn't write a BOM but we will. // Maybe the endian-ness should be configurable. win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) - encoder := transform.NewWriter(o, win16le.NewEncoder()) + encoder := &transformWriteCloser{ + Writer: transform.NewWriter(o, win16le.NewEncoder()), + underlying: o, + } s.SetOutput(encoder) + } else if s.CodePage != nil && s.CodePage.OutputCodePage != 0 { + // Use specified output codepage + enc, err := GetEncoding(s.CodePage.OutputCodePage) + if err != nil { + _ = o.Close() + return err + } + if enc != nil { + // Transform from UTF-8 to specified encoding + encoder := &transformWriteCloser{ + Writer: transform.NewWriter(o, enc.NewEncoder()), + underlying: o, + } + s.SetOutput(encoder) + } else { + // UTF-8, no transformation needed + s.SetOutput(o) + } } else { s.SetOutput(o) } @@ -352,7 +392,28 @@ func errorCommand(s *Sqlcmd, args []string, line uint) error { if err != nil { return InvalidFileError(err, args[0]) } - s.SetError(o) + // Apply output codepage if configured + if s.CodePage != nil && s.CodePage.OutputCodePage != 0 { + enc, err := GetEncoding(s.CodePage.OutputCodePage) + if err != nil { + if cerr := o.Close(); cerr != nil { + return localizer.Errorf("%v; additionally, closing error file %q failed: %v", err, args[0], cerr) + } + return err + } + if enc == nil { + // UTF-8 (or default) encoding: write directly without transform + s.SetError(o) + } else { + encoder := &transformWriteCloser{ + Writer: transform.NewWriter(o, enc.NewEncoder()), + underlying: o, + } + s.SetError(encoder) + } + } else { + s.SetError(o) + } } return nil } diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 6197aa3f..76612b77 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -458,3 +458,100 @@ func TestExitCommandAppendsParameterToCurrentBatch(t *testing.T) { } } + +func TestOutputCodePageCommand(t *testing.T) { + tests := []struct { + name string + codepage int + expectedBytes []byte + inputText string + skipOnEncError bool + }{ + { + name: "UTF-8 output", + codepage: 65001, + inputText: "café", + expectedBytes: []byte("café"), + }, + { + name: "Windows-1252 output", + codepage: 1252, + inputText: "café", + expectedBytes: []byte{0x63, 0x61, 0x66, 0xe9}, // "café" in Windows-1252 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + + // Set up codepage + s.CodePage = &CodePageSettings{ + OutputCodePage: tt.codepage, + } + + // Create temp file for output + file, err := os.CreateTemp("", "sqlcmdout") + require.NoError(t, err, "os.CreateTemp") + defer os.Remove(file.Name()) + fileName := file.Name() + _ = file.Close() + + // Run the OUT command + err = outCommand(s, []string{fileName}, 1) + require.NoError(t, err, "outCommand") + + // Write some text + _, err = s.GetOutput().Write([]byte(tt.inputText)) + require.NoError(t, err, "Write") + + // Close to flush + if closer, ok := s.GetOutput().(interface{ Close() error }); ok { + require.NoError(t, closer.Close(), "Close output") + } + + // Read the file and check encoding + content, err := os.ReadFile(fileName) + require.NoError(t, err, "ReadFile") + assert.Equal(t, tt.expectedBytes, content, "Output encoding mismatch") + }) + } +} + +func TestErrorCodePageCommand(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + + // Set up codepage for Windows-1252 + s.CodePage = &CodePageSettings{ + OutputCodePage: 1252, + } + + // Create temp file for error output + file, err := os.CreateTemp("", "sqlcmderr") + require.NoError(t, err, "os.CreateTemp") + defer os.Remove(file.Name()) + fileName := file.Name() + _ = file.Close() + + // Run the ERROR command + err = errorCommand(s, []string{fileName}, 1) + require.NoError(t, err, "errorCommand") + + // Write some text with special characters + _, err = s.err.Write([]byte("Error: café")) + require.NoError(t, err, "Write") + + // Close to flush + if closer, ok := s.err.(interface{ Close() error }); ok { + require.NoError(t, closer.Close(), "Close error") + } + + // Read the file and check encoding + content, err := os.ReadFile(fileName) + require.NoError(t, err, "ReadFile") + // "Error: café" in Windows-1252 + expected := []byte{0x45, 0x72, 0x72, 0x6f, 0x72, 0x3a, 0x20, 0x63, 0x61, 0x66, 0xe9} + assert.Equal(t, expected, content, "Error output encoding mismatch") +} diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index 5e572a94..7a861c40 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -86,6 +86,8 @@ type Sqlcmd struct { UnicodeOutputFile bool // EchoInput tells the GO command to print the batch text before running the query EchoInput bool + // CodePage specifies input/output file encoding + CodePage *CodePageSettings colorizer color.Colorizer termchan chan os.Signal } @@ -331,9 +333,29 @@ func (s *Sqlcmd) IncludeFile(path string, processAll bool) error { } defer f.Close() b := s.batch.batchline - utf16bom := unicode.BOMOverride(unicode.UTF8.NewDecoder()) - unicodeReader := transform.NewReader(f, utf16bom) - scanner := bufio.NewReader(unicodeReader) + + // Set up the reader with appropriate encoding + var reader io.Reader + if s.CodePage != nil && s.CodePage.InputCodePage != 0 { + // Use specified input codepage + enc, err := GetEncoding(s.CodePage.InputCodePage) + if err != nil { + return err + } + if enc != nil { + // Transform from specified encoding to UTF-8 + reader = transform.NewReader(f, enc.NewDecoder()) + } else { + // UTF-8 codepage: still apply BOM stripping + utf8bom := unicode.BOMOverride(unicode.UTF8.NewDecoder()) + reader = transform.NewReader(f, utf8bom) + } + } else { + // Default: auto-detect BOM for UTF-16, fallback to UTF-8 + utf16bom := unicode.BOMOverride(unicode.UTF8.NewDecoder()) + reader = transform.NewReader(f, utf16bom) + } + scanner := bufio.NewReader(reader) curLine := s.batch.read echoFileLines := s.echoFileLines ln := make([]byte, 0, 2*1024*1024)