From 41976931441fc3004b058786d6831406047e8d23 Mon Sep 17 00:00:00 2001 From: David Levy Date: Sat, 24 Jan 2026 21:51:11 -0600 Subject: [PATCH 1/8] Implement :serverlist command - Add serverlist command to list SQL Server instances via SQL Browser service - Move server listing logic from cmd/sqlcmd to pkg/sqlcmd for reuse - Both -L flag and :serverlist command now use shared ListLocalServers function - Add comprehensive tests for serverlist functionality --- cmd/sqlcmd/sqlcmd.go | 79 +- pkg/sqlcmd/commands.go | 1298 +++++++++++++++++---------------- pkg/sqlcmd/serverlist.go | 98 +++ pkg/sqlcmd/serverlist_test.go | 86 +++ 4 files changed, 839 insertions(+), 722 deletions(-) create mode 100644 pkg/sqlcmd/serverlist.go create mode 100644 pkg/sqlcmd/serverlist_test.go diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index ea655b47..eb769f0a 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -5,20 +5,16 @@ package sqlcmd import ( - "context" "errors" "fmt" - "net" "os" "regexp" "runtime/trace" "strconv" "strings" - "time" mssql "github.com/microsoft/go-mssqldb" "github.com/microsoft/go-mssqldb/azuread" - "github.com/microsoft/go-mssqldb/msdsn" "github.com/microsoft/go-sqlcmd/internal/localizer" "github.com/microsoft/go-sqlcmd/pkg/console" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" @@ -236,7 +232,7 @@ func Execute(version string) { fmt.Println() fmt.Println(localizer.Sprintf("Servers:")) } - listLocalServers() + sqlcmd.ListLocalServers(os.Stdout) os.Exit(0) } if len(argss) > 0 { @@ -911,76 +907,3 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { s.SetError(nil) return s.Exitcode, err } - -func listLocalServers() { - bmsg := []byte{byte(msdsn.BrowserAllInstances)} - resp := make([]byte, 16*1024-1) - dialer := &net.Dialer{} - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - conn, err := dialer.DialContext(ctx, "udp", ":1434") - // silently ignore failures to connect, same as ODBC - if err != nil { - return - } - defer conn.Close() - dl, _ := ctx.Deadline() - _ = conn.SetDeadline(dl) - _, err = conn.Write(bmsg) - if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - fmt.Println(err) - } - return - } - read, err := conn.Read(resp) - if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - fmt.Println(err) - } - return - } - - data := parseInstances(resp[:read]) - instances := make([]string, 0, len(data)) - for s := range data { - if s == "MSSQLSERVER" { - - instances = append(instances, "(local)", data[s]["ServerName"]) - } else { - instances = append(instances, fmt.Sprintf(`%s\%s`, data[s]["ServerName"], s)) - } - } - for _, s := range instances { - fmt.Println(" ", s) - } -} - -func parseInstances(msg []byte) msdsn.BrowserData { - results := msdsn.BrowserData{} - if len(msg) > 3 && msg[0] == 5 { - out_s := string(msg[3:]) - tokens := strings.Split(out_s, ";") - instdict := map[string]string{} - got_name := false - var name string - for _, token := range tokens { - if got_name { - instdict[name] = token - got_name = false - } else { - name = token - if len(name) == 0 { - if len(instdict) == 0 { - break - } - results[strings.ToUpper(instdict["InstanceName"])] = instdict - instdict = map[string]string{} - continue - } - got_name = true - } - } - } - return results -} diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 66dd1dba..f3c5850a 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -1,644 +1,654 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "flag" - "fmt" - "os" - "regexp" - "sort" - "strconv" - "strings" - - "github.com/microsoft/go-sqlcmd/internal/color" - "golang.org/x/text/encoding/unicode" - "golang.org/x/text/transform" -) - -// Command defines a sqlcmd action which can be intermixed with the SQL batch -// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands -type Command struct { - // regex must include at least one group if it has parameters - // Will be matched using FindStringSubmatch - regex *regexp.Regexp - // The function that implements the command. Third parameter is the line number - action func(*Sqlcmd, []string, uint) error - // Name of the command - name string - // whether the command is a system command - isSystem bool -} - -// Commands is the set of sqlcmd command implementations -type Commands map[string]*Command - -func newCommands() Commands { - // Commands is the set of Command implementations - return map[string]*Command{ - "EXIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT([\( \t]+.*\)*$|$)`), - action: exitCommand, - name: "EXIT", - }, - "QUIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`), - action: quitCommand, - name: "QUIT", - }, - "GO": { - regex: regexp.MustCompile(batchTerminatorRegex("GO")), - action: goCommand, - name: "GO", - }, - "OUT": { - regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`), - action: outCommand, - name: "OUT", - }, - "ERROR": { - regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`), - action: errorCommand, - name: "ERROR", - }, "READFILE": { - regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`), - action: readFileCommand, - name: "READFILE", - }, - "SETVAR": { - regex: regexp.MustCompile(`(?im)^[ \t]*:SETVAR(?:[ \t]+(.*$)|$)`), - action: setVarCommand, - name: "SETVAR", - }, - "LISTVAR": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:LISTVAR(?:[ \t]+(.*$)|$)`), - action: listVarCommand, - name: "LISTVAR", - }, - "RESET": { - regex: regexp.MustCompile(`(?im)^[ \t]*?:?RESET(?:[ \t]+(.*$)|$)`), - action: resetCommand, - name: "RESET", - }, - "LIST": { - regex: regexp.MustCompile(`(?im)^[ \t]*:LIST(?:[ \t]+(.*$)|$)`), - action: listCommand, - name: "LIST", - }, - "CONNECT": { - regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`), - action: connectCommand, - name: "CONNECT", - }, - "EXEC": { - regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(.*$)`), - action: execCommand, - name: "EXEC", - isSystem: true, - }, - "EDIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`), - action: editCommand, - name: "EDIT", - isSystem: true, - }, - "ONERROR": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?ON ERROR(?:[ \t]+(.*$)|$)`), - action: onerrorCommand, - name: "ONERROR", - }, - "XML": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:XML(?:[ \t]+(.*$)|$)`), - action: xmlCommand, - name: "XML", - }, - } -} - -// DisableSysCommands disables the ED and :!! commands. -// When exitOnCall is true, running those commands will exit the process. -func (c Commands) DisableSysCommands(exitOnCall bool) { - f := warnDisabled - if exitOnCall { - f = errorDisabled - } - for _, cmd := range c { - if cmd.isSystem { - cmd.action = f - } - } -} - -func (c Commands) matchCommand(line string) (*Command, []string) { - for _, cmd := range c { - matchedCommand := cmd.regex.FindStringSubmatch(line) - if matchedCommand != nil { - return cmd, removeComments(matchedCommand[1:]) - } - } - return nil, nil -} - -func removeComments(args []string) []string { - var pos int - quote := false - for i := range args { - pos, quote = commentStart([]rune(args[i]), quote) - if pos > -1 { - out := make([]string, i+1) - if i > 0 { - copy(out, args[:i]) - } - out[i] = args[i][:pos] - return out - } - } - return args -} - -func commentStart(arg []rune, quote bool) (int, bool) { - var i int - space := true - for ; i < len(arg); i++ { - c, next := arg[i], grab(arg, i+1, len(arg)) - switch { - case quote && c == '"' && next != '"': - quote = false - case quote && c == '"' && next == '"': - i++ - case c == '\t' || c == ' ': - space = true - // Note we assume none of the regexes would split arguments on non-whitespace boundaries such that "text -- comment" would get split into "text -" and "- comment" - case !quote && space && c == '-' && next == '-': - return i, false - case !quote && c == '"': - quote = true - default: - space = false - } - } - return -1, quote -} - -func warnDisabled(s *Sqlcmd, args []string, line uint) error { - s.WriteError(s.GetError(), ErrCommandsDisabled) - return nil -} - -func errorDisabled(s *Sqlcmd, args []string, line uint) error { - s.WriteError(s.GetError(), ErrCommandsDisabled) - s.Exitcode = 1 - return ErrExitRequested -} - -func batchTerminatorRegex(terminator string) string { - return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator)) -} - -// SetBatchTerminator attempts to set the batch terminator to the given value -// Returns an error if the new value is not usable in the regex -func (c Commands) SetBatchTerminator(terminator string) error { - cmd := c["GO"] - regex, err := regexp.Compile(batchTerminatorRegex(terminator)) - if err != nil { - return err - } - cmd.regex = regex - return nil -} - -// exitCommand has 3 modes. -// With no (), it just exits without running any query -// With () it runs whatever batch is in the buffer then exits -// With any text between () it runs the text as a query then exits -func exitCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 { - return ErrExitRequested - } - params := strings.TrimSpace(args[0]) - if params == "" { - return ErrExitRequested - } - if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") { - return InvalidCommandError("EXIT", line) - } - // First we save the current batch - query1 := s.batch.String() - if len(query1) > 0 { - query1 = s.getRunnableQuery(query1) - } - // Now parse the params of EXIT as a batch without commands - cmd := s.batch.cmd - s.batch.cmd = nil - defer func() { - s.batch.cmd = cmd - }() - query2 := strings.TrimSpace(params[1 : len(params)-1]) - if len(query2) > 0 { - s.batch.Reset([]rune(query2)) - _, _, err := s.batch.Next() - if err != nil { - return err - } - query2 = s.batch.String() - if len(query2) > 0 { - query2 = s.getRunnableQuery(query2) - } - } - - if len(query1) > 0 || len(query2) > 0 { - query := query1 + SqlcmdEol + query2 - s.Exitcode, _ = s.runQuery(query) - } - return ErrExitRequested -} - -// quitCommand immediately exits the program without running any more batches -func quitCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("QUIT", line) - } - return ErrExitRequested -} - -// goCommand runs the current batch the number of times specified -func goCommand(s *Sqlcmd, args []string, line uint) error { - // default to 1 execution - n := 1 - var err error - if len(args) > 0 { - cnt := strings.TrimSpace(args[0]) - if cnt != "" { - if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil { - return err - } - _, err = fmt.Sscanf(cnt, "%d", &n) - } - } - if err != nil || n < 1 { - return InvalidCommandError("GO", line) - } - if s.EchoInput { - err = listCommand(s, []string{}, line) - } - if err != nil { - return InvalidCommandError("GO", line) - } - query := s.batch.String() - if query == "" { - return nil - } - query = s.getRunnableQuery(query) - for i := 0; i < n; i++ { - if retcode, err := s.runQuery(query); err != nil { - s.Exitcode = retcode - return err - } - } - s.batch.Reset(nil) - return nil -} - -// outCommand changes the output writer to use a file -func outCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || args[0] == "" { - return InvalidCommandError("OUT", line) - } - filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) - if err != nil { - return err - } - - switch { - case strings.EqualFold(filePath, "stdout"): - s.SetOutput(os.Stdout) - case strings.EqualFold(filePath, "stderr"): - s.SetOutput(os.Stderr) - default: - o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return InvalidFileError(err, args[0]) - } - if s.UnicodeOutputFile { - // ODBC sqlcmd doesn't write a BOM but we will. - // Maybe the endian-ness should be configurable. - win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) - encoder := transform.NewWriter(o, win16le.NewEncoder()) - s.SetOutput(encoder) - } else { - s.SetOutput(o) - } - } - return nil -} - -// errorCommand changes the error writer to use a file -func errorCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || args[0] == "" { - return InvalidCommandError("ERROR", line) - } - filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) - if err != nil { - return err - } - switch { - case strings.EqualFold(filePath, "stderr"): - s.SetError(os.Stderr) - case strings.EqualFold(filePath, "stdout"): - s.SetError(os.Stdout) - default: - o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return InvalidFileError(err, args[0]) - } - s.SetError(o) - } - return nil -} - -func readFileCommand(s *Sqlcmd, args []string, line uint) error { - if args == nil || len(args) != 1 { - return InvalidCommandError(":R", line) - } - fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false) - return s.IncludeFile(fileName, false) -} - -// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables -func setVarCommand(s *Sqlcmd, args []string, line uint) error { - if args == nil || len(args) != 1 || args[0] == "" { - return InvalidCommandError(":SETVAR", line) - } - - varname := args[0] - val := "" - // The prior incarnation of sqlcmd doesn't require a space between the variable name and its value - // in some very unexpected cases. This version will require the space. - sp := strings.IndexRune(args[0], ' ') - if sp > -1 { - val = strings.TrimSpace(varname[sp:]) - varname = varname[:sp] - } - if err := s.vars.Setvar(varname, val); err != nil { - switch e := err.(type) { - case *VariableError: - return e - default: - return InvalidCommandError(":SETVAR", line) - } - } - return nil -} - -// listVarCommand prints the set of Sqlcmd scripting variables. -// Builtin values are printed first, followed by user-set values in sorted order. -func listVarCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("LISTVAR", line) - } - - vars := s.vars.All() - keys := make([]string, 0, len(vars)) - for k := range vars { - if !contains(builtinVariables, k) { - keys = append(keys, k) - } - } - sort.Strings(keys) - keys = append(builtinVariables, keys...) - for _, k := range keys { - fmt.Fprintf(s.GetOutput(), `%s = "%s"%s`, k, vars[k], SqlcmdEol) - } - return nil -} - -// resetCommand resets the statement cache -func resetCommand(s *Sqlcmd, args []string, line uint) error { - if s.batch != nil { - s.batch.Reset(nil) - } - - return nil -} - -// listCommand displays statements currently in the statement cache -func listCommand(s *Sqlcmd, args []string, line uint) (err error) { - cmd := "" - if args != nil { - if len(args) > 0 { - cmd = strings.ToLower(strings.TrimSpace(args[0])) - if len(args) > 1 || (cmd != "color" && cmd != "") { - return InvalidCommandError("LIST", line) - } - } - } - output := s.GetOutput() - if cmd == "color" { - sample := "select 'literal' as literal, 100 as number from [sys].[tables]" - clr := color.TextTypeTSql - if s.Format.IsXmlMode() { - sample = `value` - clr = color.TextTypeXml - } - // ignoring errors since it's not critical output - for _, style := range s.colorizer.Styles() { - _, _ = output.Write([]byte(style + ": ")) - _ = s.colorizer.Write(output, sample, style, clr) - _, _ = output.Write([]byte(SqlcmdEol)) - } - return - } - if s.batch == nil || s.batch.String() == "" { - return - } - - if err = s.colorizer.Write(output, s.batch.String(), s.vars.ColorScheme(), color.TextTypeTSql); err == nil { - _, err = output.Write([]byte(SqlcmdEol)) - } - - return -} - -func connectCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 { - return InvalidCommandError("CONNECT", line) - } - - commandArgs := strings.Fields(args[0]) - - // Parse flags - flags := flag.NewFlagSet("connect", flag.ContinueOnError) - database := flags.String("D", "", "database name") - username := flags.String("U", "", "user name") - password := flags.String("P", "", "password") - loginTimeout := flags.String("l", "", "login timeout") - authenticationMethod := flags.String("G", "", "authentication method") - - err := flags.Parse(commandArgs[1:]) - //err := flags.Parse(args[1:]) - if err != nil { - return InvalidCommandError("CONNECT", line) - } - - connect := *s.Connect - connect.UserName, _ = resolveArgumentVariables(s, []rune(*username), false) - connect.Password, _ = resolveArgumentVariables(s, []rune(*password), false) - connect.Database, _ = resolveArgumentVariables(s, []rune(*database), false) - - timeout, _ := resolveArgumentVariables(s, []rune(*loginTimeout), false) - if timeout != "" { - if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil { - if timeoutSeconds < 0 { - return InvalidCommandError("CONNECT", line) - } - connect.LoginTimeoutSeconds = int(timeoutSeconds) - } - } - - connect.AuthenticationMethod = *authenticationMethod - - // Set server name as the first positional argument - if len(commandArgs) > 0 { - connect.ServerName, _ = resolveArgumentVariables(s, []rune(commandArgs[0]), false) - } - - // If no user name is provided we switch to integrated auth - _ = s.ConnectDb(&connect, s.lineIo == nil) - - // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option - return nil -} - -func execCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 { - return InvalidCommandError("EXEC", line) - } - cmdLine := strings.TrimSpace(args[0]) - if cmdLine == "" { - return InvalidCommandError("EXEC", line) - } - if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil { - return err - } else { - cmd := sysCommand(cmdLine) - cmd.Stderr = s.GetError() - cmd.Stdout = s.GetOutput() - _ = cmd.Run() - } - return nil -} - -func editCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("ED", line) - } - file, err := os.CreateTemp("", "sq*.sql") - if err != nil { - return err - } - fileName := file.Name() - defer os.Remove(fileName) - text := s.batch.String() - if s.batch.State() == "-" { - text = fmt.Sprintf("%s%s", text, SqlcmdEol) - } - _, err = file.WriteString(text) - if err != nil { - return err - } - file.Close() - cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`) - cmd.Stderr = s.GetError() - cmd.Stdout = s.GetOutput() - err = cmd.Run() - if err != nil { - return err - } - wasEcho := s.echoFileLines - s.echoFileLines = true - s.batch.Reset(nil) - _ = s.IncludeFile(fileName, false) - s.echoFileLines = wasEcho - return nil -} - -func onerrorCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || args[0] == "" { - return InvalidCommandError("ON ERROR", line) - } - params := strings.TrimSpace(args[0]) - - if strings.EqualFold(strings.ToLower(params), "exit") { - s.Connect.ExitOnError = true - } else if strings.EqualFold(strings.ToLower(params), "ignore") { - s.Connect.IgnoreError = true - s.Connect.ExitOnError = false - } else { - return InvalidCommandError("ON ERROR", line) - } - return nil -} - -func xmlCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) != 1 || args[0] == "" { - return InvalidCommandError("XML", line) - } - params := strings.TrimSpace(args[0]) - // "OFF" and "ON" are documented as the allowed values. - // ODBC sqlcmd treats any value other than "ON" the same as "OFF". - // So we will too. - if strings.EqualFold(params, "on") { - s.Format.XmlMode(true) - } else { - s.Format.XmlMode(false) - } - return nil -} - -func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { - var b *strings.Builder - end := len(arg) - for i := 0; i < end && !s.Connect.DisableVariableSubstitution; { - c, next := arg[i], grab(arg, i+1, end) - switch { - case c == '$' && next == '(': - vl, ok := readVariableReference(arg, i+2, end) - if ok { - varName := string(arg[i+2 : vl]) - val, ok := s.resolveVariable(varName) - if ok { - if b == nil { - b = new(strings.Builder) - b.Grow(len(arg)) - b.WriteString(string(arg[0:i])) - } - b.WriteString(val) - } else { - if failOnUnresolved { - return "", UndefinedVariable(varName) - } - s.WriteError(s.GetError(), UndefinedVariable(varName)) - if b != nil { - b.WriteString(string(arg[i : vl+1])) - } - } - i += ((vl - i) + 1) - } else { - if b != nil { - b.WriteString("$(") - } - i += 2 - } - default: - if b != nil { - b.WriteRune(c) - } - i++ - } - } - if b == nil { - return string(arg), nil - } - return b.String(), nil -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "flag" + "fmt" + "os" + "regexp" + "sort" + "strconv" + "strings" + + "github.com/microsoft/go-sqlcmd/internal/color" + "golang.org/x/text/encoding/unicode" + "golang.org/x/text/transform" +) + +// Command defines a sqlcmd action which can be intermixed with the SQL batch +// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands +type Command struct { + // regex must include at least one group if it has parameters + // Will be matched using FindStringSubmatch + regex *regexp.Regexp + // The function that implements the command. Third parameter is the line number + action func(*Sqlcmd, []string, uint) error + // Name of the command + name string + // whether the command is a system command + isSystem bool +} + +// Commands is the set of sqlcmd command implementations +type Commands map[string]*Command + +func newCommands() Commands { + // Commands is the set of Command implementations + return map[string]*Command{ + "EXIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT([\( \t]+.*\)*$|$)`), + action: exitCommand, + name: "EXIT", + }, + "QUIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`), + action: quitCommand, + name: "QUIT", + }, + "GO": { + regex: regexp.MustCompile(batchTerminatorRegex("GO")), + action: goCommand, + name: "GO", + }, + "OUT": { + regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`), + action: outCommand, + name: "OUT", + }, + "ERROR": { + regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`), + action: errorCommand, + name: "ERROR", + }, "READFILE": { + regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`), + action: readFileCommand, + name: "READFILE", + }, + "SETVAR": { + regex: regexp.MustCompile(`(?im)^[ \t]*:SETVAR(?:[ \t]+(.*$)|$)`), + action: setVarCommand, + name: "SETVAR", + }, + "LISTVAR": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:LISTVAR(?:[ \t]+(.*$)|$)`), + action: listVarCommand, + name: "LISTVAR", + }, + "RESET": { + regex: regexp.MustCompile(`(?im)^[ \t]*?:?RESET(?:[ \t]+(.*$)|$)`), + action: resetCommand, + name: "RESET", + }, + "LIST": { + regex: regexp.MustCompile(`(?im)^[ \t]*:LIST(?:[ \t]+(.*$)|$)`), + action: listCommand, + name: "LIST", + }, + "CONNECT": { + regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`), + action: connectCommand, + name: "CONNECT", + }, + "EXEC": { + regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(.*$)`), + action: execCommand, + name: "EXEC", + isSystem: true, + }, + "EDIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`), + action: editCommand, + name: "EDIT", + isSystem: true, + }, + "ONERROR": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?ON ERROR(?:[ \t]+(.*$)|$)`), + action: onerrorCommand, + name: "ONERROR", + }, + "XML": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:XML(?:[ \t]+(.*$)|$)`), + action: xmlCommand, + name: "XML", + }, + "SERVERLIST": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:SERVERLIST(?:[ \t]+(.*$)|$)`), + action: serverlistCommand, + name: "SERVERLIST", + }, + } +} + +// DisableSysCommands disables the ED and :!! commands. +// When exitOnCall is true, running those commands will exit the process. +func (c Commands) DisableSysCommands(exitOnCall bool) { + f := warnDisabled + if exitOnCall { + f = errorDisabled + } + for _, cmd := range c { + if cmd.isSystem { + cmd.action = f + } + } +} + +func (c Commands) matchCommand(line string) (*Command, []string) { + for _, cmd := range c { + matchedCommand := cmd.regex.FindStringSubmatch(line) + if matchedCommand != nil { + return cmd, removeComments(matchedCommand[1:]) + } + } + return nil, nil +} + +func removeComments(args []string) []string { + var pos int + quote := false + for i := range args { + pos, quote = commentStart([]rune(args[i]), quote) + if pos > -1 { + out := make([]string, i+1) + if i > 0 { + copy(out, args[:i]) + } + out[i] = args[i][:pos] + return out + } + } + return args +} + +func commentStart(arg []rune, quote bool) (int, bool) { + var i int + space := true + for ; i < len(arg); i++ { + c, next := arg[i], grab(arg, i+1, len(arg)) + switch { + case quote && c == '"' && next != '"': + quote = false + case quote && c == '"' && next == '"': + i++ + case c == '\t' || c == ' ': + space = true + // Note we assume none of the regexes would split arguments on non-whitespace boundaries such that "text -- comment" would get split into "text -" and "- comment" + case !quote && space && c == '-' && next == '-': + return i, false + case !quote && c == '"': + quote = true + default: + space = false + } + } + return -1, quote +} + +func warnDisabled(s *Sqlcmd, args []string, line uint) error { + s.WriteError(s.GetError(), ErrCommandsDisabled) + return nil +} + +func errorDisabled(s *Sqlcmd, args []string, line uint) error { + s.WriteError(s.GetError(), ErrCommandsDisabled) + s.Exitcode = 1 + return ErrExitRequested +} + +func batchTerminatorRegex(terminator string) string { + return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator)) +} + +// SetBatchTerminator attempts to set the batch terminator to the given value +// Returns an error if the new value is not usable in the regex +func (c Commands) SetBatchTerminator(terminator string) error { + cmd := c["GO"] + regex, err := regexp.Compile(batchTerminatorRegex(terminator)) + if err != nil { + return err + } + cmd.regex = regex + return nil +} + +// exitCommand has 3 modes. +// With no (), it just exits without running any query +// With () it runs whatever batch is in the buffer then exits +// With any text between () it runs the text as a query then exits +func exitCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return ErrExitRequested + } + params := strings.TrimSpace(args[0]) + if params == "" { + return ErrExitRequested + } + if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") { + return InvalidCommandError("EXIT", line) + } + // First we save the current batch + query1 := s.batch.String() + if len(query1) > 0 { + query1 = s.getRunnableQuery(query1) + } + // Now parse the params of EXIT as a batch without commands + cmd := s.batch.cmd + s.batch.cmd = nil + defer func() { + s.batch.cmd = cmd + }() + query2 := strings.TrimSpace(params[1 : len(params)-1]) + if len(query2) > 0 { + s.batch.Reset([]rune(query2)) + _, _, err := s.batch.Next() + if err != nil { + return err + } + query2 = s.batch.String() + if len(query2) > 0 { + query2 = s.getRunnableQuery(query2) + } + } + + if len(query1) > 0 || len(query2) > 0 { + query := query1 + SqlcmdEol + query2 + s.Exitcode, _ = s.runQuery(query) + } + return ErrExitRequested +} + +// quitCommand immediately exits the program without running any more batches +func quitCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("QUIT", line) + } + return ErrExitRequested +} + +// goCommand runs the current batch the number of times specified +func goCommand(s *Sqlcmd, args []string, line uint) error { + // default to 1 execution + n := 1 + var err error + if len(args) > 0 { + cnt := strings.TrimSpace(args[0]) + if cnt != "" { + if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil { + return err + } + _, err = fmt.Sscanf(cnt, "%d", &n) + } + } + if err != nil || n < 1 { + return InvalidCommandError("GO", line) + } + if s.EchoInput { + err = listCommand(s, []string{}, line) + } + if err != nil { + return InvalidCommandError("GO", line) + } + query := s.batch.String() + if query == "" { + return nil + } + query = s.getRunnableQuery(query) + for i := 0; i < n; i++ { + if retcode, err := s.runQuery(query); err != nil { + s.Exitcode = retcode + return err + } + } + s.batch.Reset(nil) + return nil +} + +// outCommand changes the output writer to use a file +func outCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError("OUT", line) + } + filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) + if err != nil { + return err + } + + switch { + case strings.EqualFold(filePath, "stdout"): + s.SetOutput(os.Stdout) + case strings.EqualFold(filePath, "stderr"): + s.SetOutput(os.Stderr) + default: + o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return InvalidFileError(err, args[0]) + } + if s.UnicodeOutputFile { + // ODBC sqlcmd doesn't write a BOM but we will. + // Maybe the endian-ness should be configurable. + win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) + encoder := transform.NewWriter(o, win16le.NewEncoder()) + s.SetOutput(encoder) + } else { + s.SetOutput(o) + } + } + return nil +} + +// errorCommand changes the error writer to use a file +func errorCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError("ERROR", line) + } + filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) + if err != nil { + return err + } + switch { + case strings.EqualFold(filePath, "stderr"): + s.SetError(os.Stderr) + case strings.EqualFold(filePath, "stdout"): + s.SetError(os.Stdout) + default: + o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return InvalidFileError(err, args[0]) + } + s.SetError(o) + } + return nil +} + +func readFileCommand(s *Sqlcmd, args []string, line uint) error { + if args == nil || len(args) != 1 { + return InvalidCommandError(":R", line) + } + fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false) + return s.IncludeFile(fileName, false) +} + +// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables +func setVarCommand(s *Sqlcmd, args []string, line uint) error { + if args == nil || len(args) != 1 || args[0] == "" { + return InvalidCommandError(":SETVAR", line) + } + + varname := args[0] + val := "" + // The prior incarnation of sqlcmd doesn't require a space between the variable name and its value + // in some very unexpected cases. This version will require the space. + sp := strings.IndexRune(args[0], ' ') + if sp > -1 { + val = strings.TrimSpace(varname[sp:]) + varname = varname[:sp] + } + if err := s.vars.Setvar(varname, val); err != nil { + switch e := err.(type) { + case *VariableError: + return e + default: + return InvalidCommandError(":SETVAR", line) + } + } + return nil +} + +// listVarCommand prints the set of Sqlcmd scripting variables. +// Builtin values are printed first, followed by user-set values in sorted order. +func listVarCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("LISTVAR", line) + } + + vars := s.vars.All() + keys := make([]string, 0, len(vars)) + for k := range vars { + if !contains(builtinVariables, k) { + keys = append(keys, k) + } + } + sort.Strings(keys) + keys = append(builtinVariables, keys...) + for _, k := range keys { + fmt.Fprintf(s.GetOutput(), `%s = "%s"%s`, k, vars[k], SqlcmdEol) + } + return nil +} + +// resetCommand resets the statement cache +func resetCommand(s *Sqlcmd, args []string, line uint) error { + if s.batch != nil { + s.batch.Reset(nil) + } + + return nil +} + +// listCommand displays statements currently in the statement cache +func listCommand(s *Sqlcmd, args []string, line uint) (err error) { + cmd := "" + if args != nil { + if len(args) > 0 { + cmd = strings.ToLower(strings.TrimSpace(args[0])) + if len(args) > 1 || (cmd != "color" && cmd != "") { + return InvalidCommandError("LIST", line) + } + } + } + output := s.GetOutput() + if cmd == "color" { + sample := "select 'literal' as literal, 100 as number from [sys].[tables]" + clr := color.TextTypeTSql + if s.Format.IsXmlMode() { + sample = `value` + clr = color.TextTypeXml + } + // ignoring errors since it's not critical output + for _, style := range s.colorizer.Styles() { + _, _ = output.Write([]byte(style + ": ")) + _ = s.colorizer.Write(output, sample, style, clr) + _, _ = output.Write([]byte(SqlcmdEol)) + } + return + } + if s.batch == nil || s.batch.String() == "" { + return + } + + if err = s.colorizer.Write(output, s.batch.String(), s.vars.ColorScheme(), color.TextTypeTSql); err == nil { + _, err = output.Write([]byte(SqlcmdEol)) + } + + return +} + +func connectCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return InvalidCommandError("CONNECT", line) + } + + commandArgs := strings.Fields(args[0]) + + // Parse flags + flags := flag.NewFlagSet("connect", flag.ContinueOnError) + database := flags.String("D", "", "database name") + username := flags.String("U", "", "user name") + password := flags.String("P", "", "password") + loginTimeout := flags.String("l", "", "login timeout") + authenticationMethod := flags.String("G", "", "authentication method") + + err := flags.Parse(commandArgs[1:]) + //err := flags.Parse(args[1:]) + if err != nil { + return InvalidCommandError("CONNECT", line) + } + + connect := *s.Connect + connect.UserName, _ = resolveArgumentVariables(s, []rune(*username), false) + connect.Password, _ = resolveArgumentVariables(s, []rune(*password), false) + connect.Database, _ = resolveArgumentVariables(s, []rune(*database), false) + + timeout, _ := resolveArgumentVariables(s, []rune(*loginTimeout), false) + if timeout != "" { + if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil { + if timeoutSeconds < 0 { + return InvalidCommandError("CONNECT", line) + } + connect.LoginTimeoutSeconds = int(timeoutSeconds) + } + } + + connect.AuthenticationMethod = *authenticationMethod + + // Set server name as the first positional argument + if len(commandArgs) > 0 { + connect.ServerName, _ = resolveArgumentVariables(s, []rune(commandArgs[0]), false) + } + + // If no user name is provided we switch to integrated auth + _ = s.ConnectDb(&connect, s.lineIo == nil) + + // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option + return nil +} + +func execCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return InvalidCommandError("EXEC", line) + } + cmdLine := strings.TrimSpace(args[0]) + if cmdLine == "" { + return InvalidCommandError("EXEC", line) + } + if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil { + return err + } else { + cmd := sysCommand(cmdLine) + cmd.Stderr = s.GetError() + cmd.Stdout = s.GetOutput() + _ = cmd.Run() + } + return nil +} + +func editCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("ED", line) + } + file, err := os.CreateTemp("", "sq*.sql") + if err != nil { + return err + } + fileName := file.Name() + defer os.Remove(fileName) + text := s.batch.String() + if s.batch.State() == "-" { + text = fmt.Sprintf("%s%s", text, SqlcmdEol) + } + _, err = file.WriteString(text) + if err != nil { + return err + } + file.Close() + cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`) + cmd.Stderr = s.GetError() + cmd.Stdout = s.GetOutput() + err = cmd.Run() + if err != nil { + return err + } + wasEcho := s.echoFileLines + s.echoFileLines = true + s.batch.Reset(nil) + _ = s.IncludeFile(fileName, false) + s.echoFileLines = wasEcho + return nil +} + +func onerrorCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError("ON ERROR", line) + } + params := strings.TrimSpace(args[0]) + + if strings.EqualFold(strings.ToLower(params), "exit") { + s.Connect.ExitOnError = true + } else if strings.EqualFold(strings.ToLower(params), "ignore") { + s.Connect.IgnoreError = true + s.Connect.ExitOnError = false + } else { + return InvalidCommandError("ON ERROR", line) + } + return nil +} + +func xmlCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) != 1 || args[0] == "" { + return InvalidCommandError("XML", line) + } + params := strings.TrimSpace(args[0]) + // "OFF" and "ON" are documented as the allowed values. + // ODBC sqlcmd treats any value other than "ON" the same as "OFF". + // So we will too. + if strings.EqualFold(params, "on") { + s.Format.XmlMode(true) + } else { + s.Format.XmlMode(false) + } + return nil +} + +func serverlistCommand(s *Sqlcmd, args []string, line uint) error { + ListLocalServers(s.GetOutput()) + return nil +} + +func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { + var b *strings.Builder + end := len(arg) + for i := 0; i < end && !s.Connect.DisableVariableSubstitution; { + c, next := arg[i], grab(arg, i+1, end) + switch { + case c == '$' && next == '(': + vl, ok := readVariableReference(arg, i+2, end) + if ok { + varName := string(arg[i+2 : vl]) + val, ok := s.resolveVariable(varName) + if ok { + if b == nil { + b = new(strings.Builder) + b.Grow(len(arg)) + b.WriteString(string(arg[0:i])) + } + b.WriteString(val) + } else { + if failOnUnresolved { + return "", UndefinedVariable(varName) + } + s.WriteError(s.GetError(), UndefinedVariable(varName)) + if b != nil { + b.WriteString(string(arg[i : vl+1])) + } + } + i += ((vl - i) + 1) + } else { + if b != nil { + b.WriteString("$(") + } + i += 2 + } + default: + if b != nil { + b.WriteRune(c) + } + i++ + } + } + if b == nil { + return string(arg), nil + } + return b.String(), nil +} diff --git a/pkg/sqlcmd/serverlist.go b/pkg/sqlcmd/serverlist.go new file mode 100644 index 00000000..48c1ebde --- /dev/null +++ b/pkg/sqlcmd/serverlist.go @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "strings" + "time" + + "github.com/microsoft/go-mssqldb/msdsn" +) + +// ListLocalServers queries the SQL Browser service for available SQL Server instances +// and writes the results to the provided writer. +func ListLocalServers(w io.Writer) { + instances := GetLocalServerInstances() + for _, s := range instances { + fmt.Fprintln(w, " ", s) + } +} + +// GetLocalServerInstances queries the SQL Browser service and returns a list of +// available SQL Server instances on the local machine. +func GetLocalServerInstances() []string { + bmsg := []byte{byte(msdsn.BrowserAllInstances)} + resp := make([]byte, 16*1024-1) + dialer := &net.Dialer{} + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + conn, err := dialer.DialContext(ctx, "udp", ":1434") + // silently ignore failures to connect, same as ODBC + if err != nil { + return nil + } + defer conn.Close() + dl, _ := ctx.Deadline() + _ = conn.SetDeadline(dl) + _, err = conn.Write(bmsg) + if err != nil { + if !errors.Is(err, os.ErrDeadlineExceeded) { + // Silently ignore errors, same as ODBC + } + return nil + } + read, err := conn.Read(resp) + if err != nil { + if !errors.Is(err, os.ErrDeadlineExceeded) { + // Silently ignore errors, same as ODBC + } + return nil + } + + data := parseInstances(resp[:read]) + instances := make([]string, 0, len(data)) + for s := range data { + if s == "MSSQLSERVER" { + instances = append(instances, "(local)", data[s]["ServerName"]) + } else { + instances = append(instances, fmt.Sprintf(`%s\%s`, data[s]["ServerName"], s)) + } + } + return instances +} + +func parseInstances(msg []byte) msdsn.BrowserData { + results := msdsn.BrowserData{} + if len(msg) > 3 && msg[0] == 5 { + out_s := string(msg[3:]) + tokens := strings.Split(out_s, ";") + instdict := map[string]string{} + got_name := false + var name string + for _, token := range tokens { + if got_name { + instdict[name] = token + got_name = false + } else { + name = token + if len(name) == 0 { + if len(instdict) == 0 { + break + } + results[strings.ToUpper(instdict["InstanceName"])] = instdict + instdict = map[string]string{} + continue + } + got_name = true + } + } + } + return results +} diff --git a/pkg/sqlcmd/serverlist_test.go b/pkg/sqlcmd/serverlist_test.go new file mode 100644 index 00000000..59342ecc --- /dev/null +++ b/pkg/sqlcmd/serverlist_test.go @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestListLocalServers(t *testing.T) { + // Test that ListLocalServers writes to the provided writer without error + // Note: actual server discovery depends on SQL Browser service availability + var buf bytes.Buffer + ListLocalServers(&buf) + // We can't assert specific content since it depends on environment, + // but we verify it doesn't panic and writes valid output + t.Logf("ListLocalServers output: %q", buf.String()) +} + +func TestGetLocalServerInstances(t *testing.T) { + // Test that GetLocalServerInstances returns a slice (may be empty if no servers) + instances := GetLocalServerInstances() + // instances may be nil or empty if no SQL Browser is running, that's OK + t.Logf("Found %d instances", len(instances)) + for _, inst := range instances { + assert.NotEmpty(t, inst, "Instance name should not be empty") + } +} + +func TestParseInstances(t *testing.T) { + // Test parsing of SQL Browser response + // Format: 0x05 (response type), 2 bytes length, then semicolon-separated key=value pairs + // Each instance ends with two semicolons + + t.Run("empty response", func(t *testing.T) { + result := parseInstances([]byte{}) + assert.Empty(t, result) + }) + + t.Run("invalid header", func(t *testing.T) { + result := parseInstances([]byte{1, 0, 0}) + assert.Empty(t, result) + }) + + t.Run("valid single instance", func(t *testing.T) { + // Simulating SQL Browser response format + // Header: 0x05 followed by 2 length bytes, then the instance data + data := []byte{5, 0, 0} + instanceData := "ServerName;MYSERVER;InstanceName;MSSQLSERVER;IsClustered;No;Version;15.0.2000.5;tcp;1433;;" + data = append(data, []byte(instanceData)...) + + result := parseInstances(data) + assert.Len(t, result, 1) + assert.Contains(t, result, "MSSQLSERVER") + assert.Equal(t, "MYSERVER", result["MSSQLSERVER"]["ServerName"]) + assert.Equal(t, "1433", result["MSSQLSERVER"]["tcp"]) + }) + + t.Run("valid multiple instances", func(t *testing.T) { + data := []byte{5, 0, 0} + instanceData := "ServerName;MYSERVER;InstanceName;MSSQLSERVER;tcp;1433;;ServerName;MYSERVER;InstanceName;SQLEXPRESS;tcp;1434;;" + data = append(data, []byte(instanceData)...) + + result := parseInstances(data) + assert.Len(t, result, 2) + assert.Contains(t, result, "MSSQLSERVER") + assert.Contains(t, result, "SQLEXPRESS") + }) +} + +func TestServerlistCommand(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + + // Run the serverlist command + c := []string{":serverlist"} + err := runSqlCmd(t, s, c) + + // The command should not raise an error even if no servers are found + assert.NoError(t, err, ":serverlist should not raise error") + // Output may be empty if no SQL Browser is running + t.Logf("Serverlist output: %q", buf.buf.String()) +} From fb3825e1d2395325bdf98b909b432cd87d530bcb Mon Sep 17 00:00:00 2001 From: David Levy Date: Sat, 24 Jan 2026 21:57:29 -0600 Subject: [PATCH 2/8] Fix linting errors in serverlist.go Remove empty conditional branches that triggered staticcheck SA9003. Remove unused imports (errors, os). --- pkg/sqlcmd/serverlist.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pkg/sqlcmd/serverlist.go b/pkg/sqlcmd/serverlist.go index 48c1ebde..3a9194c9 100644 --- a/pkg/sqlcmd/serverlist.go +++ b/pkg/sqlcmd/serverlist.go @@ -5,11 +5,9 @@ package sqlcmd import ( "context" - "errors" "fmt" "io" "net" - "os" "strings" "time" @@ -43,16 +41,12 @@ func GetLocalServerInstances() []string { _ = conn.SetDeadline(dl) _, err = conn.Write(bmsg) if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - // Silently ignore errors, same as ODBC - } + // Silently ignore errors, same as ODBC return nil } read, err := conn.Read(resp) if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - // Silently ignore errors, same as ODBC - } + // Silently ignore errors, same as ODBC return nil } From ebd4b2f17bd85026b05f9876176e5900f4b5a0df Mon Sep 17 00:00:00 2001 From: David Levy Date: Sat, 24 Jan 2026 22:32:31 -0600 Subject: [PATCH 3/8] Implement :help command Add the :help command to display available sqlcmd commands. This improves compatibility with legacy ODBC sqlcmd. Changes: - Added HELP command to command registry - Added helpCommand function with full command list - Added tests for command parsing and functionality - Updated README.md --- README.md | 1 + pkg/sqlcmd/commands.go | 54 +++++++++++++++++++++++++++++++++++++ pkg/sqlcmd/commands_test.go | 23 ++++++++++++++++ 3 files changed, 78 insertions(+) diff --git a/README.md b/README.md index fe26e192..711ac9b6 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,7 @@ switches are most important to you to have implemented next in the new sqlcmd. - `:Connect` now has an optional `-G` parameter to select one of the authentication methods for Azure SQL Database - `SqlAuthentication`, `ActiveDirectoryDefault`, `ActiveDirectoryIntegrated`, `ActiveDirectoryServicePrincipal`, `ActiveDirectoryManagedIdentity`, `ActiveDirectoryPassword`. If `-G` is not provided, either Integrated security or SQL Authentication will be used, dependent on the presence of a `-U` username parameter. - The new `--driver-logging-level` command line parameter allows you to see traces from the `go-mssqldb` client driver. Use `64` to see all traces. - Sqlcmd can now print results using a vertical format. Use the new `--vertical` command line option to set it. It's also controlled by the `SQLCMDFORMAT` scripting variable. +- `:help` displays a list of available sqlcmd commands. ``` 1> select session_id, client_interface_name, program_name from sys.dm_exec_sessions where session_id=@@spid diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 66dd1dba..8528c4b5 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -113,6 +113,11 @@ func newCommands() Commands { action: xmlCommand, name: "XML", }, + "HELP": { + regex: regexp.MustCompile(`(?im)^[ \t]*:HELP(?:[ \t]+(.*$)|$)`), + action: helpCommand, + name: "HELP", + }, } } @@ -596,6 +601,55 @@ func xmlCommand(s *Sqlcmd, args []string, line uint) error { return nil } +// helpCommand displays the list of available sqlcmd commands +func helpCommand(s *Sqlcmd, args []string, line uint) error { + helpText := `:!! [] + - Executes a command in the operating system shell. +:connect server[\instance] [-l timeout] [-U user [-P password]] + - Connects to a SQL Server instance. +:ed + - Edits the current or last executed statement cache. +:error + - Redirects error output to a file, stderr, or stdout. +:exit + - Quits sqlcmd immediately. +:exit() + - Execute statement cache; quit with no return value. +:exit() + - Execute the specified query; returns numeric result. +go [] + - Executes the statement cache (n times). +:help + - Shows this list of commands. +:list + - Prints the content of the statement cache. +:listvar + - Lists the set sqlcmd scripting variables. +:on error [exit|ignore] + - Action for batch or sqlcmd command errors. +:out |stderr|stdout + - Redirects query output to a file, stderr, or stdout. +:perftrace |stderr|stdout + - Redirects timing output to a file, stderr, or stdout. +:quit + - Quits sqlcmd immediately. +:r + - Append file contents to the statement cache. +:reset + - Discards the statement cache. +:serverlist + - Lists local and SQL Servers on the network. +:setvar {variable} + - Removes a sqlcmd scripting variable. +:setvar + - Sets a sqlcmd scripting variable. +:xml [on|off] + - Sets XML output mode. +` + _, err := s.GetOutput().Write([]byte(helpText)) + return err +} + func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { var b *strings.Builder end := len(arg) diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 6197aa3f..5d527b87 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -54,6 +54,8 @@ func TestCommandParsing(t *testing.T) { {`:XML ON `, "XML", []string{`ON `}}, {`:RESET`, "RESET", []string{""}}, {`RESET`, "RESET", []string{""}}, + {`:HELP`, "HELP", []string{""}}, + {`:help`, "HELP", []string{""}}, } for _, test := range commands { @@ -458,3 +460,24 @@ func TestExitCommandAppendsParameterToCurrentBatch(t *testing.T) { } } + +func TestHelpCommand(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.SetOutput(buf) + + err := helpCommand(s, []string{""}, 1) + assert.NoError(t, err, "helpCommand should not error") + + output := buf.buf.String() + // Verify key commands are listed + assert.Contains(t, output, ":connect", "help should list :connect") + assert.Contains(t, output, ":exit", "help should list :exit") + assert.Contains(t, output, ":help", "help should list :help") + assert.Contains(t, output, ":setvar", "help should list :setvar") + assert.Contains(t, output, ":listvar", "help should list :listvar") + assert.Contains(t, output, ":out", "help should list :out") + assert.Contains(t, output, ":error", "help should list :error") + assert.Contains(t, output, ":r", "help should list :r") + assert.Contains(t, output, "go", "help should list go") +} From e1b1dcd8cde79d92649c7a4ad7e5441b60eea064 Mon Sep 17 00:00:00 2001 From: David Levy Date: Sun, 25 Jan 2026 11:43:39 -0600 Subject: [PATCH 4/8] Fix review comments for PR #630 - Rename variables to follow Go naming conventions: - out_s -> outStr - got_name -> gotName - instdict -> instanceDict - Add argument validation to serverlistCommand --- pkg/sqlcmd/commands.go | 3 +++ pkg/sqlcmd/serverlist.go | 22 +++++++++++----------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index f3c5850a..58fc199f 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -602,6 +602,9 @@ func xmlCommand(s *Sqlcmd, args []string, line uint) error { } func serverlistCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) > 0 && args[0] != "" { + return InvalidCommandError("SERVERLIST", line) + } ListLocalServers(s.GetOutput()) return nil } diff --git a/pkg/sqlcmd/serverlist.go b/pkg/sqlcmd/serverlist.go index 3a9194c9..b663a1ae 100644 --- a/pkg/sqlcmd/serverlist.go +++ b/pkg/sqlcmd/serverlist.go @@ -65,26 +65,26 @@ func GetLocalServerInstances() []string { func parseInstances(msg []byte) msdsn.BrowserData { results := msdsn.BrowserData{} if len(msg) > 3 && msg[0] == 5 { - out_s := string(msg[3:]) - tokens := strings.Split(out_s, ";") - instdict := map[string]string{} - got_name := false + outStr := string(msg[3:]) + tokens := strings.Split(outStr, ";") + instanceDict := map[string]string{} + gotName := false var name string for _, token := range tokens { - if got_name { - instdict[name] = token - got_name = false + if gotName { + instanceDict[name] = token + gotName = false } else { name = token if len(name) == 0 { - if len(instdict) == 0 { + if len(instanceDict) == 0 { break } - results[strings.ToUpper(instdict["InstanceName"])] = instdict - instdict = map[string]string{} + results[strings.ToUpper(instanceDict["InstanceName"])] = instanceDict + instanceDict = map[string]string{} continue } - got_name = true + gotName = true } } } From 4bea3f26108ee5406b27b154be308be7f8ecc52a Mon Sep 17 00:00:00 2001 From: David Levy Date: Sun, 25 Jan 2026 11:47:13 -0600 Subject: [PATCH 5/8] Fix review comments for PR #634 - Remove :serverlist and :perftrace from help text - These commands are in separate PRs and not yet merged - Help text should only list commands that exist in this branch --- pkg/sqlcmd/commands.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 8528c4b5..d67d6c17 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -629,16 +629,12 @@ go [] - Action for batch or sqlcmd command errors. :out |stderr|stdout - Redirects query output to a file, stderr, or stdout. -:perftrace |stderr|stdout - - Redirects timing output to a file, stderr, or stdout. :quit - Quits sqlcmd immediately. :r - Append file contents to the statement cache. :reset - Discards the statement cache. -:serverlist - - Lists local and SQL Servers on the network. :setvar {variable} - Removes a sqlcmd scripting variable. :setvar From ec0a3cd7afb93ae38d5cbe9c9fab1aeedafb1f83 Mon Sep 17 00:00:00 2001 From: David Levy Date: Sun, 25 Jan 2026 13:01:37 -0600 Subject: [PATCH 6/8] Address Copilot review comments on serverlist command - Fix fmt.Fprintf spacing in ListLocalServers - Sort instance names for deterministic output - Add validation for missing ServerName in instances - Add argument validation for helpCommand - Add test cases for :SERVERLIST and :serverlist - Add assertion for :serverlist in TestHelpCommand --- pkg/sqlcmd/commands.go | 3 +++ pkg/sqlcmd/commands_test.go | 5 +++-- pkg/sqlcmd/serverlist.go | 20 +++++++++++++++++--- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 43d22438..699bc1a4 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -608,6 +608,9 @@ func xmlCommand(s *Sqlcmd, args []string, line uint) error { // helpCommand displays the list of available sqlcmd commands func helpCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) > 0 && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("HELP", line) + } helpText := `:!! [] - Executes a command in the operating system shell. :connect server[\instance] [-l timeout] [-U user [-P password]] diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 5d527b87..e2ee2dbb 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -55,8 +55,8 @@ func TestCommandParsing(t *testing.T) { {`:RESET`, "RESET", []string{""}}, {`RESET`, "RESET", []string{""}}, {`:HELP`, "HELP", []string{""}}, - {`:help`, "HELP", []string{""}}, - } + {`:help`, "HELP", []string{""}}, {`:SERVERLIST`, "SERVERLIST", []string{""}}, + {`:serverlist`, "SERVERLIST", []string{""}}} for _, test := range commands { cmd, args := c.matchCommand(test.line) @@ -479,5 +479,6 @@ func TestHelpCommand(t *testing.T) { assert.Contains(t, output, ":out", "help should list :out") assert.Contains(t, output, ":error", "help should list :error") assert.Contains(t, output, ":r", "help should list :r") + assert.Contains(t, output, ":serverlist", "help should list :serverlist") assert.Contains(t, output, "go", "help should list go") } diff --git a/pkg/sqlcmd/serverlist.go b/pkg/sqlcmd/serverlist.go index b663a1ae..4e77f4e5 100644 --- a/pkg/sqlcmd/serverlist.go +++ b/pkg/sqlcmd/serverlist.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net" + "sort" "strings" "time" @@ -19,7 +20,7 @@ import ( func ListLocalServers(w io.Writer) { instances := GetLocalServerInstances() for _, s := range instances { - fmt.Fprintln(w, " ", s) + fmt.Fprintf(w, " %s\n", s) } } @@ -52,11 +53,24 @@ func GetLocalServerInstances() []string { data := parseInstances(resp[:read]) instances := make([]string, 0, len(data)) + + // Sort instance names for deterministic output + instanceNames := make([]string, 0, len(data)) for s := range data { + instanceNames = append(instanceNames, s) + } + sort.Strings(instanceNames) + + for _, s := range instanceNames { + serverName := data[s]["ServerName"] + if serverName == "" { + // Skip instances without a ServerName + continue + } if s == "MSSQLSERVER" { - instances = append(instances, "(local)", data[s]["ServerName"]) + instances = append(instances, "(local)", serverName) } else { - instances = append(instances, fmt.Sprintf(`%s\%s`, data[s]["ServerName"], s)) + instances = append(instances, fmt.Sprintf(`%s\%s`, serverName, s)) } } return instances From f241917a831a62ef9909cb4b56a1eece4dc51f9e Mon Sep 17 00:00:00 2001 From: David Levy Date: Sun, 25 Jan 2026 14:10:39 -0600 Subject: [PATCH 7/8] Address Copilot review comments for serverlist command - Use strings.TrimSpace in serverlistCommand for consistency - Fix test formatting to put each test case on its own line - Add InstanceName validation before using as map key - Document :serverlist command in README.md --- README.md | 1 + pkg/sqlcmd/commands.go | 2 +- pkg/sqlcmd/commands_test.go | 6 ++++-- pkg/sqlcmd/serverlist.go | 5 ++++- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 711ac9b6..5ab4679f 100644 --- a/README.md +++ b/README.md @@ -155,6 +155,7 @@ switches are most important to you to have implemented next in the new sqlcmd. - The new `--driver-logging-level` command line parameter allows you to see traces from the `go-mssqldb` client driver. Use `64` to see all traces. - Sqlcmd can now print results using a vertical format. Use the new `--vertical` command line option to set it. It's also controlled by the `SQLCMDFORMAT` scripting variable. - `:help` displays a list of available sqlcmd commands. +- `:serverlist` lists local SQL Server instances discovered via the SQL Server Browser service. ``` 1> select session_id, client_interface_name, program_name from sys.dm_exec_sessions where session_id=@@spid diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 699bc1a4..548632e6 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -657,7 +657,7 @@ go [] } func serverlistCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) > 0 && args[0] != "" { + if len(args) > 0 && strings.TrimSpace(args[0]) != "" { return InvalidCommandError("SERVERLIST", line) } ListLocalServers(s.GetOutput()) diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index e2ee2dbb..7895e307 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -55,8 +55,10 @@ func TestCommandParsing(t *testing.T) { {`:RESET`, "RESET", []string{""}}, {`RESET`, "RESET", []string{""}}, {`:HELP`, "HELP", []string{""}}, - {`:help`, "HELP", []string{""}}, {`:SERVERLIST`, "SERVERLIST", []string{""}}, - {`:serverlist`, "SERVERLIST", []string{""}}} + {`:help`, "HELP", []string{""}}, + {`:SERVERLIST`, "SERVERLIST", []string{""}}, + {`:serverlist`, "SERVERLIST", []string{""}}, + } for _, test := range commands { cmd, args := c.matchCommand(test.line) diff --git a/pkg/sqlcmd/serverlist.go b/pkg/sqlcmd/serverlist.go index 4e77f4e5..ee3fba80 100644 --- a/pkg/sqlcmd/serverlist.go +++ b/pkg/sqlcmd/serverlist.go @@ -94,7 +94,10 @@ func parseInstances(msg []byte) msdsn.BrowserData { if len(instanceDict) == 0 { break } - results[strings.ToUpper(instanceDict["InstanceName"])] = instanceDict + // Only add if InstanceName key exists and is non-empty + if instName, ok := instanceDict["InstanceName"]; ok && instName != "" { + results[strings.ToUpper(instName)] = instanceDict + } instanceDict = map[string]string{} continue } From c103e26ce93c69c80a335e1d16857c3db1802432 Mon Sep 17 00:00:00 2001 From: David Levy Date: Sun, 25 Jan 2026 17:09:10 -0600 Subject: [PATCH 8/8] Restore error printing for non-timeout errors and add documentation - Changed GetLocalServerInstances() to return ([]string, error) - Only return error if NOT os.ErrDeadlineExceeded (timeout is expected) - ListLocalServers() prints errors to stderr (matches ODBC sqlcmd behavior) - Expanded README documentation for :serverlist command - Added batch script examples for error handling and automation --- README.md | 35 ++++++++++++++++++++++++++++++++++- pkg/sqlcmd/serverlist.go | 28 ++++++++++++++++++++-------- pkg/sqlcmd/serverlist_test.go | 6 +++++- 3 files changed, 59 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 5ab4679f..7da89b68 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,40 @@ switches are most important to you to have implemented next in the new sqlcmd. - The new `--driver-logging-level` command line parameter allows you to see traces from the `go-mssqldb` client driver. Use `64` to see all traces. - Sqlcmd can now print results using a vertical format. Use the new `--vertical` command line option to set it. It's also controlled by the `SQLCMDFORMAT` scripting variable. - `:help` displays a list of available sqlcmd commands. -- `:serverlist` lists local SQL Server instances discovered via the SQL Server Browser service. +- `:serverlist` lists local SQL Server instances discovered via the SQL Server Browser service (UDP port 1434). The command queries the SQL Browser service and displays the server name and instance name for each discovered instance. If no instances are found or the Browser service is not running, no output is produced. Non-timeout errors are printed to stderr. + +``` +1> :serverlist +MYSERVER\SQL2019 +MYSERVER\SQL2022 +``` + +#### Using :serverlist in batch scripts + +When automating server discovery, you can capture the output and check for errors: + +```batch +@echo off +REM Discover local SQL Server instances and connect to the first one +sqlcmd -Q ":serverlist" 2>nul > servers.txt +if %errorlevel% neq 0 ( + echo Error discovering servers + exit /b 1 +) +for /f "tokens=1" %%s in (servers.txt) do ( + echo Connecting to %%s... + sqlcmd -S %%s -Q "SELECT @@SERVERNAME" + goto :done +) +echo No SQL Server instances found +:done +``` + +To capture stderr separately (for error logging): +```batch +sqlcmd -Q ":serverlist" 2>errors.log > servers.txt +if exist errors.log if not "%%~z errors.log"=="0" type errors.log +``` ``` 1> select session_id, client_interface_name, program_name from sys.dm_exec_sessions where session_id=@@spid diff --git a/pkg/sqlcmd/serverlist.go b/pkg/sqlcmd/serverlist.go index ee3fba80..efe29c7d 100644 --- a/pkg/sqlcmd/serverlist.go +++ b/pkg/sqlcmd/serverlist.go @@ -5,9 +5,11 @@ package sqlcmd import ( "context" + "errors" "fmt" "io" "net" + "os" "sort" "strings" "time" @@ -18,7 +20,10 @@ import ( // ListLocalServers queries the SQL Browser service for available SQL Server instances // and writes the results to the provided writer. func ListLocalServers(w io.Writer) { - instances := GetLocalServerInstances() + instances, err := GetLocalServerInstances() + if err != nil { + fmt.Fprintln(os.Stderr, err) + } for _, s := range instances { fmt.Fprintf(w, " %s\n", s) } @@ -26,7 +31,8 @@ func ListLocalServers(w io.Writer) { // GetLocalServerInstances queries the SQL Browser service and returns a list of // available SQL Server instances on the local machine. -func GetLocalServerInstances() []string { +// Returns an error for non-timeout network errors. +func GetLocalServerInstances() ([]string, error) { bmsg := []byte{byte(msdsn.BrowserAllInstances)} resp := make([]byte, 16*1024-1) dialer := &net.Dialer{} @@ -35,20 +41,26 @@ func GetLocalServerInstances() []string { conn, err := dialer.DialContext(ctx, "udp", ":1434") // silently ignore failures to connect, same as ODBC if err != nil { - return nil + return nil, nil } defer conn.Close() dl, _ := ctx.Deadline() _ = conn.SetDeadline(dl) _, err = conn.Write(bmsg) if err != nil { - // Silently ignore errors, same as ODBC - return nil + // Only return error if it's not a timeout + if !errors.Is(err, os.ErrDeadlineExceeded) { + return nil, err + } + return nil, nil } read, err := conn.Read(resp) if err != nil { - // Silently ignore errors, same as ODBC - return nil + // Only return error if it's not a timeout + if !errors.Is(err, os.ErrDeadlineExceeded) { + return nil, err + } + return nil, nil } data := parseInstances(resp[:read]) @@ -73,7 +85,7 @@ func GetLocalServerInstances() []string { instances = append(instances, fmt.Sprintf(`%s\%s`, serverName, s)) } } - return instances + return instances, nil } func parseInstances(msg []byte) msdsn.BrowserData { diff --git a/pkg/sqlcmd/serverlist_test.go b/pkg/sqlcmd/serverlist_test.go index 59342ecc..3ab20920 100644 --- a/pkg/sqlcmd/serverlist_test.go +++ b/pkg/sqlcmd/serverlist_test.go @@ -22,8 +22,12 @@ func TestListLocalServers(t *testing.T) { func TestGetLocalServerInstances(t *testing.T) { // Test that GetLocalServerInstances returns a slice (may be empty if no servers) - instances := GetLocalServerInstances() + instances, err := GetLocalServerInstances() // instances may be nil or empty if no SQL Browser is running, that's OK + // err may be non-nil for non-timeout network errors + if err != nil { + t.Logf("GetLocalServerInstances returned error (expected in some environments): %v", err) + } t.Logf("Found %d instances", len(instances)) for _, inst := range instances { assert.NotEmpty(t, inst, "Instance name should not be empty")