diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 58e92913..ed7feb1e 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -809,11 +809,34 @@ func newRunCmd() *cobra.Command { return nil } - _, err := desktopClient.Inspect(model, false) - if err != nil { - if !errors.Is(err, desktop.ErrNotFound) { - return handleClientError(err, "Failed to inspect model") + modelInfo, err := desktopClient.Inspect(model, false) + modelFoundLocally := err == nil + if err != nil && !errors.Is(err, desktop.ErrNotFound) { + return handleClientError(err, "Failed to inspect model") + } + + if !modelFoundLocally { + remoteInfo, remoteErr := desktopClient.Inspect(model, true) + if remoteErr == nil { + modelInfo = remoteInfo + } + } + + backend := "" + if modelInfo.ID != "" { + backend, _ = GetRequiredBackendFromModelInfo(&modelInfo) + } + + if backend != "" { + if err := EnsureBackendAvailable(backend, cmd); err != nil { + if errors.Is(err, errBackendInstallationCancelled) { + return nil + } + return err } + } + + if !modelFoundLocally { cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") if err := pullModel(cmd, desktopClient, model); err != nil { return err diff --git a/cmd/cli/commands/utils.go b/cmd/cli/commands/utils.go index 643955fc..f5f4dbea 100644 --- a/cmd/cli/commands/utils.go +++ b/cmd/cli/commands/utils.go @@ -1,7 +1,9 @@ package commands import ( + "bufio" "bytes" + "encoding/json" "errors" "fmt" "io" @@ -11,7 +13,11 @@ import ( "github.com/docker/model-runner/cmd/cli/desktop" "github.com/docker/model-runner/cmd/cli/pkg/standalone" "github.com/docker/model-runner/pkg/distribution/oci/reference" + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference/backends/diffusers" + "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/vllm" + dmrm "github.com/docker/model-runner/pkg/inference/models" "github.com/moby/term" "github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter/renderer" @@ -42,6 +48,8 @@ func getDefaultRegistry() string { var errNotRunning = fmt.Errorf("Docker Model Runner is not running. Please start it and try again.\n") +var errBackendInstallationCancelled = errors.New("backend installation cancelled") + func handleClientError(err error, message string) error { if errors.Is(err, desktop.ErrServiceUnavailable) { err = errNotRunning @@ -270,6 +278,105 @@ func newTable(w io.Writer) *tablewriter.Table { ) } +func CheckBackendInstalled(backend string) (bool, error) { + status := desktopClient.Status() + if status.Error != nil { + return false, fmt.Errorf("failed to get backend status: %w", status.Error) + } + + var backendStatus map[string]string + if err := json.Unmarshal(status.Status, &backendStatus); err != nil { + return false, fmt.Errorf("failed to parse backend status: %w", err) + } + + backendState, exists := backendStatus[backend] + if !exists { + return false, nil + } + + state := strings.TrimSpace(strings.ToLower(backendState)) + if strings.HasPrefix(state, "not ") || strings.HasPrefix(state, "error") { + return false, nil + } + + return strings.HasPrefix(state, "installed") || strings.HasPrefix(state, "running"), nil +} + +func PromptInstallBackend(backend string, cmd *cobra.Command) (bool, error) { + fmt.Fprintf(cmd.OutOrStdout(), "Backend %q is not installed. Download and install it now? [Y/n]: ", backend) + + reader := bufio.NewReader(cmd.InOrStdin()) + input, err := reader.ReadString('\n') + if err != nil { + return false, fmt.Errorf("failed to read input: %w", err) + } + + input = strings.TrimSpace(strings.ToLower(input)) + return input == "" || input == "y" || input == "yes", nil +} + +func InstallBackend(backend string) error { + if err := desktopClient.InstallBackend(backend); err != nil { + return fmt.Errorf("failed to install backend %s: %w", backend, err) + } + + return nil +} + +func EnsureBackendAvailable(backend string, cmd *cobra.Command) error { + installed, err := CheckBackendInstalled(backend) + if err != nil { + return err + } + + if installed { + return nil + } + + confirm, err := PromptInstallBackend(backend, cmd) + if err != nil { + return err + } + + if !confirm { + cmd.Printf("Run 'docker model install-runner --backend %s' to install it manually.\n", backend) + return errBackendInstallationCancelled + } + + if err := InstallBackend(backend); err != nil { + return err + } + + installed, err = CheckBackendInstalled(backend) + if err != nil { + return err + } + if !installed { + return fmt.Errorf("backend %q is still not installed; run 'docker model install-runner --backend %s'", backend, backend) + } + + cmd.Printf("Backend %q installed successfully.\n", backend) + return nil +} + +func GetRequiredBackendFromModelInfo(modelInfo *dmrm.Model) (string, error) { + config, ok := modelInfo.Config.(*types.Config) + if !ok { + return llamacpp.Name, nil + } + + switch config.Format { + case types.FormatSafetensors: + return vllm.Name, nil + case types.FormatGGUF: + return llamacpp.Name, nil + case types.FormatDiffusers: + return diffusers.Name, nil + default: + return llamacpp.Name, nil + } +} + func printNextSteps(out io.Writer, messages []string) { if len(messages) == 0 { return diff --git a/cmd/cli/commands/utils_test.go b/cmd/cli/commands/utils_test.go index 0433fc06..cb62ff4e 100644 --- a/cmd/cli/commands/utils_test.go +++ b/cmd/cli/commands/utils_test.go @@ -1,9 +1,26 @@ package commands import ( + "bytes" + "encoding/json" "errors" "fmt" + "io" + "net/http" + "strings" "testing" + + "github.com/docker/model-runner/cmd/cli/desktop" + mockdesktop "github.com/docker/model-runner/cmd/cli/mocks" + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends/diffusers" + "github.com/docker/model-runner/pkg/inference/backends/llamacpp" + "github.com/docker/model-runner/pkg/inference/backends/vllm" + dmrm "github.com/docker/model-runner/pkg/inference/models" + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" ) func TestStripDefaultsFromModelName(t *testing.T) { @@ -112,3 +129,112 @@ func TestHandleClientErrorFormat(t *testing.T) { } }) } + +func setupDesktopClientStatusMock(t *testing.T, ctrl *gomock.Controller, backendStatus map[string]string) { + t.Helper() + + client := mockdesktop.NewMockDockerHttpClient(ctrl) + modelRunner = desktop.NewContextForMock(client) + desktopClient = desktop.New(modelRunner) + + statusJSON, err := json.Marshal(backendStatus) + require.NoError(t, err) + + expectedModelsURL := modelRunner.URL(inference.ModelsPrefix) + expectedStatusURL := modelRunner.URL(inference.InferencePrefix + "/status") + expectedUserAgent := "docker-model-cli/" + desktop.Version + + client.EXPECT().Do(gomock.Cond(func(req any) bool { + r, ok := req.(*http.Request) + return ok && r.URL.String() == expectedModelsURL && r.Header.Get("User-Agent") == expectedUserAgent + })).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("{}"))}, nil) + + client.EXPECT().Do(gomock.Cond(func(req any) bool { + r, ok := req.(*http.Request) + return ok && r.URL.String() == expectedStatusURL && r.Header.Get("User-Agent") == expectedUserAgent + })).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(statusJSON))}, nil) +} + +func TestCheckBackendInstalled(t *testing.T) { + t.Run("running status string is treated as installed", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "running vllm latest-cuda"}) + + installed, err := CheckBackendInstalled(vllm.Name) + require.NoError(t, err) + require.True(t, installed) + }) + + t.Run("not running status is treated as missing", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "not running"}) + + installed, err := CheckBackendInstalled(vllm.Name) + require.NoError(t, err) + require.False(t, installed) + }) + + t.Run("error status is treated as missing", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "error failed to start"}) + + installed, err := CheckBackendInstalled(vllm.Name) + require.NoError(t, err) + require.False(t, installed) + }) +} + +func TestPromptInstallBackend(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.SetIn(strings.NewReader("yes\n")) + out := new(bytes.Buffer) + cmd.SetOut(out) + + confirmed, err := PromptInstallBackend(vllm.Name, cmd) + require.NoError(t, err) + require.True(t, confirmed) + require.Contains(t, out.String(), "Backend \"vllm\" is not installed") +} + +func TestEnsureBackendAvailableCancelled(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "not running"}) + + cmd := &cobra.Command{Use: "test"} + cmd.SetIn(strings.NewReader("n\n")) + out := new(bytes.Buffer) + cmd.SetOut(out) + + err := EnsureBackendAvailable(vllm.Name, cmd) + require.Error(t, err) + require.ErrorIs(t, err, errBackendInstallationCancelled) + require.Contains(t, out.String(), "docker model install-runner --backend vllm") +} + +func TestGetRequiredBackendFromModelInfo(t *testing.T) { + t.Run("safetensors chooses vllm", func(t *testing.T) { + backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatSafetensors}}) + require.NoError(t, err) + require.Equal(t, vllm.Name, backend) + }) + + t.Run("gguf chooses llamacpp", func(t *testing.T) { + backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatGGUF}}) + require.NoError(t, err) + require.Equal(t, llamacpp.Name, backend) + }) + + t.Run("diffusers chooses diffusers backend", func(t *testing.T) { + backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatDiffusers}}) + require.NoError(t, err) + require.Equal(t, diffusers.Name, backend) + }) +}