diff --git a/README.md b/README.md index fe26e192..7da89b68 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,41 @@ switches are most important to you to have implemented next in the new sqlcmd. - `:Connect` now has an optional `-G` parameter to select one of the authentication methods for Azure SQL Database - `SqlAuthentication`, `ActiveDirectoryDefault`, `ActiveDirectoryIntegrated`, `ActiveDirectoryServicePrincipal`, `ActiveDirectoryManagedIdentity`, `ActiveDirectoryPassword`. If `-G` is not provided, either Integrated security or SQL Authentication will be used, dependent on the presence of a `-U` username parameter. - The new `--driver-logging-level` command line parameter allows you to see traces from the `go-mssqldb` client driver. Use `64` to see all traces. - Sqlcmd can now print results using a vertical format. Use the new `--vertical` command line option to set it. It's also controlled by the `SQLCMDFORMAT` scripting variable. +- `:help` displays a list of available sqlcmd commands. +- `:serverlist` lists local SQL Server instances discovered via the SQL Server Browser service (UDP port 1434). The command queries the SQL Browser service and displays the server name and instance name for each discovered instance. If no instances are found or the Browser service is not running, no output is produced. Non-timeout errors are printed to stderr. + +``` +1> :serverlist +MYSERVER\SQL2019 +MYSERVER\SQL2022 +``` + +#### Using :serverlist in batch scripts + +When automating server discovery, you can capture the output and check for errors: + +```batch +@echo off +REM Discover local SQL Server instances and connect to the first one +sqlcmd -Q ":serverlist" 2>nul > servers.txt +if %errorlevel% neq 0 ( + echo Error discovering servers + exit /b 1 +) +for /f "tokens=1" %%s in (servers.txt) do ( + echo Connecting to %%s... + sqlcmd -S %%s -Q "SELECT @@SERVERNAME" + goto :done +) +echo No SQL Server instances found +:done +``` + +To capture stderr separately (for error logging): +```batch +sqlcmd -Q ":serverlist" 2>errors.log > servers.txt +if exist errors.log if not "%%~z errors.log"=="0" type errors.log +``` ``` 1> select session_id, client_interface_name, program_name from sys.dm_exec_sessions where session_id=@@spid diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index ea655b47..eb769f0a 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -5,20 +5,16 @@ package sqlcmd import ( - "context" "errors" "fmt" - "net" "os" "regexp" "runtime/trace" "strconv" "strings" - "time" mssql "github.com/microsoft/go-mssqldb" "github.com/microsoft/go-mssqldb/azuread" - "github.com/microsoft/go-mssqldb/msdsn" "github.com/microsoft/go-sqlcmd/internal/localizer" "github.com/microsoft/go-sqlcmd/pkg/console" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" @@ -236,7 +232,7 @@ func Execute(version string) { fmt.Println() fmt.Println(localizer.Sprintf("Servers:")) } - listLocalServers() + sqlcmd.ListLocalServers(os.Stdout) os.Exit(0) } if len(argss) > 0 { @@ -911,76 +907,3 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { s.SetError(nil) return s.Exitcode, err } - -func listLocalServers() { - bmsg := []byte{byte(msdsn.BrowserAllInstances)} - resp := make([]byte, 16*1024-1) - dialer := &net.Dialer{} - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - conn, err := dialer.DialContext(ctx, "udp", ":1434") - // silently ignore failures to connect, same as ODBC - if err != nil { - return - } - defer conn.Close() - dl, _ := ctx.Deadline() - _ = conn.SetDeadline(dl) - _, err = conn.Write(bmsg) - if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - fmt.Println(err) - } - return - } - read, err := conn.Read(resp) - if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - fmt.Println(err) - } - return - } - - data := parseInstances(resp[:read]) - instances := make([]string, 0, len(data)) - for s := range data { - if s == "MSSQLSERVER" { - - instances = append(instances, "(local)", data[s]["ServerName"]) - } else { - instances = append(instances, fmt.Sprintf(`%s\%s`, data[s]["ServerName"], s)) - } - } - for _, s := range instances { - fmt.Println(" ", s) - } -} - -func parseInstances(msg []byte) msdsn.BrowserData { - results := msdsn.BrowserData{} - if len(msg) > 3 && msg[0] == 5 { - out_s := string(msg[3:]) - tokens := strings.Split(out_s, ";") - instdict := map[string]string{} - got_name := false - var name string - for _, token := range tokens { - if got_name { - instdict[name] = token - got_name = false - } else { - name = token - if len(name) == 0 { - if len(instdict) == 0 { - break - } - results[strings.ToUpper(instdict["InstanceName"])] = instdict - instdict = map[string]string{} - continue - } - got_name = true - } - } - } - return results -} diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 66dd1dba..548632e6 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -1,644 +1,712 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "flag" - "fmt" - "os" - "regexp" - "sort" - "strconv" - "strings" - - "github.com/microsoft/go-sqlcmd/internal/color" - "golang.org/x/text/encoding/unicode" - "golang.org/x/text/transform" -) - -// Command defines a sqlcmd action which can be intermixed with the SQL batch -// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands -type Command struct { - // regex must include at least one group if it has parameters - // Will be matched using FindStringSubmatch - regex *regexp.Regexp - // The function that implements the command. Third parameter is the line number - action func(*Sqlcmd, []string, uint) error - // Name of the command - name string - // whether the command is a system command - isSystem bool -} - -// Commands is the set of sqlcmd command implementations -type Commands map[string]*Command - -func newCommands() Commands { - // Commands is the set of Command implementations - return map[string]*Command{ - "EXIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT([\( \t]+.*\)*$|$)`), - action: exitCommand, - name: "EXIT", - }, - "QUIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`), - action: quitCommand, - name: "QUIT", - }, - "GO": { - regex: regexp.MustCompile(batchTerminatorRegex("GO")), - action: goCommand, - name: "GO", - }, - "OUT": { - regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`), - action: outCommand, - name: "OUT", - }, - "ERROR": { - regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`), - action: errorCommand, - name: "ERROR", - }, "READFILE": { - regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`), - action: readFileCommand, - name: "READFILE", - }, - "SETVAR": { - regex: regexp.MustCompile(`(?im)^[ \t]*:SETVAR(?:[ \t]+(.*$)|$)`), - action: setVarCommand, - name: "SETVAR", - }, - "LISTVAR": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:LISTVAR(?:[ \t]+(.*$)|$)`), - action: listVarCommand, - name: "LISTVAR", - }, - "RESET": { - regex: regexp.MustCompile(`(?im)^[ \t]*?:?RESET(?:[ \t]+(.*$)|$)`), - action: resetCommand, - name: "RESET", - }, - "LIST": { - regex: regexp.MustCompile(`(?im)^[ \t]*:LIST(?:[ \t]+(.*$)|$)`), - action: listCommand, - name: "LIST", - }, - "CONNECT": { - regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`), - action: connectCommand, - name: "CONNECT", - }, - "EXEC": { - regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(.*$)`), - action: execCommand, - name: "EXEC", - isSystem: true, - }, - "EDIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`), - action: editCommand, - name: "EDIT", - isSystem: true, - }, - "ONERROR": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?ON ERROR(?:[ \t]+(.*$)|$)`), - action: onerrorCommand, - name: "ONERROR", - }, - "XML": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:XML(?:[ \t]+(.*$)|$)`), - action: xmlCommand, - name: "XML", - }, - } -} - -// DisableSysCommands disables the ED and :!! commands. -// When exitOnCall is true, running those commands will exit the process. -func (c Commands) DisableSysCommands(exitOnCall bool) { - f := warnDisabled - if exitOnCall { - f = errorDisabled - } - for _, cmd := range c { - if cmd.isSystem { - cmd.action = f - } - } -} - -func (c Commands) matchCommand(line string) (*Command, []string) { - for _, cmd := range c { - matchedCommand := cmd.regex.FindStringSubmatch(line) - if matchedCommand != nil { - return cmd, removeComments(matchedCommand[1:]) - } - } - return nil, nil -} - -func removeComments(args []string) []string { - var pos int - quote := false - for i := range args { - pos, quote = commentStart([]rune(args[i]), quote) - if pos > -1 { - out := make([]string, i+1) - if i > 0 { - copy(out, args[:i]) - } - out[i] = args[i][:pos] - return out - } - } - return args -} - -func commentStart(arg []rune, quote bool) (int, bool) { - var i int - space := true - for ; i < len(arg); i++ { - c, next := arg[i], grab(arg, i+1, len(arg)) - switch { - case quote && c == '"' && next != '"': - quote = false - case quote && c == '"' && next == '"': - i++ - case c == '\t' || c == ' ': - space = true - // Note we assume none of the regexes would split arguments on non-whitespace boundaries such that "text -- comment" would get split into "text -" and "- comment" - case !quote && space && c == '-' && next == '-': - return i, false - case !quote && c == '"': - quote = true - default: - space = false - } - } - return -1, quote -} - -func warnDisabled(s *Sqlcmd, args []string, line uint) error { - s.WriteError(s.GetError(), ErrCommandsDisabled) - return nil -} - -func errorDisabled(s *Sqlcmd, args []string, line uint) error { - s.WriteError(s.GetError(), ErrCommandsDisabled) - s.Exitcode = 1 - return ErrExitRequested -} - -func batchTerminatorRegex(terminator string) string { - return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator)) -} - -// SetBatchTerminator attempts to set the batch terminator to the given value -// Returns an error if the new value is not usable in the regex -func (c Commands) SetBatchTerminator(terminator string) error { - cmd := c["GO"] - regex, err := regexp.Compile(batchTerminatorRegex(terminator)) - if err != nil { - return err - } - cmd.regex = regex - return nil -} - -// exitCommand has 3 modes. -// With no (), it just exits without running any query -// With () it runs whatever batch is in the buffer then exits -// With any text between () it runs the text as a query then exits -func exitCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 { - return ErrExitRequested - } - params := strings.TrimSpace(args[0]) - if params == "" { - return ErrExitRequested - } - if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") { - return InvalidCommandError("EXIT", line) - } - // First we save the current batch - query1 := s.batch.String() - if len(query1) > 0 { - query1 = s.getRunnableQuery(query1) - } - // Now parse the params of EXIT as a batch without commands - cmd := s.batch.cmd - s.batch.cmd = nil - defer func() { - s.batch.cmd = cmd - }() - query2 := strings.TrimSpace(params[1 : len(params)-1]) - if len(query2) > 0 { - s.batch.Reset([]rune(query2)) - _, _, err := s.batch.Next() - if err != nil { - return err - } - query2 = s.batch.String() - if len(query2) > 0 { - query2 = s.getRunnableQuery(query2) - } - } - - if len(query1) > 0 || len(query2) > 0 { - query := query1 + SqlcmdEol + query2 - s.Exitcode, _ = s.runQuery(query) - } - return ErrExitRequested -} - -// quitCommand immediately exits the program without running any more batches -func quitCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("QUIT", line) - } - return ErrExitRequested -} - -// goCommand runs the current batch the number of times specified -func goCommand(s *Sqlcmd, args []string, line uint) error { - // default to 1 execution - n := 1 - var err error - if len(args) > 0 { - cnt := strings.TrimSpace(args[0]) - if cnt != "" { - if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil { - return err - } - _, err = fmt.Sscanf(cnt, "%d", &n) - } - } - if err != nil || n < 1 { - return InvalidCommandError("GO", line) - } - if s.EchoInput { - err = listCommand(s, []string{}, line) - } - if err != nil { - return InvalidCommandError("GO", line) - } - query := s.batch.String() - if query == "" { - return nil - } - query = s.getRunnableQuery(query) - for i := 0; i < n; i++ { - if retcode, err := s.runQuery(query); err != nil { - s.Exitcode = retcode - return err - } - } - s.batch.Reset(nil) - return nil -} - -// outCommand changes the output writer to use a file -func outCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || args[0] == "" { - return InvalidCommandError("OUT", line) - } - filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) - if err != nil { - return err - } - - switch { - case strings.EqualFold(filePath, "stdout"): - s.SetOutput(os.Stdout) - case strings.EqualFold(filePath, "stderr"): - s.SetOutput(os.Stderr) - default: - o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return InvalidFileError(err, args[0]) - } - if s.UnicodeOutputFile { - // ODBC sqlcmd doesn't write a BOM but we will. - // Maybe the endian-ness should be configurable. - win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) - encoder := transform.NewWriter(o, win16le.NewEncoder()) - s.SetOutput(encoder) - } else { - s.SetOutput(o) - } - } - return nil -} - -// errorCommand changes the error writer to use a file -func errorCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || args[0] == "" { - return InvalidCommandError("ERROR", line) - } - filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) - if err != nil { - return err - } - switch { - case strings.EqualFold(filePath, "stderr"): - s.SetError(os.Stderr) - case strings.EqualFold(filePath, "stdout"): - s.SetError(os.Stdout) - default: - o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return InvalidFileError(err, args[0]) - } - s.SetError(o) - } - return nil -} - -func readFileCommand(s *Sqlcmd, args []string, line uint) error { - if args == nil || len(args) != 1 { - return InvalidCommandError(":R", line) - } - fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false) - return s.IncludeFile(fileName, false) -} - -// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables -func setVarCommand(s *Sqlcmd, args []string, line uint) error { - if args == nil || len(args) != 1 || args[0] == "" { - return InvalidCommandError(":SETVAR", line) - } - - varname := args[0] - val := "" - // The prior incarnation of sqlcmd doesn't require a space between the variable name and its value - // in some very unexpected cases. This version will require the space. - sp := strings.IndexRune(args[0], ' ') - if sp > -1 { - val = strings.TrimSpace(varname[sp:]) - varname = varname[:sp] - } - if err := s.vars.Setvar(varname, val); err != nil { - switch e := err.(type) { - case *VariableError: - return e - default: - return InvalidCommandError(":SETVAR", line) - } - } - return nil -} - -// listVarCommand prints the set of Sqlcmd scripting variables. -// Builtin values are printed first, followed by user-set values in sorted order. -func listVarCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("LISTVAR", line) - } - - vars := s.vars.All() - keys := make([]string, 0, len(vars)) - for k := range vars { - if !contains(builtinVariables, k) { - keys = append(keys, k) - } - } - sort.Strings(keys) - keys = append(builtinVariables, keys...) - for _, k := range keys { - fmt.Fprintf(s.GetOutput(), `%s = "%s"%s`, k, vars[k], SqlcmdEol) - } - return nil -} - -// resetCommand resets the statement cache -func resetCommand(s *Sqlcmd, args []string, line uint) error { - if s.batch != nil { - s.batch.Reset(nil) - } - - return nil -} - -// listCommand displays statements currently in the statement cache -func listCommand(s *Sqlcmd, args []string, line uint) (err error) { - cmd := "" - if args != nil { - if len(args) > 0 { - cmd = strings.ToLower(strings.TrimSpace(args[0])) - if len(args) > 1 || (cmd != "color" && cmd != "") { - return InvalidCommandError("LIST", line) - } - } - } - output := s.GetOutput() - if cmd == "color" { - sample := "select 'literal' as literal, 100 as number from [sys].[tables]" - clr := color.TextTypeTSql - if s.Format.IsXmlMode() { - sample = `value` - clr = color.TextTypeXml - } - // ignoring errors since it's not critical output - for _, style := range s.colorizer.Styles() { - _, _ = output.Write([]byte(style + ": ")) - _ = s.colorizer.Write(output, sample, style, clr) - _, _ = output.Write([]byte(SqlcmdEol)) - } - return - } - if s.batch == nil || s.batch.String() == "" { - return - } - - if err = s.colorizer.Write(output, s.batch.String(), s.vars.ColorScheme(), color.TextTypeTSql); err == nil { - _, err = output.Write([]byte(SqlcmdEol)) - } - - return -} - -func connectCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 { - return InvalidCommandError("CONNECT", line) - } - - commandArgs := strings.Fields(args[0]) - - // Parse flags - flags := flag.NewFlagSet("connect", flag.ContinueOnError) - database := flags.String("D", "", "database name") - username := flags.String("U", "", "user name") - password := flags.String("P", "", "password") - loginTimeout := flags.String("l", "", "login timeout") - authenticationMethod := flags.String("G", "", "authentication method") - - err := flags.Parse(commandArgs[1:]) - //err := flags.Parse(args[1:]) - if err != nil { - return InvalidCommandError("CONNECT", line) - } - - connect := *s.Connect - connect.UserName, _ = resolveArgumentVariables(s, []rune(*username), false) - connect.Password, _ = resolveArgumentVariables(s, []rune(*password), false) - connect.Database, _ = resolveArgumentVariables(s, []rune(*database), false) - - timeout, _ := resolveArgumentVariables(s, []rune(*loginTimeout), false) - if timeout != "" { - if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil { - if timeoutSeconds < 0 { - return InvalidCommandError("CONNECT", line) - } - connect.LoginTimeoutSeconds = int(timeoutSeconds) - } - } - - connect.AuthenticationMethod = *authenticationMethod - - // Set server name as the first positional argument - if len(commandArgs) > 0 { - connect.ServerName, _ = resolveArgumentVariables(s, []rune(commandArgs[0]), false) - } - - // If no user name is provided we switch to integrated auth - _ = s.ConnectDb(&connect, s.lineIo == nil) - - // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option - return nil -} - -func execCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 { - return InvalidCommandError("EXEC", line) - } - cmdLine := strings.TrimSpace(args[0]) - if cmdLine == "" { - return InvalidCommandError("EXEC", line) - } - if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil { - return err - } else { - cmd := sysCommand(cmdLine) - cmd.Stderr = s.GetError() - cmd.Stdout = s.GetOutput() - _ = cmd.Run() - } - return nil -} - -func editCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("ED", line) - } - file, err := os.CreateTemp("", "sq*.sql") - if err != nil { - return err - } - fileName := file.Name() - defer os.Remove(fileName) - text := s.batch.String() - if s.batch.State() == "-" { - text = fmt.Sprintf("%s%s", text, SqlcmdEol) - } - _, err = file.WriteString(text) - if err != nil { - return err - } - file.Close() - cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`) - cmd.Stderr = s.GetError() - cmd.Stdout = s.GetOutput() - err = cmd.Run() - if err != nil { - return err - } - wasEcho := s.echoFileLines - s.echoFileLines = true - s.batch.Reset(nil) - _ = s.IncludeFile(fileName, false) - s.echoFileLines = wasEcho - return nil -} - -func onerrorCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || args[0] == "" { - return InvalidCommandError("ON ERROR", line) - } - params := strings.TrimSpace(args[0]) - - if strings.EqualFold(strings.ToLower(params), "exit") { - s.Connect.ExitOnError = true - } else if strings.EqualFold(strings.ToLower(params), "ignore") { - s.Connect.IgnoreError = true - s.Connect.ExitOnError = false - } else { - return InvalidCommandError("ON ERROR", line) - } - return nil -} - -func xmlCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) != 1 || args[0] == "" { - return InvalidCommandError("XML", line) - } - params := strings.TrimSpace(args[0]) - // "OFF" and "ON" are documented as the allowed values. - // ODBC sqlcmd treats any value other than "ON" the same as "OFF". - // So we will too. - if strings.EqualFold(params, "on") { - s.Format.XmlMode(true) - } else { - s.Format.XmlMode(false) - } - return nil -} - -func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { - var b *strings.Builder - end := len(arg) - for i := 0; i < end && !s.Connect.DisableVariableSubstitution; { - c, next := arg[i], grab(arg, i+1, end) - switch { - case c == '$' && next == '(': - vl, ok := readVariableReference(arg, i+2, end) - if ok { - varName := string(arg[i+2 : vl]) - val, ok := s.resolveVariable(varName) - if ok { - if b == nil { - b = new(strings.Builder) - b.Grow(len(arg)) - b.WriteString(string(arg[0:i])) - } - b.WriteString(val) - } else { - if failOnUnresolved { - return "", UndefinedVariable(varName) - } - s.WriteError(s.GetError(), UndefinedVariable(varName)) - if b != nil { - b.WriteString(string(arg[i : vl+1])) - } - } - i += ((vl - i) + 1) - } else { - if b != nil { - b.WriteString("$(") - } - i += 2 - } - default: - if b != nil { - b.WriteRune(c) - } - i++ - } - } - if b == nil { - return string(arg), nil - } - return b.String(), nil -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "flag" + "fmt" + "os" + "regexp" + "sort" + "strconv" + "strings" + + "github.com/microsoft/go-sqlcmd/internal/color" + "golang.org/x/text/encoding/unicode" + "golang.org/x/text/transform" +) + +// Command defines a sqlcmd action which can be intermixed with the SQL batch +// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands +type Command struct { + // regex must include at least one group if it has parameters + // Will be matched using FindStringSubmatch + regex *regexp.Regexp + // The function that implements the command. Third parameter is the line number + action func(*Sqlcmd, []string, uint) error + // Name of the command + name string + // whether the command is a system command + isSystem bool +} + +// Commands is the set of sqlcmd command implementations +type Commands map[string]*Command + +func newCommands() Commands { + // Commands is the set of Command implementations + return map[string]*Command{ + "EXIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT([\( \t]+.*\)*$|$)`), + action: exitCommand, + name: "EXIT", + }, + "QUIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`), + action: quitCommand, + name: "QUIT", + }, + "GO": { + regex: regexp.MustCompile(batchTerminatorRegex("GO")), + action: goCommand, + name: "GO", + }, + "OUT": { + regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`), + action: outCommand, + name: "OUT", + }, + "ERROR": { + regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`), + action: errorCommand, + name: "ERROR", + }, "READFILE": { + regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`), + action: readFileCommand, + name: "READFILE", + }, + "SETVAR": { + regex: regexp.MustCompile(`(?im)^[ \t]*:SETVAR(?:[ \t]+(.*$)|$)`), + action: setVarCommand, + name: "SETVAR", + }, + "LISTVAR": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:LISTVAR(?:[ \t]+(.*$)|$)`), + action: listVarCommand, + name: "LISTVAR", + }, + "RESET": { + regex: regexp.MustCompile(`(?im)^[ \t]*?:?RESET(?:[ \t]+(.*$)|$)`), + action: resetCommand, + name: "RESET", + }, + "LIST": { + regex: regexp.MustCompile(`(?im)^[ \t]*:LIST(?:[ \t]+(.*$)|$)`), + action: listCommand, + name: "LIST", + }, + "CONNECT": { + regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`), + action: connectCommand, + name: "CONNECT", + }, + "EXEC": { + regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(.*$)`), + action: execCommand, + name: "EXEC", + isSystem: true, + }, + "EDIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`), + action: editCommand, + name: "EDIT", + isSystem: true, + }, + "ONERROR": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?ON ERROR(?:[ \t]+(.*$)|$)`), + action: onerrorCommand, + name: "ONERROR", + }, + "XML": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:XML(?:[ \t]+(.*$)|$)`), + action: xmlCommand, + name: "XML", + }, + "HELP": { + regex: regexp.MustCompile(`(?im)^[ \t]*:HELP(?:[ \t]+(.*$)|$)`), + action: helpCommand, + name: "HELP", + }, + "SERVERLIST": { + regex: regexp.MustCompile(`(?im)^[ \t]*:SERVERLIST(?:[ \t]+(.*$)|$)`), + action: serverlistCommand, + name: "SERVERLIST", + }, + } +} + +// DisableSysCommands disables the ED and :!! commands. +// When exitOnCall is true, running those commands will exit the process. +func (c Commands) DisableSysCommands(exitOnCall bool) { + f := warnDisabled + if exitOnCall { + f = errorDisabled + } + for _, cmd := range c { + if cmd.isSystem { + cmd.action = f + } + } +} + +func (c Commands) matchCommand(line string) (*Command, []string) { + for _, cmd := range c { + matchedCommand := cmd.regex.FindStringSubmatch(line) + if matchedCommand != nil { + return cmd, removeComments(matchedCommand[1:]) + } + } + return nil, nil +} + +func removeComments(args []string) []string { + var pos int + quote := false + for i := range args { + pos, quote = commentStart([]rune(args[i]), quote) + if pos > -1 { + out := make([]string, i+1) + if i > 0 { + copy(out, args[:i]) + } + out[i] = args[i][:pos] + return out + } + } + return args +} + +func commentStart(arg []rune, quote bool) (int, bool) { + var i int + space := true + for ; i < len(arg); i++ { + c, next := arg[i], grab(arg, i+1, len(arg)) + switch { + case quote && c == '"' && next != '"': + quote = false + case quote && c == '"' && next == '"': + i++ + case c == '\t' || c == ' ': + space = true + // Note we assume none of the regexes would split arguments on non-whitespace boundaries such that "text -- comment" would get split into "text -" and "- comment" + case !quote && space && c == '-' && next == '-': + return i, false + case !quote && c == '"': + quote = true + default: + space = false + } + } + return -1, quote +} + +func warnDisabled(s *Sqlcmd, args []string, line uint) error { + s.WriteError(s.GetError(), ErrCommandsDisabled) + return nil +} + +func errorDisabled(s *Sqlcmd, args []string, line uint) error { + s.WriteError(s.GetError(), ErrCommandsDisabled) + s.Exitcode = 1 + return ErrExitRequested +} + +func batchTerminatorRegex(terminator string) string { + return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator)) +} + +// SetBatchTerminator attempts to set the batch terminator to the given value +// Returns an error if the new value is not usable in the regex +func (c Commands) SetBatchTerminator(terminator string) error { + cmd := c["GO"] + regex, err := regexp.Compile(batchTerminatorRegex(terminator)) + if err != nil { + return err + } + cmd.regex = regex + return nil +} + +// exitCommand has 3 modes. +// With no (), it just exits without running any query +// With () it runs whatever batch is in the buffer then exits +// With any text between () it runs the text as a query then exits +func exitCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return ErrExitRequested + } + params := strings.TrimSpace(args[0]) + if params == "" { + return ErrExitRequested + } + if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") { + return InvalidCommandError("EXIT", line) + } + // First we save the current batch + query1 := s.batch.String() + if len(query1) > 0 { + query1 = s.getRunnableQuery(query1) + } + // Now parse the params of EXIT as a batch without commands + cmd := s.batch.cmd + s.batch.cmd = nil + defer func() { + s.batch.cmd = cmd + }() + query2 := strings.TrimSpace(params[1 : len(params)-1]) + if len(query2) > 0 { + s.batch.Reset([]rune(query2)) + _, _, err := s.batch.Next() + if err != nil { + return err + } + query2 = s.batch.String() + if len(query2) > 0 { + query2 = s.getRunnableQuery(query2) + } + } + + if len(query1) > 0 || len(query2) > 0 { + query := query1 + SqlcmdEol + query2 + s.Exitcode, _ = s.runQuery(query) + } + return ErrExitRequested +} + +// quitCommand immediately exits the program without running any more batches +func quitCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("QUIT", line) + } + return ErrExitRequested +} + +// goCommand runs the current batch the number of times specified +func goCommand(s *Sqlcmd, args []string, line uint) error { + // default to 1 execution + n := 1 + var err error + if len(args) > 0 { + cnt := strings.TrimSpace(args[0]) + if cnt != "" { + if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil { + return err + } + _, err = fmt.Sscanf(cnt, "%d", &n) + } + } + if err != nil || n < 1 { + return InvalidCommandError("GO", line) + } + if s.EchoInput { + err = listCommand(s, []string{}, line) + } + if err != nil { + return InvalidCommandError("GO", line) + } + query := s.batch.String() + if query == "" { + return nil + } + query = s.getRunnableQuery(query) + for i := 0; i < n; i++ { + if retcode, err := s.runQuery(query); err != nil { + s.Exitcode = retcode + return err + } + } + s.batch.Reset(nil) + return nil +} + +// outCommand changes the output writer to use a file +func outCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError("OUT", line) + } + filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) + if err != nil { + return err + } + + switch { + case strings.EqualFold(filePath, "stdout"): + s.SetOutput(os.Stdout) + case strings.EqualFold(filePath, "stderr"): + s.SetOutput(os.Stderr) + default: + o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return InvalidFileError(err, args[0]) + } + if s.UnicodeOutputFile { + // ODBC sqlcmd doesn't write a BOM but we will. + // Maybe the endian-ness should be configurable. + win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) + encoder := transform.NewWriter(o, win16le.NewEncoder()) + s.SetOutput(encoder) + } else { + s.SetOutput(o) + } + } + return nil +} + +// errorCommand changes the error writer to use a file +func errorCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError("ERROR", line) + } + filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) + if err != nil { + return err + } + switch { + case strings.EqualFold(filePath, "stderr"): + s.SetError(os.Stderr) + case strings.EqualFold(filePath, "stdout"): + s.SetError(os.Stdout) + default: + o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return InvalidFileError(err, args[0]) + } + s.SetError(o) + } + return nil +} + +func readFileCommand(s *Sqlcmd, args []string, line uint) error { + if args == nil || len(args) != 1 { + return InvalidCommandError(":R", line) + } + fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false) + return s.IncludeFile(fileName, false) +} + +// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables +func setVarCommand(s *Sqlcmd, args []string, line uint) error { + if args == nil || len(args) != 1 || args[0] == "" { + return InvalidCommandError(":SETVAR", line) + } + + varname := args[0] + val := "" + // The prior incarnation of sqlcmd doesn't require a space between the variable name and its value + // in some very unexpected cases. This version will require the space. + sp := strings.IndexRune(args[0], ' ') + if sp > -1 { + val = strings.TrimSpace(varname[sp:]) + varname = varname[:sp] + } + if err := s.vars.Setvar(varname, val); err != nil { + switch e := err.(type) { + case *VariableError: + return e + default: + return InvalidCommandError(":SETVAR", line) + } + } + return nil +} + +// listVarCommand prints the set of Sqlcmd scripting variables. +// Builtin values are printed first, followed by user-set values in sorted order. +func listVarCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("LISTVAR", line) + } + + vars := s.vars.All() + keys := make([]string, 0, len(vars)) + for k := range vars { + if !contains(builtinVariables, k) { + keys = append(keys, k) + } + } + sort.Strings(keys) + keys = append(builtinVariables, keys...) + for _, k := range keys { + fmt.Fprintf(s.GetOutput(), `%s = "%s"%s`, k, vars[k], SqlcmdEol) + } + return nil +} + +// resetCommand resets the statement cache +func resetCommand(s *Sqlcmd, args []string, line uint) error { + if s.batch != nil { + s.batch.Reset(nil) + } + + return nil +} + +// listCommand displays statements currently in the statement cache +func listCommand(s *Sqlcmd, args []string, line uint) (err error) { + cmd := "" + if args != nil { + if len(args) > 0 { + cmd = strings.ToLower(strings.TrimSpace(args[0])) + if len(args) > 1 || (cmd != "color" && cmd != "") { + return InvalidCommandError("LIST", line) + } + } + } + output := s.GetOutput() + if cmd == "color" { + sample := "select 'literal' as literal, 100 as number from [sys].[tables]" + clr := color.TextTypeTSql + if s.Format.IsXmlMode() { + sample = `value` + clr = color.TextTypeXml + } + // ignoring errors since it's not critical output + for _, style := range s.colorizer.Styles() { + _, _ = output.Write([]byte(style + ": ")) + _ = s.colorizer.Write(output, sample, style, clr) + _, _ = output.Write([]byte(SqlcmdEol)) + } + return + } + if s.batch == nil || s.batch.String() == "" { + return + } + + if err = s.colorizer.Write(output, s.batch.String(), s.vars.ColorScheme(), color.TextTypeTSql); err == nil { + _, err = output.Write([]byte(SqlcmdEol)) + } + + return +} + +func connectCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return InvalidCommandError("CONNECT", line) + } + + commandArgs := strings.Fields(args[0]) + + // Parse flags + flags := flag.NewFlagSet("connect", flag.ContinueOnError) + database := flags.String("D", "", "database name") + username := flags.String("U", "", "user name") + password := flags.String("P", "", "password") + loginTimeout := flags.String("l", "", "login timeout") + authenticationMethod := flags.String("G", "", "authentication method") + + err := flags.Parse(commandArgs[1:]) + //err := flags.Parse(args[1:]) + if err != nil { + return InvalidCommandError("CONNECT", line) + } + + connect := *s.Connect + connect.UserName, _ = resolveArgumentVariables(s, []rune(*username), false) + connect.Password, _ = resolveArgumentVariables(s, []rune(*password), false) + connect.Database, _ = resolveArgumentVariables(s, []rune(*database), false) + + timeout, _ := resolveArgumentVariables(s, []rune(*loginTimeout), false) + if timeout != "" { + if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil { + if timeoutSeconds < 0 { + return InvalidCommandError("CONNECT", line) + } + connect.LoginTimeoutSeconds = int(timeoutSeconds) + } + } + + connect.AuthenticationMethod = *authenticationMethod + + // Set server name as the first positional argument + if len(commandArgs) > 0 { + connect.ServerName, _ = resolveArgumentVariables(s, []rune(commandArgs[0]), false) + } + + // If no user name is provided we switch to integrated auth + _ = s.ConnectDb(&connect, s.lineIo == nil) + + // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option + return nil +} + +func execCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return InvalidCommandError("EXEC", line) + } + cmdLine := strings.TrimSpace(args[0]) + if cmdLine == "" { + return InvalidCommandError("EXEC", line) + } + if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil { + return err + } else { + cmd := sysCommand(cmdLine) + cmd.Stderr = s.GetError() + cmd.Stdout = s.GetOutput() + _ = cmd.Run() + } + return nil +} + +func editCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("ED", line) + } + file, err := os.CreateTemp("", "sq*.sql") + if err != nil { + return err + } + fileName := file.Name() + defer os.Remove(fileName) + text := s.batch.String() + if s.batch.State() == "-" { + text = fmt.Sprintf("%s%s", text, SqlcmdEol) + } + _, err = file.WriteString(text) + if err != nil { + return err + } + file.Close() + cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`) + cmd.Stderr = s.GetError() + cmd.Stdout = s.GetOutput() + err = cmd.Run() + if err != nil { + return err + } + wasEcho := s.echoFileLines + s.echoFileLines = true + s.batch.Reset(nil) + _ = s.IncludeFile(fileName, false) + s.echoFileLines = wasEcho + return nil +} + +func onerrorCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError("ON ERROR", line) + } + params := strings.TrimSpace(args[0]) + + if strings.EqualFold(strings.ToLower(params), "exit") { + s.Connect.ExitOnError = true + } else if strings.EqualFold(strings.ToLower(params), "ignore") { + s.Connect.IgnoreError = true + s.Connect.ExitOnError = false + } else { + return InvalidCommandError("ON ERROR", line) + } + return nil +} + +func xmlCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) != 1 || args[0] == "" { + return InvalidCommandError("XML", line) + } + params := strings.TrimSpace(args[0]) + // "OFF" and "ON" are documented as the allowed values. + // ODBC sqlcmd treats any value other than "ON" the same as "OFF". + // So we will too. + if strings.EqualFold(params, "on") { + s.Format.XmlMode(true) + } else { + s.Format.XmlMode(false) + } + return nil +} + +// helpCommand displays the list of available sqlcmd commands +func helpCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) > 0 && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("HELP", line) + } + helpText := `:!! [] + - Executes a command in the operating system shell. +:connect server[\instance] [-l timeout] [-U user [-P password]] + - Connects to a SQL Server instance. +:ed + - Edits the current or last executed statement cache. +:error + - Redirects error output to a file, stderr, or stdout. +:exit + - Quits sqlcmd immediately. +:exit() + - Execute statement cache; quit with no return value. +:exit() + - Execute the specified query; returns numeric result. +go [] + - Executes the statement cache (n times). +:help + - Shows this list of commands. +:list + - Prints the content of the statement cache. +:listvar + - Lists the set sqlcmd scripting variables. +:on error [exit|ignore] + - Action for batch or sqlcmd command errors. +:out |stderr|stdout + - Redirects query output to a file, stderr, or stdout. +:quit + - Quits sqlcmd immediately. +:r + - Append file contents to the statement cache. +:reset + - Discards the statement cache. +:serverlist + - Lists local SQL Server instances. +:setvar {variable} + - Removes a sqlcmd scripting variable. +:setvar + - Sets a sqlcmd scripting variable. +:xml [on|off] + - Sets XML output mode. +` + _, err := s.GetOutput().Write([]byte(helpText)) + return err +} + +func serverlistCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) > 0 && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("SERVERLIST", line) + } + ListLocalServers(s.GetOutput()) + return nil +} + +func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { + var b *strings.Builder + end := len(arg) + for i := 0; i < end && !s.Connect.DisableVariableSubstitution; { + c, next := arg[i], grab(arg, i+1, end) + switch { + case c == '$' && next == '(': + vl, ok := readVariableReference(arg, i+2, end) + if ok { + varName := string(arg[i+2 : vl]) + val, ok := s.resolveVariable(varName) + if ok { + if b == nil { + b = new(strings.Builder) + b.Grow(len(arg)) + b.WriteString(string(arg[0:i])) + } + b.WriteString(val) + } else { + if failOnUnresolved { + return "", UndefinedVariable(varName) + } + s.WriteError(s.GetError(), UndefinedVariable(varName)) + if b != nil { + b.WriteString(string(arg[i : vl+1])) + } + } + i += ((vl - i) + 1) + } else { + if b != nil { + b.WriteString("$(") + } + i += 2 + } + default: + if b != nil { + b.WriteRune(c) + } + i++ + } + } + if b == nil { + return string(arg), nil + } + return b.String(), nil +} diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 6197aa3f..7895e307 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -54,6 +54,10 @@ func TestCommandParsing(t *testing.T) { {`:XML ON `, "XML", []string{`ON `}}, {`:RESET`, "RESET", []string{""}}, {`RESET`, "RESET", []string{""}}, + {`:HELP`, "HELP", []string{""}}, + {`:help`, "HELP", []string{""}}, + {`:SERVERLIST`, "SERVERLIST", []string{""}}, + {`:serverlist`, "SERVERLIST", []string{""}}, } for _, test := range commands { @@ -458,3 +462,25 @@ func TestExitCommandAppendsParameterToCurrentBatch(t *testing.T) { } } + +func TestHelpCommand(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.SetOutput(buf) + + err := helpCommand(s, []string{""}, 1) + assert.NoError(t, err, "helpCommand should not error") + + output := buf.buf.String() + // Verify key commands are listed + assert.Contains(t, output, ":connect", "help should list :connect") + assert.Contains(t, output, ":exit", "help should list :exit") + assert.Contains(t, output, ":help", "help should list :help") + assert.Contains(t, output, ":setvar", "help should list :setvar") + assert.Contains(t, output, ":listvar", "help should list :listvar") + assert.Contains(t, output, ":out", "help should list :out") + assert.Contains(t, output, ":error", "help should list :error") + assert.Contains(t, output, ":r", "help should list :r") + assert.Contains(t, output, ":serverlist", "help should list :serverlist") + assert.Contains(t, output, "go", "help should list go") +} diff --git a/pkg/sqlcmd/serverlist.go b/pkg/sqlcmd/serverlist.go new file mode 100644 index 00000000..efe29c7d --- /dev/null +++ b/pkg/sqlcmd/serverlist.go @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "sort" + "strings" + "time" + + "github.com/microsoft/go-mssqldb/msdsn" +) + +// ListLocalServers queries the SQL Browser service for available SQL Server instances +// and writes the results to the provided writer. +func ListLocalServers(w io.Writer) { + instances, err := GetLocalServerInstances() + if err != nil { + fmt.Fprintln(os.Stderr, err) + } + for _, s := range instances { + fmt.Fprintf(w, " %s\n", s) + } +} + +// GetLocalServerInstances queries the SQL Browser service and returns a list of +// available SQL Server instances on the local machine. +// Returns an error for non-timeout network errors. +func GetLocalServerInstances() ([]string, error) { + bmsg := []byte{byte(msdsn.BrowserAllInstances)} + resp := make([]byte, 16*1024-1) + dialer := &net.Dialer{} + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + conn, err := dialer.DialContext(ctx, "udp", ":1434") + // silently ignore failures to connect, same as ODBC + if err != nil { + return nil, nil + } + defer conn.Close() + dl, _ := ctx.Deadline() + _ = conn.SetDeadline(dl) + _, err = conn.Write(bmsg) + if err != nil { + // Only return error if it's not a timeout + if !errors.Is(err, os.ErrDeadlineExceeded) { + return nil, err + } + return nil, nil + } + read, err := conn.Read(resp) + if err != nil { + // Only return error if it's not a timeout + if !errors.Is(err, os.ErrDeadlineExceeded) { + return nil, err + } + return nil, nil + } + + data := parseInstances(resp[:read]) + instances := make([]string, 0, len(data)) + + // Sort instance names for deterministic output + instanceNames := make([]string, 0, len(data)) + for s := range data { + instanceNames = append(instanceNames, s) + } + sort.Strings(instanceNames) + + for _, s := range instanceNames { + serverName := data[s]["ServerName"] + if serverName == "" { + // Skip instances without a ServerName + continue + } + if s == "MSSQLSERVER" { + instances = append(instances, "(local)", serverName) + } else { + instances = append(instances, fmt.Sprintf(`%s\%s`, serverName, s)) + } + } + return instances, nil +} + +func parseInstances(msg []byte) msdsn.BrowserData { + results := msdsn.BrowserData{} + if len(msg) > 3 && msg[0] == 5 { + outStr := string(msg[3:]) + tokens := strings.Split(outStr, ";") + instanceDict := map[string]string{} + gotName := false + var name string + for _, token := range tokens { + if gotName { + instanceDict[name] = token + gotName = false + } else { + name = token + if len(name) == 0 { + if len(instanceDict) == 0 { + break + } + // Only add if InstanceName key exists and is non-empty + if instName, ok := instanceDict["InstanceName"]; ok && instName != "" { + results[strings.ToUpper(instName)] = instanceDict + } + instanceDict = map[string]string{} + continue + } + gotName = true + } + } + } + return results +} diff --git a/pkg/sqlcmd/serverlist_test.go b/pkg/sqlcmd/serverlist_test.go new file mode 100644 index 00000000..3ab20920 --- /dev/null +++ b/pkg/sqlcmd/serverlist_test.go @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestListLocalServers(t *testing.T) { + // Test that ListLocalServers writes to the provided writer without error + // Note: actual server discovery depends on SQL Browser service availability + var buf bytes.Buffer + ListLocalServers(&buf) + // We can't assert specific content since it depends on environment, + // but we verify it doesn't panic and writes valid output + t.Logf("ListLocalServers output: %q", buf.String()) +} + +func TestGetLocalServerInstances(t *testing.T) { + // Test that GetLocalServerInstances returns a slice (may be empty if no servers) + instances, err := GetLocalServerInstances() + // instances may be nil or empty if no SQL Browser is running, that's OK + // err may be non-nil for non-timeout network errors + if err != nil { + t.Logf("GetLocalServerInstances returned error (expected in some environments): %v", err) + } + t.Logf("Found %d instances", len(instances)) + for _, inst := range instances { + assert.NotEmpty(t, inst, "Instance name should not be empty") + } +} + +func TestParseInstances(t *testing.T) { + // Test parsing of SQL Browser response + // Format: 0x05 (response type), 2 bytes length, then semicolon-separated key=value pairs + // Each instance ends with two semicolons + + t.Run("empty response", func(t *testing.T) { + result := parseInstances([]byte{}) + assert.Empty(t, result) + }) + + t.Run("invalid header", func(t *testing.T) { + result := parseInstances([]byte{1, 0, 0}) + assert.Empty(t, result) + }) + + t.Run("valid single instance", func(t *testing.T) { + // Simulating SQL Browser response format + // Header: 0x05 followed by 2 length bytes, then the instance data + data := []byte{5, 0, 0} + instanceData := "ServerName;MYSERVER;InstanceName;MSSQLSERVER;IsClustered;No;Version;15.0.2000.5;tcp;1433;;" + data = append(data, []byte(instanceData)...) + + result := parseInstances(data) + assert.Len(t, result, 1) + assert.Contains(t, result, "MSSQLSERVER") + assert.Equal(t, "MYSERVER", result["MSSQLSERVER"]["ServerName"]) + assert.Equal(t, "1433", result["MSSQLSERVER"]["tcp"]) + }) + + t.Run("valid multiple instances", func(t *testing.T) { + data := []byte{5, 0, 0} + instanceData := "ServerName;MYSERVER;InstanceName;MSSQLSERVER;tcp;1433;;ServerName;MYSERVER;InstanceName;SQLEXPRESS;tcp;1434;;" + data = append(data, []byte(instanceData)...) + + result := parseInstances(data) + assert.Len(t, result, 2) + assert.Contains(t, result, "MSSQLSERVER") + assert.Contains(t, result, "SQLEXPRESS") + }) +} + +func TestServerlistCommand(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + + // Run the serverlist command + c := []string{":serverlist"} + err := runSqlCmd(t, s, c) + + // The command should not raise an error even if no servers are found + assert.NoError(t, err, ":serverlist should not raise error") + // Output may be empty if no SQL Browser is running + t.Logf("Serverlist output: %q", buf.buf.String()) +}