diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index a05c0adca7..112b5bddc1 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -20,12 +20,15 @@ import ( "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/experimental/ssh/internal/proxy" + "github.com/databricks/cli/experimental/ssh/internal/sessions" "github.com/databricks/cli/experimental/ssh/internal/sshconfig" "github.com/databricks/cli/experimental/ssh/internal/vscode" sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace" "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/log" + "github.com/databricks/cli/libs/telemetry" + "github.com/databricks/cli/libs/telemetry/protos" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/retries" "github.com/databricks/databricks-sdk-go/service/compute" @@ -99,11 +102,11 @@ type ClientOptions struct { } func (o *ClientOptions) Validate() error { - if !o.ProxyMode && o.ClusterID == "" && o.ConnectionName == "" { - return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)") + if !o.ProxyMode && o.ClusterID == "" && o.ConnectionName == "" && o.Accelerator == "" { + return errors.New("please provide --cluster or --accelerator flag") } - if o.Accelerator != "" && o.ConnectionName == "" { - return errors.New("--accelerator flag can only be used with serverless compute (--name flag)") + if o.Accelerator != "" && o.ClusterID != "" { + return errors.New("--accelerator flag can only be used with serverless compute, not with --cluster") } // TODO: Remove when we add support for serverless CPU if o.ConnectionName != "" && o.Accelerator == "" { @@ -122,7 +125,7 @@ func (o *ClientOptions) Validate() error { } func (o *ClientOptions) IsServerlessMode() bool { - return o.ClusterID == "" && o.ConnectionName != "" + return o.ClusterID == "" && (o.ConnectionName != "" || o.Accelerator != "") } // SessionIdentifier returns the unique identifier for the session. @@ -202,9 +205,67 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt cancel() }() + event := BuildTelemetryEvent(opts) + + runErr := runConnect(ctx, client, opts, event) + if runErr != nil { + event.IsSuccess = false + } else { + event.IsSuccess = true + } + + telemetry.Log(ctx, protos.DatabricksCliLog{ + SshTunnelEvent: event, + }) + + return runErr +} + +// BuildTelemetryEvent creates an SshTunnelEvent pre-populated with data from client options. +func BuildTelemetryEvent(opts ClientOptions) *protos.SshTunnelEvent { + event := &protos.SshTunnelEvent{ + AcceleratorType: opts.Accelerator, + IdeType: opts.IDE, + AutoStartCluster: opts.AutoStartCluster, + } + + if opts.IsServerlessMode() { + event.ComputeType = protos.SshTunnelComputeTypeServerless + } else { + event.ComputeType = protos.SshTunnelComputeTypeDedicated + } + + switch { + case opts.ProxyMode: + event.ClientMode = protos.SshTunnelClientModeProxy + case opts.IDE != "": + event.ClientMode = protos.SshTunnelClientModeIDE + default: + event.ClientMode = protos.SshTunnelClientModeSSH + } + + // If metadata is provided, the server is already running — this is a reconnect from ProxyCommand. + event.IsReconnect = opts.ServerMetadata != "" + + return event +} + +func runConnect(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions, event *protos.SshTunnelEvent) error { + // For serverless without explicit --name: auto-generate or reconnect to existing session. + if opts.IsServerlessMode() && opts.ConnectionName == "" && !opts.ProxyMode { + err := resolveServerlessSession(ctx, client, &opts) + if err != nil { + return err + } + } + sessionID := opts.SessionIdentifier() if sessionID == "" { - return errors.New("either --cluster or --name must be provided") + return errors.New("either --cluster or --accelerator must be provided") + } + + if !opts.ProxyMode { + cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s...", sessionID)) } if opts.IDE != "" && !opts.ProxyMode { @@ -238,6 +299,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt // Only check cluster state for dedicated clusters if !opts.IsServerlessMode() { + cmdio.LogString(ctx, "Checking cluster state...") err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) if err != nil { return err @@ -263,8 +325,8 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt if err != nil { return fmt.Errorf("failed to save SSH key pair locally: %w", err) } - cmdio.LogString(ctx, "Using SSH key: "+keyPath) - cmdio.LogString(ctx, fmt.Sprintf("Secrets scope: %s, key name: %s", secretScopeName, opts.ClientPublicKeyName)) + log.Infof(ctx, "Using SSH key: %s", keyPath) + log.Infof(ctx, "Secrets scope: %s, key name: %s", secretScopeName, opts.ClientPublicKeyName) var userName string var serverPort int @@ -273,14 +335,22 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt version := build.GetInfo().Version if opts.ServerMetadata == "" { - cmdio.LogString(ctx, "Checking for ssh-tunnel binaries to upload...") + cmdio.LogString(ctx, "Uploading binaries...") + sp := cmdio.NewSpinner(ctx) + sp.TrackElapsedTime() + sp.Update("Uploading binaries...") if err := UploadTunnelReleases(ctx, client, version, opts.ReleasesDir); err != nil { + sp.Close() return fmt.Errorf("failed to upload ssh-tunnel binaries: %w", err) } + sp.Close() + + serverStartTime := time.Now() userName, serverPort, clusterID, err = ensureSSHServerIsRunning(ctx, client, version, secretScopeName, opts) if err != nil { return fmt.Errorf("failed to ensure that ssh server is running: %w", err) } + event.ServerStartTimeMs = time.Since(serverStartTime).Milliseconds() } else { // Metadata format: ",," metadata := strings.Split(opts.ServerMetadata, ",") @@ -307,10 +377,28 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt return errors.New("cluster ID is required for serverless connections but was not found in metadata") } - cmdio.LogString(ctx, "Remote user name: "+userName) - cmdio.LogString(ctx, fmt.Sprintf("Server port: %d", serverPort)) + log.Infof(ctx, "Remote user name: %s", userName) + log.Infof(ctx, "Server port: %d", serverPort) if opts.IsServerlessMode() { - cmdio.LogString(ctx, "Cluster ID (from serverless job): "+clusterID) + log.Infof(ctx, "Cluster ID (from serverless job): %s", clusterID) + } + + if !opts.ProxyMode { + cmdio.LogString(ctx, "Connected!") + } + + // Persist the session for future reconnects. + if opts.IsServerlessMode() && !opts.ProxyMode { + err = sessions.Add(ctx, sessions.Session{ + Name: opts.ConnectionName, + Accelerator: opts.Accelerator, + WorkspaceHost: client.Config.Host, + CreatedAt: time.Now(), + ClusterID: clusterID, + }) + if err != nil { + log.Warnf(ctx, "Failed to save session state: %v", err) + } } if opts.ProxyMode { @@ -318,7 +406,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt } else if opts.IDE != "" { return runIDE(ctx, client, userName, keyPath, serverPort, clusterID, opts) } else { - cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs)) + log.Infof(ctx, "Additional SSH arguments: %v", opts.AdditionalArgs) return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts) } } @@ -365,14 +453,19 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k return fmt.Errorf("failed to generate ProxyCommand: %w", err) } - hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand) + var hostConfig string + if opts.IsServerlessMode() { + hostConfig = sshconfig.GenerateServerlessHostConfig(hostName, userName, keyPath, proxyCommand) + } else { + hostConfig = sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand) + } _, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true) if err != nil { return err } - cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config entry for '%s'", hostName)) + log.Infof(ctx, "Updated SSH config entry for '%s'", hostName) return nil } @@ -471,7 +564,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, "serverless": strconv.FormatBool(opts.IsServerlessMode()), } - cmdio.LogString(ctx, "Submitting a job to start the ssh server...") + log.Infof(ctx, "Submitting a job to start the ssh server...") task := jobs.SubmitTask{ TaskKey: sshServerTaskKey, @@ -485,7 +578,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, if opts.IsServerlessMode() { task.EnvironmentKey = serverlessEnvironmentKey if opts.Accelerator != "" { - cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator) + log.Infof(ctx, "Using accelerator: %s", opts.Accelerator) task.Compute = &jobs.Compute{ HardwareAccelerator: compute.HardwareAcceleratorType(opts.Accelerator), } @@ -516,7 +609,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, return fmt.Errorf("failed to submit job: %w", err) } - cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", waiter.RunId)) + log.Infof(ctx, "Job submitted successfully with run ID: %d", waiter.RunId) return waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout) } @@ -533,15 +626,22 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server hostName := opts.SessionIdentifier() + hostKeyChecking := "StrictHostKeyChecking=accept-new" + if opts.IsServerlessMode() { + hostKeyChecking = "StrictHostKeyChecking=no" + } + sshArgs := []string{ "-l", userName, "-i", privateKeyPath, "-o", "IdentitiesOnly=yes", - "-o", "StrictHostKeyChecking=accept-new", + "-o", hostKeyChecking, "-o", "ConnectTimeout=360", "-o", "ProxyCommand=" + proxyCommand, } - if opts.UserKnownHostsFile != "" { + if opts.IsServerlessMode() { + sshArgs = append(sshArgs, "-o", "UserKnownHostsFile=/dev/null") + } else if opts.UserKnownHostsFile != "" { sshArgs = append(sshArgs, "-o", "UserKnownHostsFile="+opts.UserKnownHostsFile) } sshArgs = append(sshArgs, hostName) @@ -568,14 +668,17 @@ func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, server } func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, autoStart bool) error { + sp := cmdio.NewSpinner(ctx) + sp.TrackElapsedTime() + defer sp.Close() if autoStart { - cmdio.LogString(ctx, "Ensuring the cluster is running: "+clusterID) + sp.Update("Ensuring the cluster is running...") err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID) if err != nil { return fmt.Errorf("failed to ensure that the cluster is running: %w", err) } } else { - cmdio.LogString(ctx, "Checking cluster state: "+clusterID) + sp.Update("Checking cluster state...") cluster, err := client.Clusters.GetByClusterId(ctx, clusterID) if err != nil { return fmt.Errorf("failed to get cluster info: %w", err) @@ -590,7 +693,10 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, // waitForJobToStart polls the task status until the SSH server task is in RUNNING state or terminates. // Returns an error if the task fails to start or if polling times out. func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, taskStartupTimeout time.Duration) error { - cmdio.LogString(ctx, "Waiting for the SSH server task to start...") + sp := cmdio.NewSpinner(ctx) + sp.TrackElapsedTime() + defer sp.Close() + sp.Update("Starting SSH server...") var prevState jobs.RunLifecycleStateV2State _, err := retries.Poll(ctx, taskStartupTimeout, func() (*jobs.RunTask, *retries.Err) { @@ -620,15 +726,14 @@ func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, currentState := sshTask.Status.State - // Print status if it changed + // Update spinner if state changed if currentState != prevState { - cmdio.LogString(ctx, fmt.Sprintf("Task status: %s", currentState)) + sp.Update(fmt.Sprintf("Starting SSH server... (task: %s)", currentState)) prevState = currentState } // Check if task is running if currentState == jobs.RunLifecycleStateV2StateRunning { - cmdio.LogString(ctx, "SSH server task is now running, proceeding to connect...") return sshTask, nil } @@ -651,26 +756,30 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) if errors.Is(err, errServerMetadata) { - cmdio.LogString(ctx, "SSH server is not running, starting it now...") + cmdio.LogString(ctx, "Starting SSH server...") err := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts) if err != nil { return "", 0, "", fmt.Errorf("failed to submit and start ssh server job: %w", err) } - cmdio.LogString(ctx, "Waiting for the ssh server to start...") + sp := cmdio.NewSpinner(ctx) + sp.TrackElapsedTime() + sp.Update("Waiting for the SSH server to start...") maxRetries := 30 for retries := range maxRetries { if ctx.Err() != nil { + sp.Close() return "", 0, "", ctx.Err() } serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) if err == nil { - cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...") + sp.Close() break } else if retries < maxRetries-1 { time.Sleep(2 * time.Second) } else { + sp.Close() return "", 0, "", fmt.Errorf("failed to start the ssh server: %w", err) } } @@ -680,3 +789,97 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC return userName, serverPort, effectiveClusterID, nil } + +// resolveServerlessSession handles auto-generation and reconnection for serverless sessions. +// It checks local state for existing sessions matching the workspace and accelerator, +// probes them to see if they're still alive, and prompts the user to reconnect or create new. +func resolveServerlessSession(ctx context.Context, client *databricks.WorkspaceClient, opts *ClientOptions) error { + version := build.GetInfo().Version + + matching, err := sessions.FindMatching(ctx, client.Config.Host, opts.Accelerator) + if err != nil { + log.Warnf(ctx, "Failed to load session state: %v", err) + } + + // Probe sessions to find alive ones (limit to 5 most recent to avoid latency). + const maxProbe = 5 + if len(matching) > maxProbe { + matching = matching[len(matching)-maxProbe:] + } + + var alive []sessions.Session + for _, s := range matching { + _, _, _, probeErr := getServerMetadata(ctx, client, s.Name, s.ClusterID, version, opts.Liteswap) + if probeErr == nil { + alive = append(alive, s) + } else { + cleanupStaleSession(ctx, client, s, version) + } + } + + if len(alive) > 0 && cmdio.IsPromptSupported(ctx) { + choices := make([]string, 0, len(alive)+1) + for _, s := range alive { + choices = append(choices, fmt.Sprintf("Reconnect to %s (started %s)", s.Name, s.CreatedAt.Format(time.RFC822))) + } + choices = append(choices, "Create new session") + + choice, choiceErr := cmdio.AskSelect(ctx, "Found existing sessions:", choices) + if choiceErr != nil { + return fmt.Errorf("failed to prompt user: %w", choiceErr) + } + + for i, s := range alive { + if choice == choices[i] { + opts.ConnectionName = s.Name + cmdio.LogString(ctx, "Reconnecting to session: "+s.Name) + return nil + } + } + } + + // No alive session selected — generate a new name. + opts.ConnectionName = sessions.GenerateSessionName(opts.Accelerator) + cmdio.LogString(ctx, "Creating new session: "+opts.ConnectionName) + return nil +} + +// cleanupStaleSession removes all local and remote artifacts for a stale session. +func cleanupStaleSession(ctx context.Context, client *databricks.WorkspaceClient, s sessions.Session, version string) { + // Remove local SSH keys. + keyPath, err := keys.GetLocalSSHKeyPath(ctx, s.Name, "") + if err == nil { + os.RemoveAll(filepath.Dir(keyPath)) + } + + // Remove SSH config entry. + if err := sshconfig.RemoveHostConfig(ctx, s.Name); err != nil { + log.Debugf(ctx, "Failed to remove SSH config for %s: %v", s.Name, err) + } + + // Delete secret scope (best-effort). + me, err := client.CurrentUser.Me(ctx) + if err == nil { + scopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, s.Name) + deleteErr := client.Secrets.DeleteScope(ctx, workspace.DeleteScope{Scope: scopeName}) + if deleteErr != nil { + log.Debugf(ctx, "Failed to delete secret scope %s: %v", scopeName, deleteErr) + } + } + + // Remove workspace content directory (best-effort). + contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, s.Name) + if err == nil { + deleteErr := client.Workspace.Delete(ctx, workspace.Delete{Path: contentDir, Recursive: true}) + if deleteErr != nil { + log.Debugf(ctx, "Failed to delete workspace content for %s: %v", s.Name, deleteErr) + } + } + + // Remove from local state. + if err := sessions.Remove(ctx, s.Name); err != nil { + log.Debugf(ctx, "Failed to remove session %s from state: %v", s.Name, err) + } + + log.Infof(ctx, "Cleaned up stale session: %s", s.Name) +} diff --git a/experimental/ssh/internal/client/client_test.go b/experimental/ssh/internal/client/client_test.go index 57df2fed2c..3a361e464c 100644 --- a/experimental/ssh/internal/client/client_test.go +++ b/experimental/ssh/internal/client/client_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/databricks/cli/experimental/ssh/internal/client" + "github.com/databricks/cli/libs/telemetry/protos" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,9 +19,9 @@ func TestValidate(t *testing.T) { wantErr string }{ { - name: "no cluster or connection name", + name: "no cluster or connection name or accelerator", opts: client.ClientOptions{}, - wantErr: "please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)", + wantErr: "please provide --cluster or --accelerator flag", }, { name: "proxy mode skips cluster/name check", @@ -31,9 +32,13 @@ func TestValidate(t *testing.T) { opts: client.ClientOptions{ClusterID: "abc-123"}, }, { - name: "accelerator without connection name", + name: "accelerator with cluster ID", opts: client.ClientOptions{ClusterID: "abc-123", Accelerator: "GPU_1xA10"}, - wantErr: "--accelerator flag can only be used with serverless compute (--name flag)", + wantErr: "--accelerator flag can only be used with serverless compute, not with --cluster", + }, + { + name: "accelerator only (auto-generate session name)", + opts: client.ClientOptions{Accelerator: "GPU_1xA10"}, }, { name: "connection name without accelerator", @@ -55,8 +60,9 @@ func TestValidate(t *testing.T) { opts: client.ClientOptions{ConnectionName: "my-conn_1", Accelerator: "GPU_1xA10"}, }, { - name: "both cluster ID and connection name", - opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn", Accelerator: "GPU_1xA10"}, + name: "both cluster ID and connection name (no accelerator)", + opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn"}, + wantErr: "--name flag requires --accelerator to be set (for now we only support serverless GPU compute)", }, { name: "proxy mode with invalid connection name", @@ -164,3 +170,71 @@ func TestToProxyCommand(t *testing.T) { }) } } + +func TestBuildTelemetryEvent(t *testing.T) { + tests := []struct { + name string + opts client.ClientOptions + want *protos.SshTunnelEvent + }{ + { + name: "dedicated cluster with SSH client", + opts: client.ClientOptions{ + ClusterID: "abc-123", + AutoStartCluster: true, + }, + want: &protos.SshTunnelEvent{ + ComputeType: protos.SshTunnelComputeTypeDedicated, + ClientMode: protos.SshTunnelClientModeSSH, + AutoStartCluster: true, + }, + }, + { + name: "serverless with IDE", + opts: client.ClientOptions{ + ConnectionName: "my-conn", + Accelerator: "GPU_1xA10", + IDE: "vscode", + }, + want: &protos.SshTunnelEvent{ + ComputeType: protos.SshTunnelComputeTypeServerless, + ClientMode: protos.SshTunnelClientModeIDE, + AcceleratorType: "GPU_1xA10", + IdeType: "vscode", + }, + }, + { + name: "proxy mode with metadata (reconnect)", + opts: client.ClientOptions{ + ClusterID: "abc-123", + ProxyMode: true, + ServerMetadata: "user,2222,abc-123", + }, + want: &protos.SshTunnelEvent{ + ComputeType: protos.SshTunnelComputeTypeDedicated, + ClientMode: protos.SshTunnelClientModeProxy, + IsReconnect: true, + }, + }, + { + name: "serverless proxy mode", + opts: client.ClientOptions{ + ConnectionName: "my-conn", + Accelerator: "GPU_8xH100", + ProxyMode: true, + }, + want: &protos.SshTunnelEvent{ + ComputeType: protos.SshTunnelComputeTypeServerless, + ClientMode: protos.SshTunnelClientModeProxy, + AcceleratorType: "GPU_8xH100", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := client.BuildTelemetryEvent(tt.opts) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/experimental/ssh/internal/client/releases.go b/experimental/ssh/internal/client/releases.go index f147244e9e..6c6ad800aa 100644 --- a/experimental/ssh/internal/client/releases.go +++ b/experimental/ssh/internal/client/releases.go @@ -12,8 +12,8 @@ import ( "strings" "github.com/databricks/cli/experimental/ssh/internal/workspace" - "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go" ) @@ -48,7 +48,7 @@ func uploadReleases(ctx context.Context, workspaceFiler filer.Filer, getRelease _, err := workspaceFiler.Stat(ctx, remoteBinaryPath) if err == nil { - cmdio.LogString(ctx, fmt.Sprintf("File %s already exists in the workspace, skipping upload", remoteBinaryPath)) + log.Infof(ctx, "File %s already exists in the workspace, skipping upload", remoteBinaryPath) continue } else if !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("failed to check if file %s exists in workspace: %w", remoteBinaryPath, err) @@ -60,14 +60,14 @@ func uploadReleases(ctx context.Context, workspaceFiler filer.Filer, getRelease } defer releaseReader.Close() - cmdio.LogString(ctx, fmt.Sprintf("Uploading %s to the workspace", fileName)) + log.Infof(ctx, "Uploading %s to the workspace", fileName) // workspace-files/import-file API will automatically unzip the payload, // producing the filerRoot/remoteSubFolder/*archive-contents* structure, with 'databricks' binary inside. err = workspaceFiler.Write(ctx, remoteArchivePath, releaseReader, filer.OverwriteIfExists, filer.CreateParentDirectories) if err != nil { return fmt.Errorf("failed to upload file %s to workspace: %w", remoteArchivePath, err) } - cmdio.LogString(ctx, fmt.Sprintf("Successfully uploaded %s to workspace", remoteBinaryPath)) + log.Infof(ctx, "Successfully uploaded %s to workspace", remoteBinaryPath) } return nil @@ -81,7 +81,7 @@ func getReleaseName(architecture, version string) string { } func getLocalRelease(ctx context.Context, architecture, version, releasesDir string) (io.ReadCloser, error) { - cmdio.LogString(ctx, "Looking for CLI releases in directory: "+releasesDir) + log.Infof(ctx, "Looking for CLI releases in directory: %s", releasesDir) releaseName := getReleaseName(architecture, version) releasePath := filepath.Join(releasesDir, releaseName) file, err := os.Open(releasePath) @@ -95,7 +95,7 @@ func getGithubRelease(ctx context.Context, architecture, version, releasesDir st // TODO: download and check databricks_cli__SHA256SUMS fileName := getReleaseName(architecture, version) downloadURL := fmt.Sprintf("https://github.com/databricks/cli/releases/download/v%s/%s", version, fileName) - cmdio.LogString(ctx, fmt.Sprintf("Downloading %s from %s", fileName, downloadURL)) + log.Infof(ctx, "Downloading %s from %s", fileName, downloadURL) resp, err := http.Get(downloadURL) if err != nil { diff --git a/experimental/ssh/internal/proxy/client.go b/experimental/ssh/internal/proxy/client.go index 20e9eab0e1..89be5967c9 100644 --- a/experimental/ssh/internal/proxy/client.go +++ b/experimental/ssh/internal/proxy/client.go @@ -6,19 +6,19 @@ import ( "io" "time" - "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" "golang.org/x/sync/errgroup" ) func RunClientProxy(ctx context.Context, src io.ReadCloser, dst io.Writer, requestHandoverTick func() <-chan time.Time, createConn createWebsocketConnectionFunc) error { proxy := newProxyConnection(createConn) - cmdio.LogString(ctx, "Establishing SSH proxy connection...") + log.Infof(ctx, "Establishing SSH proxy connection...") g, gCtx := errgroup.WithContext(ctx) if err := proxy.connect(gCtx); err != nil { return fmt.Errorf("failed to connect to proxy: %w", err) } defer proxy.close() - cmdio.LogString(ctx, "SSH proxy connection established") + log.Infof(ctx, "SSH proxy connection established") g.Go(func() error { for { diff --git a/experimental/ssh/internal/sessions/namegen.go b/experimental/ssh/internal/sessions/namegen.go new file mode 100644 index 0000000000..ab66a93eba --- /dev/null +++ b/experimental/ssh/internal/sessions/namegen.go @@ -0,0 +1,28 @@ +package sessions + +import ( + "crypto/rand" + "encoding/hex" + "strings" + "time" +) + +// acceleratorPrefixes maps known accelerator types to short human-readable prefixes. +var acceleratorPrefixes = map[string]string{ + "GPU_1xA10": "gpu-a10", + "GPU_8xH100": "gpu-h100", +} + +// GenerateSessionName creates a human-readable session name from the accelerator type. +// Format: -, e.g. "gpu-a10-f3a2b1c0". +func GenerateSessionName(accelerator string) string { + prefix, ok := acceleratorPrefixes[accelerator] + if !ok { + prefix = strings.ToLower(strings.ReplaceAll(accelerator, "_", "-")) + } + + date := time.Now().Format("20060102") + b := make([]byte, 3) + _, _ = rand.Read(b) + return "databricks-" + prefix + "-" + date + "-" + hex.EncodeToString(b) +} diff --git a/experimental/ssh/internal/sessions/sessions.go b/experimental/ssh/internal/sessions/sessions.go new file mode 100644 index 0000000000..15d93a4002 --- /dev/null +++ b/experimental/ssh/internal/sessions/sessions.go @@ -0,0 +1,147 @@ +package sessions + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/databricks/cli/libs/env" +) + +const ( + stateFileName = "ssh-tunnel-sessions.json" + + // Sessions older than this are considered expired and cleaned up automatically. + sessionMaxAge = 24 * time.Hour +) + +// Session represents a tracked SSH tunnel session. +type Session struct { + Name string `json:"name"` + Accelerator string `json:"accelerator"` + WorkspaceHost string `json:"workspace_host"` + CreatedAt time.Time `json:"created_at"` + ClusterID string `json:"cluster_id,omitempty"` +} + +// SessionStore holds all tracked sessions. +type SessionStore struct { + Sessions []Session `json:"sessions"` +} + +func getStateFilePath(ctx context.Context) (string, error) { + homeDir, err := env.UserHomeDir(ctx) + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, ".databricks", stateFileName), nil +} + +// Load reads the session store from disk. Returns an empty store if the file does not exist. +func Load(ctx context.Context) (*SessionStore, error) { + path, err := getStateFilePath(ctx) + if err != nil { + return nil, err + } + + data, err := os.ReadFile(path) + if os.IsNotExist(err) { + return &SessionStore{}, nil + } + if err != nil { + return nil, fmt.Errorf("failed to read session state file: %w", err) + } + + var store SessionStore + if err := json.Unmarshal(data, &store); err != nil { + return nil, fmt.Errorf("failed to parse session state file: %w", err) + } + return &store, nil +} + +// Save writes the session store to disk atomically. +func Save(ctx context.Context, store *SessionStore) error { + path, err := getStateFilePath(ctx) + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return fmt.Errorf("failed to create state directory: %w", err) + } + + data, err := json.MarshalIndent(store, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal session state: %w", err) + } + + // Atomic write: write to temp file, then rename. + tmpPath := path + ".tmp" + if err := os.WriteFile(tmpPath, data, 0o600); err != nil { + return fmt.Errorf("failed to write session state file: %w", err) + } + if err := os.Rename(tmpPath, path); err != nil { + return fmt.Errorf("failed to rename session state file: %w", err) + } + return nil +} + +// Add persists a new session to the store, replacing any existing session with the same name. +func Add(ctx context.Context, s Session) error { + store, err := Load(ctx) + if err != nil { + return err + } + + // Replace existing session with the same name. + found := false + for i, existing := range store.Sessions { + if existing.Name == s.Name { + store.Sessions[i] = s + found = true + break + } + } + if !found { + store.Sessions = append(store.Sessions, s) + } + + return Save(ctx, store) +} + +// Remove deletes a session by name. +func Remove(ctx context.Context, name string) error { + store, err := Load(ctx) + if err != nil { + return err + } + + filtered := store.Sessions[:0] + for _, s := range store.Sessions { + if s.Name != name { + filtered = append(filtered, s) + } + } + store.Sessions = filtered + return Save(ctx, store) +} + +// FindMatching returns non-expired sessions that match the given workspace host and accelerator. +func FindMatching(ctx context.Context, workspaceHost, accelerator string) ([]Session, error) { + store, err := Load(ctx) + if err != nil { + return nil, err + } + + cutoff := time.Now().Add(-sessionMaxAge) + var result []Session + for _, s := range store.Sessions { + if s.WorkspaceHost == workspaceHost && s.Accelerator == accelerator && s.CreatedAt.After(cutoff) { + result = append(result, s) + } + } + return result, nil +} diff --git a/experimental/ssh/internal/sessions/sessions_test.go b/experimental/ssh/internal/sessions/sessions_test.go new file mode 100644 index 0000000000..9ccc82aa5e --- /dev/null +++ b/experimental/ssh/internal/sessions/sessions_test.go @@ -0,0 +1,183 @@ +package sessions + +import ( + "path/filepath" + "regexp" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadEmpty(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("USERPROFILE", t.TempDir()) + + store, err := Load(t.Context()) + require.NoError(t, err) + assert.Empty(t, store.Sessions) +} + +func TestSaveAndLoad(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + store := &SessionStore{ + Sessions: []Session{ + { + Name: "gpu-a10-abcd1234", + Accelerator: "GPU_1xA10", + WorkspaceHost: "https://test.databricks.com", + CreatedAt: time.Date(2026, 3, 10, 12, 0, 0, 0, time.UTC), + ClusterID: "0310-120000-abc", + }, + }, + } + + err := Save(t.Context(), store) + require.NoError(t, err) + + loaded, err := Load(t.Context()) + require.NoError(t, err) + require.Len(t, loaded.Sessions, 1) + assert.Equal(t, "gpu-a10-abcd1234", loaded.Sessions[0].Name) + assert.Equal(t, "GPU_1xA10", loaded.Sessions[0].Accelerator) + assert.Equal(t, "https://test.databricks.com", loaded.Sessions[0].WorkspaceHost) + assert.Equal(t, "0310-120000-abc", loaded.Sessions[0].ClusterID) +} + +func TestAddAndRemove(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + ctx := t.Context() + + err := Add(ctx, Session{Name: "sess-1", Accelerator: "GPU_1xA10", WorkspaceHost: "https://a.com"}) + require.NoError(t, err) + + err = Add(ctx, Session{Name: "sess-2", Accelerator: "GPU_8xH100", WorkspaceHost: "https://b.com"}) + require.NoError(t, err) + + store, err := Load(ctx) + require.NoError(t, err) + assert.Len(t, store.Sessions, 2) + + err = Remove(ctx, "sess-1") + require.NoError(t, err) + + store, err = Load(ctx) + require.NoError(t, err) + require.Len(t, store.Sessions, 1) + assert.Equal(t, "sess-2", store.Sessions[0].Name) +} + +func TestRemoveNonExistent(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + err := Remove(t.Context(), "no-such-session") + assert.NoError(t, err) +} + +func TestFindMatching(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + ctx := t.Context() + host := "https://test.databricks.com" + + now := time.Now() + + err := Add(ctx, Session{Name: "s1", Accelerator: "GPU_1xA10", WorkspaceHost: host, CreatedAt: now}) + require.NoError(t, err) + err = Add(ctx, Session{Name: "s2", Accelerator: "GPU_8xH100", WorkspaceHost: host, CreatedAt: now}) + require.NoError(t, err) + err = Add(ctx, Session{Name: "s3", Accelerator: "GPU_1xA10", WorkspaceHost: "https://other.com", CreatedAt: now}) + require.NoError(t, err) + err = Add(ctx, Session{Name: "s4", Accelerator: "GPU_1xA10", WorkspaceHost: host, CreatedAt: now}) + require.NoError(t, err) + + matches, err := FindMatching(ctx, host, "GPU_1xA10") + require.NoError(t, err) + assert.Len(t, matches, 2) + assert.Equal(t, "s1", matches[0].Name) + assert.Equal(t, "s4", matches[1].Name) + + matches, err = FindMatching(ctx, host, "GPU_8xH100") + require.NoError(t, err) + assert.Len(t, matches, 1) + assert.Equal(t, "s2", matches[0].Name) + + matches, err = FindMatching(ctx, host, "GPU_4xA100") + require.NoError(t, err) + assert.Empty(t, matches) +} + +func TestFindMatchingExpiresOldSessions(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + ctx := t.Context() + host := "https://test.databricks.com" + + err := Add(ctx, Session{Name: "old", Accelerator: "GPU_1xA10", WorkspaceHost: host, CreatedAt: time.Now().Add(-25 * time.Hour)}) + require.NoError(t, err) + err = Add(ctx, Session{Name: "recent", Accelerator: "GPU_1xA10", WorkspaceHost: host, CreatedAt: time.Now()}) + require.NoError(t, err) + + matches, err := FindMatching(ctx, host, "GPU_1xA10") + require.NoError(t, err) + require.Len(t, matches, 1) + assert.Equal(t, "recent", matches[0].Name) +} + +func TestStateFilePath(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + path, err := getStateFilePath(t.Context()) + require.NoError(t, err) + assert.Equal(t, filepath.Join(tmpDir, ".databricks", stateFileName), path) +} + +// connectionNameRegex mirrors the regex in client.go. +var connectionNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`) + +func TestGenerateSessionName(t *testing.T) { + tests := []struct { + accelerator string + wantPrefix string + wantDatePrefix string + }{ + {"GPU_1xA10", "databricks-gpu-a10-", "databricks-gpu-a10-20"}, + {"GPU_8xH100", "databricks-gpu-h100-", "databricks-gpu-h100-20"}, + {"UNKNOWN_TYPE", "databricks-unknown-type-", "databricks-unknown-type-20"}, + } + + for _, tt := range tests { + t.Run(tt.accelerator, func(t *testing.T) { + name := GenerateSessionName(tt.accelerator) + assert.Greater(t, len(name), len(tt.wantPrefix), "name should be longer than prefix") + assert.Equal(t, tt.wantPrefix, name[:len(tt.wantPrefix)]) + // Verify date component is present (starts with "20" for 2000s dates). + assert.Equal(t, tt.wantDatePrefix, name[:len(tt.wantDatePrefix)]) + assert.True(t, connectionNameRegex.MatchString(name), "generated name %q must match connection name regex", name) + }) + } +} + +func TestGenerateSessionNameUniqueness(t *testing.T) { + seen := make(map[string]bool) + for range 100 { + name := GenerateSessionName("GPU_1xA10") + assert.False(t, seen[name], "duplicate name generated: %s", name) + seen[name] = true + } +} diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go index f6886a4be9..2e37e3c1e1 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig.go +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -160,14 +160,45 @@ func PromptRecreateConfig(ctx context.Context, hostName string) (bool, error) { return response, nil } +// RemoveHostConfig deletes the SSH config file for a given host name. +func RemoveHostConfig(ctx context.Context, hostName string) error { + configPath, err := GetHostConfigPath(ctx, hostName) + if err != nil { + return err + } + err = os.Remove(configPath) + if os.IsNotExist(err) { + return nil + } + return err +} + +// GenerateHostConfig generates an SSH host config block. func GenerateHostConfig(hostName, userName, identityFile, proxyCommand string) string { + return generateHostConfig(hostName, userName, identityFile, proxyCommand, false) +} + +// GenerateServerlessHostConfig generates an SSH host config block for serverless compute. +// It disables strict host key checking since serverless containers generate fresh keys each time, +// and identity is already verified through Databricks authentication and Driver Proxy. +func GenerateServerlessHostConfig(hostName, userName, identityFile, proxyCommand string) string { + return generateHostConfig(hostName, userName, identityFile, proxyCommand, true) +} + +func generateHostConfig(hostName, userName, identityFile, proxyCommand string, serverless bool) string { + hostKeyChecking := "StrictHostKeyChecking accept-new" + knownHostsLine := "" + if serverless { + hostKeyChecking = "StrictHostKeyChecking no" + knownHostsLine = " UserKnownHostsFile /dev/null\n" + } return fmt.Sprintf(` Host %s User %s ConnectTimeout 360 - StrictHostKeyChecking accept-new - IdentitiesOnly yes + %s +%s IdentitiesOnly yes IdentityFile %q ProxyCommand %s -`, hostName, userName, identityFile, proxyCommand) +`, hostName, userName, hostKeyChecking, knownHostsLine, identityFile, proxyCommand) } diff --git a/experimental/ssh/internal/vscode/run.go b/experimental/ssh/internal/vscode/run.go index fb88c32edd..fa48630ff8 100644 --- a/experimental/ssh/internal/vscode/run.go +++ b/experimental/ssh/internal/vscode/run.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" "golang.org/x/mod/semver" ) @@ -148,7 +149,7 @@ func LaunchIDE(ctx context.Context, ideOption, connectionName, userName, databri remoteURI := fmt.Sprintf("ssh-remote+%s@%s", userName, connectionName) remotePath := fmt.Sprintf("/Workspace/Users/%s/", databricksUserName) - cmdio.LogString(ctx, fmt.Sprintf("Launching %s with remote URI: %s and path: %s", ideOption, remoteURI, remotePath)) + log.Infof(ctx, "Launching %s with remote URI: %s and path: %s", ideOption, remoteURI, remotePath) ideCmd := exec.CommandContext(ctx, ide.Command, "--remote", remoteURI, remotePath) ideCmd.Stdout = os.Stdout diff --git a/experimental/ssh/internal/vscode/settings.go b/experimental/ssh/internal/vscode/settings.go index fa71cc70a2..0877a57bff 100644 --- a/experimental/ssh/internal/vscode/settings.go +++ b/experimental/ssh/internal/vscode/settings.go @@ -214,29 +214,33 @@ func validateSettings(v hujson.Value, connectionName string) *missingSettings { func settingsMessage(connectionName string, missing *missingSettings) string { var lines []string if missing.portRange { - lines = append(lines, fmt.Sprintf(" \"%s\": {\"%s\": \"%s\"}", serverPickPortsKey, connectionName, portRange)) + lines = append(lines, fmt.Sprintf(" \"%s\": {\"%s\": \"%s\"}", serverPickPortsKey, connectionName, portRange)) } if missing.platform { - lines = append(lines, fmt.Sprintf(" \"%s\": {\"%s\": \"%s\"}", remotePlatformKey, connectionName, remotePlatform)) + lines = append(lines, fmt.Sprintf(" \"%s\": {\"%s\": \"%s\"}", remotePlatformKey, connectionName, remotePlatform)) } if missing.listenOnSocket { - lines = append(lines, fmt.Sprintf(" \"%s\": true // Global setting that affects all remote ssh connections", listenOnSocketKey)) + lines = append(lines, fmt.Sprintf(" \"%s\": true // Global setting", listenOnSocketKey)) } if len(missing.extensions) > 0 { quoted := make([]string, len(missing.extensions)) for i, ext := range missing.extensions { quoted[i] = fmt.Sprintf("\"%s\"", ext) } - lines = append(lines, fmt.Sprintf(" \"%s\": [%s] // Global setting that affects all remote ssh connections", defaultExtensionsKey, strings.Join(quoted, ", "))) + lines = append(lines, fmt.Sprintf(" \"%s\": [%s] // Global setting", defaultExtensionsKey, strings.Join(quoted, ", "))) } - return strings.Join(lines, "\n") + return " {\n" + strings.Join(lines, ",\n") + "\n }" } func promptUserForUpdate(ctx context.Context, ide, connectionName string, missing *missingSettings) (bool, error) { question := fmt.Sprintf( - "The following settings will be applied to %s for '%s':\n%s\nApply these settings?", + "The following settings will be applied to %s for '%s':\n\n%s\n\nApply these settings?", getIDE(ide).Name, connectionName, settingsMessage(connectionName, missing)) - return cmdio.AskYesOrNo(ctx, question) + ans, err := cmdio.Ask(ctx, question+" [Y/n]", "y") + if err != nil { + return false, err + } + return strings.ToLower(ans) == "y", nil } func handleMissingFile(ctx context.Context, ide, connectionName, settingsPath string) error { @@ -289,11 +293,11 @@ func backupSettings(ctx context.Context, path string) error { latestBak := path + ".latest.bak" if _, err := os.Stat(originalBak); os.IsNotExist(err) { - cmdio.LogString(ctx, "Backing up settings to "+filepath.ToSlash(originalBak)) + log.Infof(ctx, "Backing up settings to %s", filepath.ToSlash(originalBak)) return os.WriteFile(originalBak, data, 0o600) } - cmdio.LogString(ctx, "Backing up settings to "+filepath.ToSlash(latestBak)) + log.Infof(ctx, "Backing up settings to %s", filepath.ToSlash(latestBak)) return os.WriteFile(latestBak, data, 0o600) } diff --git a/libs/cmdio/spinner.go b/libs/cmdio/spinner.go index e9ca438f9b..ed711b7019 100644 --- a/libs/cmdio/spinner.go +++ b/libs/cmdio/spinner.go @@ -2,6 +2,7 @@ package cmdio import ( "context" + "fmt" "sync" "time" @@ -12,15 +13,17 @@ import ( // spinnerModel is the Bubble Tea model for the spinner. type spinnerModel struct { - spinner bubblespinner.Model - suffix string - quitting bool + spinner bubblespinner.Model + suffix string + quitting bool + startTime time.Time // non-zero when elapsed time display is enabled } // Message types for spinner updates. type ( - suffixMsg string - quitMsg struct{} + suffixMsg string + quitMsg struct{} + elapsedTimeMsg struct{ startTime time.Time } ) // newSpinnerModel creates a new spinner model. @@ -50,6 +53,10 @@ func (m spinnerModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.suffix = string(msg) return m, nil + case elapsedTimeMsg: + m.startTime = msg.startTime + return m, nil + case quitMsg: m.quitting = true return m, tea.Quit @@ -69,11 +76,15 @@ func (m spinnerModel) View() string { return "" } + result := m.spinner.View() if m.suffix != "" { - return m.spinner.View() + " " + m.suffix + result += " " + m.suffix } - - return m.spinner.View() + if !m.startTime.IsZero() { + elapsed := time.Since(m.startTime) + result += fmt.Sprintf(" %02d:%02d", int(elapsed.Minutes()), int(elapsed.Seconds())%60) + } + return result } // spinner provides a structured interface for displaying progress indicators. @@ -90,6 +101,13 @@ type spinner struct { done chan struct{} // Closed when tea.Program finishes } +// TrackElapsedTime enables an elapsed time display (MM:SS) next to the spinner message. +func (sp *spinner) TrackElapsedTime() { + if sp.p != nil { + sp.p.Send(elapsedTimeMsg{startTime: time.Now()}) + } +} + // Update sends a status message to the spinner. // This operation sends directly to the tea.Program. func (sp *spinner) Update(msg string) { diff --git a/libs/gorules/rule_time_now_in_testserver.go b/libs/gorules/rule_time_now_in_testserver.go new file mode 100644 index 0000000000..ac20fb8e42 --- /dev/null +++ b/libs/gorules/rule_time_now_in_testserver.go @@ -0,0 +1,14 @@ +package gorules + +import "github.com/quasilyte/go-ruleguard/dsl" + +// NoTimeNowUnixMilliInTestServer forbids direct time.Now().UnixMilli() calls in libs/testserver. +// Use nowMilli() instead to guarantee unique, strictly increasing timestamps. +// Integer millisecond timestamps get indexed replacements in test output (e.g. [UNIX_TIME_MILLIS][0]) +// and collisions between resources cause flaky tests. +func NoTimeNowUnixMilliInTestServer(m dsl.Matcher) { + m.Match(`time.Now().UnixMilli()`). + Where(m.File().PkgPath.Matches(`.*/libs/testserver`) && + !m.File().Name.Matches(`fake_workspace\.go$`)). + Report(`Use nowMilli() instead of time.Now().UnixMilli() in testserver to ensure unique timestamps`) +} diff --git a/libs/telemetry/protos/frontend_log.go b/libs/telemetry/protos/frontend_log.go index 7e6ab1012b..816297a8ee 100644 --- a/libs/telemetry/protos/frontend_log.go +++ b/libs/telemetry/protos/frontend_log.go @@ -19,4 +19,5 @@ type DatabricksCliLog struct { CliTestEvent *CliTestEvent `json:"cli_test_event,omitempty"` BundleInitEvent *BundleInitEvent `json:"bundle_init_event,omitempty"` BundleDeployEvent *BundleDeployEvent `json:"bundle_deploy_event,omitempty"` + SshTunnelEvent *SshTunnelEvent `json:"ssh_tunnel_event,omitempty"` } diff --git a/libs/telemetry/protos/ssh_tunnel_event.go b/libs/telemetry/protos/ssh_tunnel_event.go new file mode 100644 index 0000000000..cd7abf1ecd --- /dev/null +++ b/libs/telemetry/protos/ssh_tunnel_event.go @@ -0,0 +1,48 @@ +package protos + +// SshTunnelComputeType represents the type of compute used for SSH tunnel. +type SshTunnelComputeType string + +const ( + SshTunnelComputeTypeUnspecified SshTunnelComputeType = "TYPE_UNSPECIFIED" + SshTunnelComputeTypeDedicated SshTunnelComputeType = "DEDICATED" + SshTunnelComputeTypeServerless SshTunnelComputeType = "SERVERLESS" +) + +// SshTunnelClientMode represents how the SSH tunnel client is used. +type SshTunnelClientMode string + +const ( + SshTunnelClientModeUnspecified SshTunnelClientMode = "TYPE_UNSPECIFIED" + SshTunnelClientModeSSH SshTunnelClientMode = "SSH_CLIENT" + SshTunnelClientModeProxy SshTunnelClientMode = "PROXY" + SshTunnelClientModeIDE SshTunnelClientMode = "IDE" +) + +// SshTunnelEvent tracks SSH tunnel connection lifecycle and usage. +type SshTunnelEvent struct { + // Type of compute: dedicated cluster or serverless. + ComputeType SshTunnelComputeType `json:"compute_type,omitempty"` + + // GPU accelerator type for serverless compute (e.g., "GPU_1xA10", "GPU_8xH100"). + AcceleratorType string `json:"accelerator_type,omitempty"` + + // IDE used for the connection (e.g., "vscode", "cursor"), empty if none. + IdeType string `json:"ide_type,omitempty"` + + // How the client is used: SSH client, proxy mode, or IDE mode. + ClientMode SshTunnelClientMode `json:"client_mode,omitempty"` + + // Whether this is a reconnection to an existing session. + IsReconnect bool `json:"is_reconnect,omitempty"` + + // Whether the cluster was auto-started by the CLI. + AutoStartCluster bool `json:"auto_start_cluster,omitempty"` + + // Time in milliseconds spent starting the SSH server (including job submission + // and waiting for the server to become ready). Zero if server was already running. + ServerStartTimeMs int64 `json:"server_start_time_ms"` + + // Flag to indicate if the connection was successful + IsSuccess bool `json:"is_success,omitempty"` +} diff --git a/libs/testserver/catalogs.go b/libs/testserver/catalogs.go index bd9598d600..859721ee73 100644 --- a/libs/testserver/catalogs.go +++ b/libs/testserver/catalogs.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "net/http" - "time" "github.com/databricks/databricks-sdk-go/service/catalog" ) @@ -29,14 +28,14 @@ func (s *FakeWorkspace) CatalogsCreate(req Request) Response { Options: createRequest.Options, Properties: createRequest.Properties, FullName: createRequest.Name, - CreatedAt: time.Now().UnixMilli(), + CreatedAt: nowMilli(), CreatedBy: s.CurrentUser().UserName, - UpdatedAt: time.Now().UnixMilli(), UpdatedBy: s.CurrentUser().UserName, MetastoreId: nextUUID(), Owner: s.CurrentUser().UserName, CatalogType: catalog.CatalogTypeManagedCatalog, } + catalogInfo.UpdatedAt = catalogInfo.CreatedAt s.Catalogs[createRequest.Name] = catalogInfo return Response{ @@ -79,7 +78,7 @@ func (s *FakeWorkspace) CatalogsUpdate(req Request, name string) Response { name = updateRequest.NewName } - existing.UpdatedAt = time.Now().UnixMilli() + existing.UpdatedAt = nowMilli() existing.UpdatedBy = s.CurrentUser().UserName s.Catalogs[name] = existing diff --git a/libs/testserver/external_locations.go b/libs/testserver/external_locations.go index 0606ad76e5..b0000deb2a 100644 --- a/libs/testserver/external_locations.go +++ b/libs/testserver/external_locations.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "net/http" - "time" "github.com/databricks/databricks-sdk-go/service/catalog" ) @@ -37,13 +36,13 @@ func (s *FakeWorkspace) ExternalLocationsCreate(req Request) Response { Fallback: createRequest.Fallback, EncryptionDetails: createRequest.EncryptionDetails, FileEventQueue: createRequest.FileEventQueue, - CreatedAt: time.Now().UnixMilli(), + CreatedAt: nowMilli(), CreatedBy: s.CurrentUser().UserName, - UpdatedAt: time.Now().UnixMilli(), UpdatedBy: s.CurrentUser().UserName, MetastoreId: nextUUID(), Owner: s.CurrentUser().UserName, } + locationInfo.UpdatedAt = locationInfo.CreatedAt s.ExternalLocations[createRequest.Name] = locationInfo return Response{ @@ -95,7 +94,7 @@ func (s *FakeWorkspace) ExternalLocationsUpdate(req Request, name string) Respon name = updateRequest.NewName } - existing.UpdatedAt = time.Now().UnixMilli() + existing.UpdatedAt = nowMilli() existing.UpdatedBy = s.CurrentUser().UserName s.ExternalLocations[name] = existing diff --git a/libs/testserver/pipelines.go b/libs/testserver/pipelines.go index 763a38be84..a6ce25b022 100644 --- a/libs/testserver/pipelines.go +++ b/libs/testserver/pipelines.go @@ -3,7 +3,6 @@ package testserver import ( "encoding/json" "fmt" - "time" "github.com/databricks/databricks-sdk-go/service/pipelines" ) @@ -41,7 +40,7 @@ func (s *FakeWorkspace) PipelineCreate(req Request) Response { pipelineId := nextUUID() r.PipelineId = pipelineId r.CreatorUserName = "tester@databricks.com" - r.LastModified = time.Now().UnixMilli() + r.LastModified = nowMilli() r.Name = r.Spec.Name r.RunAsUserName = "tester@databricks.com" r.State = "IDLE" diff --git a/libs/testserver/registered_models.go b/libs/testserver/registered_models.go index 74815ae7a8..e3723a95e9 100644 --- a/libs/testserver/registered_models.go +++ b/libs/testserver/registered_models.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "net/http" - "time" "github.com/databricks/databricks-sdk-go/service/catalog" ) @@ -30,13 +29,13 @@ func (s *FakeWorkspace) RegisteredModelsCreate(req Request) Response { SchemaName: createRequest.SchemaName, StorageLocation: createRequest.StorageLocation, FullName: fullName, - CreatedAt: time.Now().UnixMilli(), + CreatedAt: nowMilli(), CreatedBy: s.CurrentUser().UserName, - UpdatedAt: time.Now().UnixMilli(), UpdatedBy: s.CurrentUser().UserName, MetastoreId: nextUUID(), Owner: s.CurrentUser().UserName, } + registeredModel.UpdatedAt = registeredModel.CreatedAt s.RegisteredModels[fullName] = registeredModel return Response{ @@ -78,7 +77,7 @@ func (s *FakeWorkspace) RegisteredModelsUpdate(req Request, fullName string) Res fullName = existing.CatalogName + "." + existing.SchemaName + "." + updateRequest.NewName } - existing.UpdatedAt = time.Now().UnixMilli() + existing.UpdatedAt = nowMilli() s.RegisteredModels[fullName] = existing return Response{ Body: existing,