Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 39 additions & 20 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
return errors.New("either --cluster or --name must be provided")
}

if !opts.ProxyMode {
cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s...", sessionID))
}

if opts.IDE != "" && !opts.ProxyMode {
if err := vscode.CheckIDECommand(opts.IDE); err != nil {
return err
Expand Down Expand Up @@ -238,6 +242,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt

// Only check cluster state for dedicated clusters
if !opts.IsServerlessMode() {
cmdio.LogString(ctx, "Checking cluster state...")
err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster)
if err != nil {
return err
Expand All @@ -263,8 +268,8 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
if err != nil {
return fmt.Errorf("failed to save SSH key pair locally: %w", err)
}
cmdio.LogString(ctx, "Using SSH key: "+keyPath)
cmdio.LogString(ctx, fmt.Sprintf("Secrets scope: %s, key name: %s", secretScopeName, opts.ClientPublicKeyName))
log.Infof(ctx, "Using SSH key: %s", keyPath)
log.Infof(ctx, "Secrets scope: %s, key name: %s", secretScopeName, opts.ClientPublicKeyName)

var userName string
var serverPort int
Expand All @@ -273,10 +278,14 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
version := build.GetInfo().Version

if opts.ServerMetadata == "" {
cmdio.LogString(ctx, "Checking for ssh-tunnel binaries to upload...")
cmdio.LogString(ctx, "Uploading binaries...")
sp := cmdio.NewSpinner(ctx)
sp.Update("Uploading binaries...")
if err := UploadTunnelReleases(ctx, client, version, opts.ReleasesDir); err != nil {
sp.Close()
return fmt.Errorf("failed to upload ssh-tunnel binaries: %w", err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Agent Swarm Review] [Nit]

The spinner for binary upload creates and closes inline rather than using defer sp.Close(). An early return between sp := cmdio.NewSpinner(ctx) and sp.Close() would leak the spinner. The other spinners in this PR correctly use defer sp.Close().

Suggestion: Move sp.Close() to a defer and remove the explicit close calls.

}
sp.Close()
userName, serverPort, clusterID, err = ensureSSHServerIsRunning(ctx, client, version, secretScopeName, opts)
if err != nil {
return fmt.Errorf("failed to ensure that ssh server is running: %w", err)
Expand Down Expand Up @@ -307,18 +316,22 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
return errors.New("cluster ID is required for serverless connections but was not found in metadata")
}

cmdio.LogString(ctx, "Remote user name: "+userName)
cmdio.LogString(ctx, fmt.Sprintf("Server port: %d", serverPort))
log.Infof(ctx, "Remote user name: %s", userName)
log.Infof(ctx, "Server port: %d", serverPort)
if opts.IsServerlessMode() {
cmdio.LogString(ctx, "Cluster ID (from serverless job): "+clusterID)
log.Infof(ctx, "Cluster ID (from serverless job): %s", clusterID)
}

if !opts.ProxyMode {
cmdio.LogString(ctx, "Connected!")
}

if opts.ProxyMode {
return runSSHProxy(ctx, client, serverPort, clusterID, opts)
} else if opts.IDE != "" {
return runIDE(ctx, client, userName, keyPath, serverPort, clusterID, opts)
} else {
cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs))
log.Infof(ctx, "Additional SSH arguments: %v", opts.AdditionalArgs)
return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts)
}
}
Expand Down Expand Up @@ -372,7 +385,7 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k
return err
}

cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config entry for '%s'", hostName))
log.Infof(ctx, "Updated SSH config entry for '%s'", hostName)
return nil
}

Expand Down Expand Up @@ -471,7 +484,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
"serverless": strconv.FormatBool(opts.IsServerlessMode()),
}

cmdio.LogString(ctx, "Submitting a job to start the ssh server...")
log.Infof(ctx, "Submitting a job to start the ssh server...")

task := jobs.SubmitTask{
TaskKey: sshServerTaskKey,
Expand All @@ -485,7 +498,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
if opts.IsServerlessMode() {
task.EnvironmentKey = serverlessEnvironmentKey
if opts.Accelerator != "" {
cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator)
log.Infof(ctx, "Using accelerator: %s", opts.Accelerator)
task.Compute = &jobs.Compute{
HardwareAccelerator: compute.HardwareAcceleratorType(opts.Accelerator),
}
Expand Down Expand Up @@ -516,7 +529,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
return fmt.Errorf("failed to submit job: %w", err)
}

cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", waiter.RunId))
log.Infof(ctx, "Job submitted successfully with run ID: %d", waiter.RunId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Job run ID can be very useful for debugging (and cluster id too actually)
And the problem is that we can't really get them after the fact (unless we also store local logs to a file, which we don't do right now)


return waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout)
}
Expand Down Expand Up @@ -568,14 +581,16 @@ func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, server
}

