diff --git a/job.go b/job.go index 64cdab20..44e2bd06 100644 --- a/job.go +++ b/job.go @@ -356,6 +356,25 @@ func (j *Job) updateConnections() { if u.User != nil { user = u.User.Username() } + + // For SQL Server connections, + // url.path is reserved for sql server instance not the database name + // database name can be specified in multiple ways: + // 1. Query parameter: database=dbname + // 2. Query parameter: initial catalog=dbname + database := "" + if strings.HasPrefix(conn, "sqlserver://") { + // Check for 'database' parameter first + if dbParam := getQueryStringCaseInsensitive(u.Query(), "database"); dbParam != "" { + database = dbParam + } else if catalogParam := getQueryStringCaseInsensitive(u.Query(), "initial catalog"); catalogParam != "" { + // 'initial catalog' is an alternative to 'database' parameter + database = catalogParam + } + } else { + database = strings.TrimPrefix(u.Path, "/") + } + // we expose some of the connection variables as labels, so we need to // remember them newConn := &connection{ @@ -363,7 +382,7 @@ func (j *Job) updateConnections() { url: conn, driver: u.Scheme, host: u.Host, - database: strings.TrimPrefix(u.Path, "/"), + database: database, user: user, } if newConn.driver == "athena" { @@ -387,11 +406,11 @@ func (j *Job) updateConnections() { privateKeyPath := os.ExpandEnv(queryParams.Get("private_key_file")) cfg := &gosnowflake.Config{ - Account: u.Host, - User: u.User.Username(), - Role: queryParams.Get("role"), + Account: u.Host, + User: u.User.Username(), + Role: queryParams.Get("role"), Database: queryParams.Get("database"), - Schema: queryParams.Get("schema"), + Schema: queryParams.Get("schema"), } if privateKeyPath != "" { @@ -683,3 +702,13 @@ func (c *connection) connect(job *Job) error { c.conn = conn return nil } + +func getQueryStringCaseInsensitive(values url.Values, key string) string { + key = strings.ToLower(key) + for k, v := range values { + if strings.ToLower(k) == key && len(v) > 0 { + return v[0] + } + } + return "" +}