Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cmd/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ using a client ID and secret is not supported.`,
cmd.Flags().DurationVar(&tokenTimeout, "timeout", defaultTimeout,
"Timeout for acquiring a token.")

var refreshBefore time.Duration
cmd.Flags().DurationVar(&refreshBefore, "refresh-before", 0,
"Refresh the token if it expires within this duration (e.g., 5m, 30s).")

cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
profileName := ""
Expand All @@ -78,6 +82,7 @@ using a client ID and secret is not supported.`,
profileName: profileName,
args: args,
tokenTimeout: tokenTimeout,
refreshBefore: refreshBefore,
profiler: profile.DefaultProfiler,
persistentAuthOpts: nil,
})
Expand Down Expand Up @@ -108,6 +113,9 @@ type loadTokenArgs struct {
// tokenTimeout is the timeout for retrieving (and potentially refreshing) an OAuth token.
tokenTimeout time.Duration

// refreshBefore triggers a token refresh if the token expires within this duration.
refreshBefore time.Duration

// profiler is the profiler to use for reading the host and account ID from the .databrickscfg file.
profiler profile.Profiler

Expand Down Expand Up @@ -242,6 +250,9 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) {
return nil, err
}
allArgs := append(args.persistentAuthOpts, u2m.WithOAuthArgument(oauthArgument))
if args.refreshBefore > 0 {
allArgs = append(allArgs, u2m.WithExpiryDelta(args.refreshBefore))
}
persistentAuth, err := u2m.NewPersistentAuth(ctx, allArgs...)
if err != nil {
helpMsg := helpfulError(ctx, args.profileName, oauthArgument)
Expand Down
64 changes: 64 additions & 0 deletions cmd/auth/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ func TestToken_loadToken(t *testing.T) {
Name: "legacy-ws",
Host: "https://legacy-ws.cloud.databricks.com",
},
{
Name: "valid-token",
Host: "https://accounts.cloud.databricks.com",
AccountID: "valid-token",
},
{
Name: "m2m-profile",
Host: "https://m2m.cloud.databricks.com",
Expand Down Expand Up @@ -642,6 +647,65 @@ func TestToken_loadToken(t *testing.T) {
},
validateToken: validateToken,
},
{
name: "refreshBefore skips refresh when token has enough time",
args: loadTokenArgs{
authArguments: &auth.AuthArguments{},
profileName: "valid-token",
args: []string{},
tokenTimeout: 1 * time.Hour,
refreshBefore: 5 * time.Minute,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(&inMemoryTokenCache{Tokens: map[string]*oauth2.Token{
"valid-token": {AccessToken: "still-valid", RefreshToken: "valid-token", Expiry: time.Now().Add(1 * time.Hour)},
}}),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
},
},
validateToken: func(resp *oauth2.Token) {
assert.Equal(t, "still-valid", resp.AccessToken)
},
},
{
name: "refreshBefore zero preserves default behavior",
args: loadTokenArgs{
authArguments: &auth.AuthArguments{},
profileName: "valid-token",
args: []string{},
tokenTimeout: 1 * time.Hour,
refreshBefore: 0,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(&inMemoryTokenCache{Tokens: map[string]*oauth2.Token{
"valid-token": {AccessToken: "still-valid", RefreshToken: "valid-token", Expiry: time.Now().Add(1 * time.Hour)},
}}),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
},
},
validateToken: func(resp *oauth2.Token) {
assert.Equal(t, "still-valid", resp.AccessToken)
},
},
{
name: "refreshBefore forces refresh when token expires within window",
args: loadTokenArgs{
authArguments: &auth.AuthArguments{},
profileName: "valid-token",
args: []string{},
tokenTimeout: 1 * time.Hour,
refreshBefore: 2 * time.Hour,
profiler: profiler,
persistentAuthOpts: []u2m.PersistentAuthOption{
u2m.WithTokenCache(&inMemoryTokenCache{Tokens: map[string]*oauth2.Token{
"valid-token": {AccessToken: "still-valid", RefreshToken: "valid-token", Expiry: time.Now().Add(1 * time.Hour)},
}}),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}),
},
},
validateToken: validateToken,
},
{
name: "host flag with profile env var disambiguates multi-profile",
setupCtx: func(ctx context.Context) context.Context {
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,5 @@ require (
google.golang.org/grpc v1.78.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
)

replace github.com/databricks/databricks-sdk-go => /Users/anthony.ivan/projects/databricks-sdk-go
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s=
github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI=
github.com/databricks/databricks-sdk-go v0.119.0 h1:Fot5T4bBGxfuFHII0xLPXuzkBmALWiJeUBeuXQX2Pcw=
github.com/databricks/databricks-sdk-go v0.119.0/go.mod h1:hWoHnHbNLjPKiTm5K/7bcIv3J3Pkgo5x9pPzh8K3RVE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down