func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, autoStart bool) error {
sp := cmdio.NewSpinner(ctx)
defer sp.Close()
if autoStart {
cmdio.LogString(ctx, "Ensuring the cluster is running: "+clusterID)
sp.Update("Ensuring the cluster is running...")
err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID)
if err != nil {
return fmt.Errorf("failed to ensure that the cluster is running: %w", err)
}
} else {
cmdio.LogString(ctx, "Checking cluster state: "+clusterID)
sp.Update("Checking cluster state...")
cluster, err := client.Clusters.GetByClusterId(ctx, clusterID)
if err != nil {
return fmt.Errorf("failed to get cluster info: %w", err)
Expand All @@ -590,7 +605,9 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient,
// waitForJobToStart polls the task status until the SSH server task is in RUNNING state or terminates.
// Returns an error if the task fails to start or if polling times out.
func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, taskStartupTimeout time.Duration) error {
cmdio.LogString(ctx, "Waiting for the SSH server task to start...")
sp := cmdio.NewSpinner(ctx)
defer sp.Close()
sp.Update("Starting SSH server...")
var prevState jobs.RunLifecycleStateV2State

_, err := retries.Poll(ctx, taskStartupTimeout, func() (*jobs.RunTask, *retries.Err) {
Expand Down Expand Up @@ -620,15 +637,14 @@ func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient,

currentState := sshTask.Status.State

// Print status if it changed
// Update spinner if state changed
if currentState != prevState {
cmdio.LogString(ctx, fmt.Sprintf("Task status: %s", currentState))
sp.Update(fmt.Sprintf("Starting SSH server... (task: %s)", currentState))
prevState = currentState
}

// Check if task is running
if currentState == jobs.RunLifecycleStateV2StateRunning {
cmdio.LogString(ctx, "SSH server task is now running, proceeding to connect...")
return sshTask, nil
}

Expand All @@ -651,26 +667,29 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC

serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
if errors.Is(err, errServerMetadata) {
cmdio.LogString(ctx, "SSH server is not running, starting it now...")
cmdio.LogString(ctx, "Starting SSH server...")

err := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts)
if err != nil {
return "", 0, "", fmt.Errorf("failed to submit and start ssh server job: %w", err)
}

cmdio.LogString(ctx, "Waiting for the ssh server to start...")
sp := cmdio.NewSpinner(ctx)
sp.Update("Waiting for the SSH server to start...")
maxRetries := 30
for retries := range maxRetries {
if ctx.Err() != nil {
sp.Close()
return "", 0, "", ctx.Err()
}
serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
if err == nil {
cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...")
sp.Close()
break
} else if retries < maxRetries-1 {
time.Sleep(2 * time.Second)
} else {
sp.Close()
return "", 0, "", fmt.Errorf("failed to start the ssh server: %w", err)
}
}
Expand Down
12 changes: 6 additions & 6 deletions experimental/ssh/internal/client/releases.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
"strings"

"github.com/databricks/cli/experimental/ssh/internal/workspace"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/cli/libs/log"
"github.com/databricks/databricks-sdk-go"
)

Expand Down Expand Up @@ -48,7 +48,7 @@ func uploadReleases(ctx context.Context, workspaceFiler filer.Filer, getRelease

_, err := workspaceFiler.Stat(ctx, remoteBinaryPath)
if err == nil {
cmdio.LogString(ctx, fmt.Sprintf("File %s already exists in the workspace, skipping upload", remoteBinaryPath))
log.Infof(ctx, "File %s already exists in the workspace, skipping upload", remoteBinaryPath)
continue
} else if !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("failed to check if file %s exists in workspace: %w", remoteBinaryPath, err)
Expand All @@ -60,14 +60,14 @@ func uploadReleases(ctx context.Context, workspaceFiler filer.Filer, getRelease
}
defer releaseReader.Close()

cmdio.LogString(ctx, fmt.Sprintf("Uploading %s to the workspace", fileName))
log.Infof(ctx, "Uploading %s to the workspace", fileName)
// workspace-files/import-file API will automatically unzip the payload,
// producing the filerRoot/remoteSubFolder/*archive-contents* structure, with 'databricks' binary inside.
err = workspaceFiler.Write(ctx, remoteArchivePath, releaseReader, filer.OverwriteIfExists, filer.CreateParentDirectories)
if err != nil {
return fmt.Errorf("failed to upload file %s to workspace: %w", remoteArchivePath, err)
}
cmdio.LogString(ctx, fmt.Sprintf("Successfully uploaded %s to workspace", remoteBinaryPath))
log.Infof(ctx, "Successfully uploaded %s to workspace", remoteBinaryPath)
}

return nil
Expand All @@ -81,7 +81,7 @@ func getReleaseName(architecture, version string) string {
}

func getLocalRelease(ctx context.Context, architecture, version, releasesDir string) (io.ReadCloser, error) {
cmdio.LogString(ctx, "Looking for CLI releases in directory: "+releasesDir)
log.Infof(ctx, "Looking for CLI releases in directory: %s", releasesDir)
releaseName := getReleaseName(architecture, version)
releasePath := filepath.Join(releasesDir, releaseName)
file, err := os.Open(releasePath)
Expand All @@ -95,7 +95,7 @@ func getGithubRelease(ctx context.Context, architecture, version, releasesDir st
// TODO: download and check databricks_cli_<version>_SHA256SUMS
fileName := getReleaseName(architecture, version)
downloadURL := fmt.Sprintf("https://github.com/databricks/cli/releases/download/v%s/%s", version, fileName)
cmdio.LogString(ctx, fmt.Sprintf("Downloading %s from %s", fileName, downloadURL))
log.Infof(ctx, "Downloading %s from %s", fileName, downloadURL)

resp, err := http.Get(downloadURL)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions experimental/ssh/internal/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ import (
"io"
"time"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/log"
"golang.org/x/sync/errgroup"
)

func RunClientProxy(ctx context.Context, src io.ReadCloser, dst io.Writer, requestHandoverTick func() <-chan time.Time, createConn createWebsocketConnectionFunc) error {
proxy := newProxyConnection(createConn)
cmdio.LogString(ctx, "Establishing SSH proxy connection...")
log.Infof(ctx, "Establishing SSH proxy connection...")
g, gCtx := errgroup.WithContext(ctx)
if err := proxy.connect(gCtx); err != nil {
return fmt.Errorf("failed to connect to proxy: %w", err)
}
defer proxy.close()
cmdio.LogString(ctx, "SSH proxy connection established")
log.Infof(ctx, "SSH proxy connection established")

g.Go(func() error {
for {
Expand Down
3 changes: 2 additions & 1 deletion experimental/ssh/internal/vscode/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/log"
"golang.org/x/mod/semver"
)

Expand Down Expand Up @@ -148,7 +149,7 @@ func LaunchIDE(ctx context.Context, ideOption, connectionName, userName, databri
remoteURI := fmt.Sprintf("ssh-remote+%s@%s", userName, connectionName)
remotePath := fmt.Sprintf("/Workspace/Users/%s/", databricksUserName)

cmdio.LogString(ctx, fmt.Sprintf("Launching %s with remote URI: %s and path: %s", ideOption, remoteURI, remotePath))
log.Infof(ctx, "Launching %s with remote URI: %s and path: %s", ideOption, remoteURI, remotePath)

ideCmd := exec.CommandContext(ctx, ide.Command, "--remote", remoteURI, remotePath)
ideCmd.Stdout = os.Stdout
Expand Down
4 changes: 2 additions & 2 deletions experimental/ssh/internal/vscode/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,11 @@ func backupSettings(ctx context.Context, path string) error {
latestBak := path + ".latest.bak"

if _, err := os.Stat(originalBak); os.IsNotExist(err) {
cmdio.LogString(ctx, "Backing up settings to "+filepath.ToSlash(originalBak))
log.Infof(ctx, "Backing up settings to %s", filepath.ToSlash(originalBak))
return os.WriteFile(originalBak, data, 0o600)
}

cmdio.LogString(ctx, "Backing up settings to "+filepath.ToSlash(latestBak))
log.Infof(ctx, "Backing up settings to %s", filepath.ToSlash(latestBak))
return os.WriteFile(latestBak, data, 0o600)
}

Expand Down