diff --git a/README.md b/README.md index 56b42850..21aa997e 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,29 @@ MODEL_RUNNER_HOST=http://localhost:13434 ./model-cli list - [Model Specification](https://github.com/docker/model-spec/blob/main/spec.md) - [Community Slack Channel](https://dockercommunity.slack.com/archives/C09H9P5E57B) +### ModelPack Compatibility + +Docker Model Runner supports both Docker model-spec artifacts and CNCF ModelPack artifacts stored in OCI registries. + +For ModelPack images, Docker Model Runner accepts: + +- config media type: `application/vnd.cncf.model.config.v1+json` +- weight layer media types, including: + - `application/vnd.cncf.model.weight.v1.gguf` + - `application/vnd.cncf.model.weight.v1.safetensors` + +This means you can pull and run a ModelPack artifact with the same user workflow: + +```bash +# Pull from any OCI-compliant registry +docker model pull //: + +# Run the model +docker model run //: "Hello" +``` + +If you are publishing artifacts for compatibility across tooling, ensure your image config and layer media types follow the ModelPack spec so downstream clients can detect and use the correct format. + ## Using the Makefile This project includes a Makefile to simplify common development tasks. Docker targets require Docker Desktop >= 4.41.0. diff --git a/pkg/distribution/distribution/bundle_test.go b/pkg/distribution/distribution/bundle_test.go index cecb3eee..6e0d624d 100644 --- a/pkg/distribution/distribution/bundle_test.go +++ b/pkg/distribution/distribution/bundle_test.go @@ -1,6 +1,7 @@ package distribution import ( + "bytes" "errors" "os" "path/filepath" @@ -142,8 +143,8 @@ func TestBundle(t *testing.T) { if err != nil { t.Fatalf("Failed to read file with expected contents: %v", err) } - if string(got) != string(expected) { - t.Fatalf("File contents did not match expected contents. Expected: %s, got: %s", expected, got) + if !bytes.Equal(got, expected) { + t.Fatalf("File contents did not match expected contents") } } }) diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index c2594543..50559f27 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -16,6 +16,7 @@ import ( "github.com/docker/model-runner/pkg/distribution/internal/mutate" "github.com/docker/model-runner/pkg/distribution/internal/progress" "github.com/docker/model-runner/pkg/distribution/internal/store" + "github.com/docker/model-runner/pkg/distribution/modelpack" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/oci/authn" "github.com/docker/model-runner/pkg/distribution/oci/remote" @@ -786,7 +787,9 @@ func checkCompat(image types.ModelArtifact, log *slog.Logger, reference string, if err != nil { return err } - if manifest.Config.MediaType != types.MediaTypeModelConfigV01 && manifest.Config.MediaType != types.MediaTypeModelConfigV02 { + if manifest.Config.MediaType != types.MediaTypeModelConfigV01 && + manifest.Config.MediaType != types.MediaTypeModelConfigV02 && + manifest.Config.MediaType != oci.MediaType(modelpack.MediaTypeModelConfigV1) { return fmt.Errorf("config type %q is unsupported: %w", manifest.Config.MediaType, ErrUnsupportedMediaType) } diff --git a/pkg/distribution/distribution/client_test.go b/pkg/distribution/distribution/client_test.go index c314dc08..948bbb0b 100644 --- a/pkg/distribution/distribution/client_test.go +++ b/pkg/distribution/distribution/client_test.go @@ -15,22 +15,163 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/docker/model-runner/pkg/distribution/internal/mutate" + "github.com/docker/model-runner/pkg/distribution/internal/partial" "github.com/docker/model-runner/pkg/distribution/internal/progress" "github.com/docker/model-runner/pkg/distribution/internal/testutil" + "github.com/docker/model-runner/pkg/distribution/modelpack" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/oci/reference" "github.com/docker/model-runner/pkg/distribution/oci/remote" mdregistry "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/registry/testregistry" "github.com/docker/model-runner/pkg/inference/platform" + "github.com/opencontainers/go-digest" ) var ( testGGUFFile = filepath.Join("..", "assets", "dummy.gguf") ) +type modelPackTestArtifact struct { + rawConfig []byte + layers []oci.Layer +} + +func (m *modelPackTestArtifact) Layers() ([]oci.Layer, error) { + return m.layers, nil +} + +func (m *modelPackTestArtifact) MediaType() (oci.MediaType, error) { + manifest, err := m.Manifest() + if err != nil { + return "", err + } + return manifest.MediaType, nil +} + +func (m *modelPackTestArtifact) Size() (int64, error) { + rawManifest, err := m.RawManifest() + if err != nil { + return 0, err + } + size := int64(len(rawManifest) + len(m.rawConfig)) + for _, layer := range m.layers { + layerSize, err := layer.Size() + if err != nil { + return 0, err + } + size += layerSize + } + return size, nil +} + +func (m *modelPackTestArtifact) ConfigName() (oci.Hash, error) { + hash, _, err := oci.SHA256(bytes.NewReader(m.rawConfig)) + return hash, err +} + +func (m *modelPackTestArtifact) ConfigFile() (*oci.ConfigFile, error) { + return nil, errors.New("invalid for model") +} + +func (m *modelPackTestArtifact) RawConfigFile() ([]byte, error) { + return m.rawConfig, nil +} + +func (m *modelPackTestArtifact) Digest() (oci.Hash, error) { + rawManifest, err := m.RawManifest() + if err != nil { + return oci.Hash{}, err + } + hash, _, err := oci.SHA256(bytes.NewReader(rawManifest)) + return hash, err +} + +func (m *modelPackTestArtifact) Manifest() (*oci.Manifest, error) { + return partial.ManifestForLayers(m) +} + +func (m *modelPackTestArtifact) RawManifest() ([]byte, error) { + manifest, err := m.Manifest() + if err != nil { + return nil, err + } + return json.Marshal(manifest) +} + +func (m *modelPackTestArtifact) LayerByDigest(hash oci.Hash) (oci.Layer, error) { + for _, layer := range m.layers { + layerDigest, err := layer.Digest() + if err != nil { + return nil, err + } + if layerDigest == hash { + return layer, nil + } + } + return nil, fmt.Errorf("layer with digest %s not found", hash) +} + +func (m *modelPackTestArtifact) LayerByDiffID(hash oci.Hash) (oci.Layer, error) { + for _, layer := range m.layers { + layerDiffID, err := layer.DiffID() + if err != nil { + return nil, err + } + if layerDiffID == hash { + return layer, nil + } + } + return nil, fmt.Errorf("layer with diffID %s not found", hash) +} + +func (m *modelPackTestArtifact) GetConfigMediaType() oci.MediaType { + return oci.MediaType(modelpack.MediaTypeModelConfigV1) +} + +func newModelPackTestArtifact(t *testing.T, modelFile string) *modelPackTestArtifact { + t.Helper() + + layer, err := partial.NewLayer(modelFile, oci.MediaType(modelpack.MediaTypeWeightGGUF)) + if err != nil { + t.Fatalf("Failed to create ModelPack layer: %v", err) + } + + diffID, err := layer.DiffID() + if err != nil { + t.Fatalf("Failed to get layer DiffID: %v", err) + } + + now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + mp := modelpack.Model{ + Descriptor: modelpack.ModelDescriptor{ + CreatedAt: &now, + Name: "dummy-modelpack", + }, + Config: modelpack.ModelConfig{ + Format: "gguf", + ParamSize: "8B", + }, + ModelFS: modelpack.ModelFS{ + Type: "layers", + DiffIDs: []digest.Digest{digest.Digest(diffID.String())}, + }, + } + + rawConfig, err := json.Marshal(mp) + if err != nil { + t.Fatalf("Failed to marshal ModelPack config: %v", err) + } + + return &modelPackTestArtifact{ + rawConfig: rawConfig, + layers: []oci.Layer{layer}, + } +} + // newTestClient creates a new client configured for testing with plain HTTP enabled. func newTestClient(storeRootPath string) (*Client, error) { return NewClient( @@ -98,8 +239,8 @@ func TestClientPullModel(t *testing.T) { t.Fatalf("Failed to read pulled model: %v", err) } - if string(pulledContent) != string(modelContent) { - t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) + if !bytes.Equal(pulledContent, modelContent) { + t.Errorf("Pulled model content doesn't match original") } }) @@ -137,8 +278,74 @@ func TestClientPullModel(t *testing.T) { t.Fatalf("Failed to read pulled model: %v", err) } - if string(pulledContent) != string(modelContent) { - t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) + if !bytes.Equal(pulledContent, modelContent) { + t.Errorf("Pulled model content doesn't match original") + } + }) + + t.Run("pull modelpack artifact", func(t *testing.T) { + tempDir := t.TempDir() + + testClient, err := newTestClient(tempDir) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + mpTag := registryHost + "/modelpack-test/model:v1.0.0" + ref, err := reference.ParseReference(mpTag) + if err != nil { + t.Fatalf("Failed to parse reference: %v", err) + } + + mpModel := newModelPackTestArtifact(t, testGGUFFile) + if err := remote.Write(ref, mpModel, nil, remote.WithPlainHTTP(true)); err != nil { + t.Fatalf("Failed to push ModelPack model: %v", err) + } + + if err := testClient.PullModel(t.Context(), mpTag, nil); err != nil { + t.Fatalf("Failed to pull ModelPack model: %v", err) + } + + pulledModel, err := testClient.GetModel(mpTag) + if err != nil { + t.Fatalf("Failed to get pulled model: %v", err) + } + + ggufPaths, err := pulledModel.GGUFPaths() + if err != nil { + t.Fatalf("Failed to get GGUF paths: %v", err) + } + if len(ggufPaths) != 1 { + t.Fatalf("Unexpected number of GGUF files: %d", len(ggufPaths)) + } + + pulledContent, err := os.ReadFile(ggufPaths[0]) + if err != nil { + t.Fatalf("Failed to read pulled GGUF file: %v", err) + } + + originalContent, err := os.ReadFile(testGGUFFile) + if err != nil { + t.Fatalf("Failed to read source GGUF file: %v", err) + } + + if !bytes.Equal(pulledContent, originalContent) { + t.Errorf("Pulled ModelPack model content doesn't match original") + } + + cfg, err := pulledModel.Config() + if err != nil { + t.Fatalf("Failed to read pulled model config: %v", err) + } + if cfg.GetFormat() != "gguf" { + t.Errorf("Config format = %q, want %q", cfg.GetFormat(), "gguf") + } + if cfg.GetParameters() != "8B" { + t.Errorf("Config parameters = %q, want %q", cfg.GetParameters(), "8B") + } + + if _, ok := cfg.(*modelpack.Model); !ok { + t.Errorf("Config type = %T, want *modelpack.Model", cfg) } }) @@ -332,8 +539,8 @@ func TestClientPullModel(t *testing.T) { t.Fatalf("Failed to read pulled model: %v", err) } - if string(pulledContent) != string(testModelContent) { - t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, testModelContent) + if !bytes.Equal(pulledContent, testModelContent) { + t.Errorf("Pulled model content doesn't match original") } // Create a modified version of the model @@ -382,8 +589,8 @@ func TestClientPullModel(t *testing.T) { t.Fatalf("Failed to read updated pulled model: %v", err) } - if string(updatedPulledContent) != string(updatedContent) { - t.Errorf("Updated pulled model content doesn't match: got %q, want %q", updatedPulledContent, updatedContent) + if !bytes.Equal(updatedPulledContent, updatedContent) { + t.Errorf("Updated pulled model content doesn't match") } }) @@ -526,7 +733,7 @@ func TestClientPullModel(t *testing.T) { t.Fatalf("Failed to read pulled model: %v", err) } - if string(pulledContent) != string(modelContent) { + if !bytes.Equal(pulledContent, modelContent) { t.Errorf("Pulled model content doesn't match original") } }) diff --git a/pkg/distribution/distribution/ecr_test.go b/pkg/distribution/distribution/ecr_test.go index 699b8dbb..66547d3f 100644 --- a/pkg/distribution/distribution/ecr_test.go +++ b/pkg/distribution/distribution/ecr_test.go @@ -1,6 +1,7 @@ package distribution import ( + "bytes" "os" "testing" @@ -79,8 +80,8 @@ func TestECRIntegration(t *testing.T) { t.Fatalf("Failed to read pulled model: %v", err) } - if string(pulledContent) != string(modelContent) { - t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) + if !bytes.Equal(pulledContent, modelContent) { + t.Errorf("Pulled model content doesn't match original") } }) diff --git a/pkg/distribution/distribution/gar_test.go b/pkg/distribution/distribution/gar_test.go index 669e10c1..b92665fe 100644 --- a/pkg/distribution/distribution/gar_test.go +++ b/pkg/distribution/distribution/gar_test.go @@ -1,6 +1,7 @@ package distribution import ( + "bytes" "os" "testing" @@ -80,8 +81,8 @@ func TestGARIntegration(t *testing.T) { t.Fatalf("Failed to read pulled model: %v", err) } - if string(pulledContent) != string(modelContent) { - t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) + if !bytes.Equal(pulledContent, modelContent) { + t.Errorf("Pulled model content doesn't match original") } })