diff --git a/cmd/auth/env.go b/cmd/auth/env.go index 11149af8c0..521f1d871a 100644 --- a/cmd/auth/env.go +++ b/cmd/auth/env.go @@ -1,146 +1,73 @@ package auth import ( - "context" "encoding/json" - "errors" "fmt" - "io/fs" - "net/http" - "net/url" + "maps" + "slices" "strings" - "github.com/databricks/cli/libs/databrickscfg/profile" - "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/flags" "github.com/spf13/cobra" - "gopkg.in/ini.v1" ) -func canonicalHost(host string) (string, error) { - parsedHost, err := url.Parse(host) - if err != nil { - return "", err - } - // If the host is empty, assume the scheme wasn't included. - if parsedHost.Host == "" { - return "https://" + host, nil - } - return "https://" + parsedHost.Host, nil -} - -var ErrNoMatchingProfiles = errors.New("no matching profiles found") - -func resolveSection(cfg *config.Config, iniFile *config.File) (*ini.Section, error) { - var candidates []*ini.Section - configuredHost, err := canonicalHost(cfg.Host) - if err != nil { - return nil, err - } - for _, section := range iniFile.Sections() { - hash := section.KeysHash() - host, ok := hash["host"] - if !ok { - // if host is not set - continue - } - canonical, err := canonicalHost(host) - if err != nil { - // we're fine with other corrupt profiles - continue - } - if canonical != configuredHost { - continue - } - candidates = append(candidates, section) - } - if len(candidates) == 0 { - return nil, ErrNoMatchingProfiles - } - // in the real situations, we don't expect this to happen often - // (if not at all), hence we don't trim the list - if len(candidates) > 1 { - var profiles []string - for _, v := range candidates { - profiles = append(profiles, v.Name()) - } - return nil, fmt.Errorf("%s match %s in %s", - strings.Join(profiles, " and "), cfg.Host, cfg.ConfigFile) - } - return candidates[0], nil -} - -func loadFromDatabricksCfg(ctx context.Context, cfg *config.Config) error { - iniFile, err := profile.DefaultProfiler.Get(ctx) - if errors.Is(err, fs.ErrNotExist) { - // it's fine not to have ~/.databrickscfg - return nil - } - if err != nil { - return err - } - profile, err := resolveSection(cfg, iniFile) - if err == ErrNoMatchingProfiles { - // it's also fine for Azure CLI or Databricks CLI, which - // are resolved by unified auth handling in the Go SDK. - return nil - } - if err != nil { - return err - } - cfg.Profile = profile.Name() - return nil -} - func newEnvCommand() *cobra.Command { cmd := &cobra.Command{ Use: "env", - Short: "Get env", + Short: "Get authentication environment variables for the current CLI context", + Long: `Output the environment variables needed to authenticate as the same identity +the CLI is currently authenticated as. This is useful for configuring downstream +tools that accept Databricks authentication via environment variables.`, } - var host string - var profile string - cmd.Flags().StringVar(&host, "host", host, "Hostname to get auth env for") - cmd.Flags().StringVar(&profile, "profile", profile, "Profile to get auth env for") - cmd.RunE = func(cmd *cobra.Command, args []string) error { - cfg := &config.Config{ - Host: host, - Profile: profile, - } - if profile != "" { - cfg.Profile = profile - } else if cfg.Host == "" { - cfg.Profile = "DEFAULT" - } else if err := loadFromDatabricksCfg(cmd.Context(), cfg); err != nil { - return err - } - // Go SDK is lazy loaded because of Terraform semantics, - // so we're creating a dummy HTTP request as a placeholder - // for headers. - r := &http.Request{Header: http.Header{}} - err := cfg.Authenticate(r.WithContext(cmd.Context())) + _, err := root.MustAnyClient(cmd, args) if err != nil { return err } - vars := map[string]string{} - for _, a := range config.ConfigAttributes { - if a.IsZero(cfg) { - continue - } - envValue := a.GetString(cfg) - for _, envName := range a.EnvVars { - vars[envName] = envValue + + cfg := cmdctx.ConfigUsed(cmd.Context()) + envVars := auth.Env(cfg) + + // Output KEY=VALUE lines when the user explicitly passes --output text. + if cmd.Flag("output").Changed && root.OutputType(cmd) == flags.OutputText { + w := cmd.OutOrStdout() + keys := slices.Sorted(maps.Keys(envVars)) + for _, k := range keys { + _, err := fmt.Fprintf(w, "%s=%s\n", k, quoteEnvValue(envVars[k])) + if err != nil { + return err + } } + return nil } - raw, err := json.MarshalIndent(map[string]any{ - "env": vars, - }, "", " ") + + raw, err := json.MarshalIndent(envVars, "", " ") if err != nil { return err } - _, _ = cmd.OutOrStdout().Write(raw) - return nil + _, err = cmd.OutOrStdout().Write(raw) + return err } return cmd } + +const shellQuotedSpecialChars = " \t\n\r\"\\$`!#&|;(){}[]<>?*~'" + +// quoteEnvValue quotes a value for KEY=VALUE output if it contains spaces or +// shell-special characters. Single quotes prevent shell expansion, and +// embedded single quotes use the POSIX-compatible '\" sequence. +func quoteEnvValue(v string) string { + if v == "" { + return `''` + } + needsQuoting := strings.ContainsAny(v, shellQuotedSpecialChars) + if !needsQuoting { + return v + } + return "'" + strings.ReplaceAll(v, "'", "'\\''") + "'" +} diff --git a/cmd/auth/env_test.go b/cmd/auth/env_test.go new file mode 100644 index 0000000000..d3347dd84d --- /dev/null +++ b/cmd/auth/env_test.go @@ -0,0 +1,102 @@ +package auth + +import ( + "bytes" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestQuoteEnvValue(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + {name: "simple value", in: "hello", want: "hello"}, + {name: "empty value", in: "", want: `''`}, + {name: "value with space", in: "hello world", want: "'hello world'"}, + {name: "value with tab", in: "hello\tworld", want: "'hello\tworld'"}, + {name: "value with double quote", in: `say "hi"`, want: "'say \"hi\"'"}, + {name: "value with backslash", in: `path\to`, want: "'path\\to'"}, + {name: "url value", in: "https://example.com", want: "https://example.com"}, + {name: "value with dollar", in: "price$5", want: "'price$5'"}, + {name: "value with backtick", in: "hello`world", want: "'hello`world'"}, + {name: "value with bang", in: "hello!world", want: "'hello!world'"}, + {name: "value with single quote", in: "it's", want: "'it'\\''s'"}, + {name: "value with newline", in: "line1\nline2", want: "'line1\nline2'"}, + {name: "value with carriage return", in: "line1\rline2", want: "'line1\rline2'"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := quoteEnvValue(c.in) + assert.Equal(t, c.want, got) + }) + } +} + +func TestEnvCommand_TextOutput(t *testing.T) { + cases := []struct { + name string + args []string + wantJSON bool + }{ + { + name: "default output is JSON", + args: nil, + wantJSON: true, + }, + { + name: "explicit --output text produces KEY=VALUE lines", + args: []string{"--output", "text"}, + wantJSON: false, + }, + { + name: "explicit --output json produces JSON", + args: []string{"--output", "json"}, + wantJSON: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + // Isolate from real config/token cache on the machine. + t.Setenv("DATABRICKS_CONFIG_FILE", t.TempDir()+"/.databrickscfg") + t.Setenv("HOME", t.TempDir()) + // Set env vars so MustAnyClient resolves auth via PAT. + t.Setenv("DATABRICKS_HOST", "https://test.cloud.databricks.com") + t.Setenv("DATABRICKS_TOKEN", "test-token-value") + + parent := &cobra.Command{Use: "databricks"} + outputFlag := flags.OutputText + parent.PersistentFlags().VarP(&outputFlag, "output", "o", "output type: text or json") + parent.PersistentFlags().StringP("profile", "p", "", "~/.databrickscfg profile") + + envCmd := newEnvCommand() + parent.AddCommand(envCmd) + parent.SetContext(cmdio.MockDiscard(t.Context())) + + var buf bytes.Buffer + parent.SetOut(&buf) + parent.SetArgs(append([]string{"env"}, c.args...)) + + err := parent.Execute() + require.NoError(t, err) + + output := buf.String() + if c.wantJSON { + assert.Contains(t, output, "{") + assert.Contains(t, output, "DATABRICKS_HOST") + } else { + assert.NotContains(t, output, "{") + assert.Contains(t, output, "DATABRICKS_HOST=") + assert.Contains(t, output, "=") + assert.NotContains(t, output, `"env"`) + } + }) + } +} diff --git a/cmd/auth/token.go b/cmd/auth/token.go index 79f99726be..7268bb55e6 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -8,11 +8,13 @@ import ( "strings" "time" + "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/flags" "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" @@ -83,17 +85,29 @@ using a client ID and secret is not supported.`, if err != nil { return err } - raw, err := json.MarshalIndent(t, "", " ") - if err != nil { - return err - } - _, _ = cmd.OutOrStdout().Write(raw) - return nil + return writeTokenOutput(cmd, t) } return cmd } +func writeTokenOutput(cmd *cobra.Command, t *oauth2.Token) error { + // Only honor the explicit --output text flag, not implicit text mode + // (e.g. from DATABRICKS_OUTPUT_FORMAT). auth token defaults to JSON, + // and changing that implicitly would break scripts that parse JSON output. + if cmd.Flag("output").Changed && root.OutputType(cmd) == flags.OutputText { + _, err := fmt.Fprintln(cmd.OutOrStdout(), t.AccessToken) + return err + } + + raw, err := json.MarshalIndent(t, "", " ") + if err != nil { + return err + } + _, err = cmd.OutOrStdout().Write(raw) + return err +} + type loadTokenArgs struct { // authArguments is the parsed auth arguments, including the host and optionally the account ID. authArguments *auth.AuthArguments diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index aa343eb372..0e41952995 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -1,6 +1,7 @@ package auth import ( + "bytes" "context" "net/http" "testing" @@ -10,8 +11,10 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/flags" "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "golang.org/x/oauth2" ) @@ -729,3 +732,105 @@ func (e errProfiler) LoadProfiles(context.Context, profile.ProfileMatchFunction) func (e errProfiler) GetPath(context.Context) (string, error) { return "", nil } + +func TestTokenCommand_TextOutput(t *testing.T) { + profiler := profile.InMemoryProfiler{ + Profiles: profile.Profiles{ + { + Name: "test-ws", + Host: "https://test-ws.cloud.databricks.com", + }, + }, + } + tokenCache := &inMemoryTokenCache{ + Tokens: map[string]*oauth2.Token{ + "test-ws": { + RefreshToken: "test-ws", + Expiry: time.Now().Add(1 * time.Hour), + }, + }, + } + persistentAuthOpts := []u2m.PersistentAuthOption{ + u2m.WithTokenCache(tokenCache), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}), + } + + cases := []struct { + name string + args []string + wantSubstr string + wantJSON bool + }{ + { + name: "default output is JSON", + args: []string{"--profile", "test-ws"}, + wantSubstr: `"access_token"`, + wantJSON: true, + }, + { + name: "explicit --output json produces JSON", + args: []string{"--profile", "test-ws", "--output", "json"}, + wantSubstr: `"access_token"`, + wantJSON: true, + }, + { + name: "explicit --output text produces plain token with newline", + args: []string{"--profile", "test-ws", "--output", "text"}, + wantSubstr: "new-access-token\n", + wantJSON: false, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + authArgs := &auth.AuthArguments{} + + parent := &cobra.Command{Use: "databricks"} + outputFlag := flags.OutputText + parent.PersistentFlags().VarP(&outputFlag, "output", "o", "output type: text or json") + parent.PersistentFlags().StringP("profile", "p", "", "~/.databrickscfg profile") + + tokenCmd := newTokenCommand(authArgs) + // Override RunE to inject test profiler and token cache while reusing + // the production output formatter. + tokenCmd.RunE = func(cmd *cobra.Command, args []string) error { + profileName := "" + if f := cmd.Flag("profile"); f != nil { + profileName = f.Value.String() + } + tok, err := loadToken(cmd.Context(), loadTokenArgs{ + authArguments: authArgs, + profileName: profileName, + args: args, + tokenTimeout: 1 * time.Hour, + profiler: profiler, + persistentAuthOpts: persistentAuthOpts, + }) + if err != nil { + return err + } + return writeTokenOutput(cmd, tok) + } + + parent.AddCommand(tokenCmd) + parent.SetContext(ctx) + + var buf bytes.Buffer + parent.SetOut(&buf) + parent.SetArgs(append([]string{"token"}, c.args...)) + + err := parent.Execute() + assert.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, c.wantSubstr) + if c.wantJSON { + assert.Contains(t, output, "{") + } else { + assert.NotContains(t, output, "{") + } + }) + } +}