diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index 8c68af45f31..94738242f30 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -17,6 +17,13 @@ --embedding_quantize 4,32 \ --coreml_quantize c4w \ --target_split_size 1048 + + With LoRA adapters (creates multimethod PTE with base + adapter methods): + python export_static_llm_coreml.py \ + --checkpoint /path/to/model.pth \ + --params /path/to/params.json \ + --adapter lora1 /path/to/adapter.safetensors /path/to/adapter_config.json \ + --adapter lora2 /path/to/adapter2.safetensors /path/to/adapter2_config.json """ import argparse @@ -35,10 +42,16 @@ from executorch.examples.apple.coreml.llama.utils import ( replace_linear_with_split_linear, ) +from executorch.examples.models.llama.convert_weights import ( + load_and_convert_unsloth_to_meta, +) from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope -from executorch.examples.models.llama.static_attention import StaticAttentionIOManager +from executorch.examples.models.llama.static_attention import ( + StaticAttentionIOManager, + transform_attention_mha_to_static_attention, +) from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import ExecutorchBackendConfig @@ -111,8 +124,13 @@ def load_model( params_path: str, max_context_len: int, generate_full_logits: bool = True, + adapter_checkpoint: str = None, + adapter_config: str = None, ): - """Load the model from checkpoint with static_mha attention type. + """Load the model from checkpoint using AttentionMHA, then transform to StaticAttention. + + Both base and LoRA models use the same construction path: + AttentionMHA → transform_attention_mha_to_static_attention(split_mha=False). Args: checkpoint_path: Path to the model checkpoint (.pth) @@ -121,6 +139,9 @@ def load_model( generate_full_logits: If True, output logits for all tokens (needed for lookahead decoding). If False, only output logits for the last token (more efficient for standard autoregressive generation). + adapter_checkpoint: Optional path to LoRA adapter weights (.safetensors) + adapter_config: Optional path to adapter_config.json with r, lora_alpha, + target_modules """ with open(params_path, "r") as f: params = json.loads(f.read()) @@ -130,8 +151,14 @@ def load_model( generate_full_logits=generate_full_logits, **params, ) - args.attention_type = "static_mha" - args.attention_kwargs = {"decompose_sdpa_in_mha": True} + args.attention_type = "mha" + + if adapter_config is not None: + with open(adapter_config, "r") as f: + lora_config = json.loads(f.read()) + args.r = lora_config["r"] + args.lora_alpha = lora_config["lora_alpha"] + args.target_modules = lora_config["target_modules"] with torch.device("meta"): model = construct_transformer(args) @@ -142,20 +169,9 @@ def load_model( if "model" in checkpoint: checkpoint = checkpoint["model"] - # Rename attention weight keys for static attention - for i in range(len(model.layers)): - if f"layers.{i}.attention.wq.weight" in checkpoint: - checkpoint[f"layers.{i}.attention.wqs.0.weight"] = checkpoint.pop( - f"layers.{i}.attention.wq.weight" - ) - if f"layers.{i}.attention.wk.weight" in checkpoint: - checkpoint[f"layers.{i}.attention.wks.0.weight"] = checkpoint.pop( - f"layers.{i}.attention.wk.weight" - ) - if f"layers.{i}.attention.wv.weight" in checkpoint: - checkpoint[f"layers.{i}.attention.wvs.0.weight"] = checkpoint.pop( - f"layers.{i}.attention.wv.weight" - ) + if adapter_checkpoint is not None: + adapter_state_dict = load_and_convert_unsloth_to_meta(adapter_checkpoint) + checkpoint.update(adapter_state_dict) missing, unexpected = model.load_state_dict( checkpoint, @@ -167,6 +183,10 @@ def load_model( if unexpected: print(f"Unexpected keys: {unexpected}") + transform_attention_mha_to_static_attention( + model, split_mha=False, decompose_sdpa_in_mha=True + ) + return model, args @@ -309,6 +329,88 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) } +def _prepare_model(model, args, float_dtype, has_lora=False): + """Apply splitting, quantization, and graph breaks to a model. + + Args: + model: The model to prepare (modified in-place). + args: CLI arguments. + float_dtype: Float dtype (torch.float16 or torch.float32). + has_lora: If True, skip LoRA internals during quantization. + """ + model = model.to(float_dtype).eval() + + if args.target_split_size is not None: + print(f"\nSplitting linear layers with target size {args.target_split_size}...") + replace_linear_with_split_linear( + model, + out_target_split_size=args.target_split_size, + out_max_splits=args.max_splits, + in_target_split_size=1, + in_max_splits=1, + ) + + if args.embedding_quantize: + bitwidth, group_size = args.embedding_quantize.split(",") + bitwidth = int(bitwidth) + group_size = int(group_size) + assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" + + print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...") + if group_size == 0: + granularity = PerAxis(0) + else: + granularity = PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bitwidth}") + + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + linear_filter = None + if has_lora: + + def linear_filter(m, fqn): + return ( + isinstance(m, torch.nn.Linear) + and "lora_a" not in fqn + and "lora_b" not in fqn + ) + + if args.linear_quantize == "b4w": + print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + ), + linear_filter, + ) + elif args.linear_quantize == "c4w": + print("\nQuantizing linear layers: 4-bit channelwise...") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerAxis(0), + ), + linear_filter, + ) + + if not args.no_graph_breaks: + print("\nAdding graph breaks between before/after the transformer blocks...") + n_layers = len(model.layers) + model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True) + model.layers[n_layers - 1] = BlockWithGraphBreak( + model.layers[n_layers - 1], break_before=False + ) + + return model + + def main(): parser = argparse.ArgumentParser( description="Export static attention Llama model to CoreML" @@ -401,9 +503,19 @@ def main(): "and generate_full_logits=True for lookahead decoding support.", ) + # LoRA adapter options + parser.add_argument( + "--adapter", + nargs=3, + action="append", + metavar=("NAME", "CHECKPOINT", "CONFIG"), + help="LoRA adapter: method name, path to adapter.safetensors, path to " + "adapter_config.json. Can be specified multiple times for multiple adapters.", + ) + args = parser.parse_args() - # Compute cache length + has_adapters = args.adapter is not None print("Export mode:") if args.multifunction: @@ -412,6 +524,9 @@ def main(): ) else: print("\tSingle method: fixed seqlen, generate_full_logits=True (lookahead)") + if has_adapters: + adapter_names = [a[0] for a in args.adapter] + print(f"\tAdapters: {adapter_names}") print("\nQuantization and datatype:") print(f"\tEmbedding quantize: {args.embedding_quantize}") @@ -443,68 +558,23 @@ def main(): # Set dtype float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] - model = model.to(float_dtype).eval() - - # Apply linear splitting (before quantization) - if args.target_split_size is not None: - print(f"\nSplitting linear layers with target size {args.target_split_size}...") - replace_linear_with_split_linear( - model, - out_target_split_size=args.target_split_size, - out_max_splits=args.max_splits, - in_target_split_size=1, - in_max_splits=1, - ) - - # Apply embedding quantization - if args.embedding_quantize: - bitwidth, group_size = args.embedding_quantize.split(",") - bitwidth = int(bitwidth) - group_size = int(group_size) - assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" - - print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...") - if group_size == 0: - granularity = PerAxis(0) - else: - granularity = PerGroup(group_size) - weight_dtype = getattr(torch, f"int{bitwidth}") - - quantize_( - model, - IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), - lambda m, fqn: isinstance(m, torch.nn.Embedding), - ) - - # Apply linear quantization - if args.linear_quantize == "b4w": - print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...") - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=PerGroup(32), - ), - ) - elif args.linear_quantize == "c4w": - print("\nQuantizing linear layers: 4-bit channelwise...") - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=PerAxis(0), - ), - ) - - # Add graph breaks between transformer blocks - # Keeping model pieces smaller helps with ANE performance - if not args.no_graph_breaks: - print("\nAdding graph breaks between before/after the transformer blocks...") - n_layers = len(model.layers) - model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True) - model.layers[n_layers - 1] = BlockWithGraphBreak( - model.layers[n_layers - 1], break_before=False - ) + model = _prepare_model(model, args, float_dtype, has_lora=False) + + # Load and prepare LoRA adapter models + lora_models = {} + if has_adapters: + for adapter_name, adapter_ckpt, adapter_cfg in args.adapter: + print(f"\nLoading adapter '{adapter_name}' from {adapter_ckpt}...") + lora_model, _ = load_model( + args.checkpoint, + args.params, + args.max_context_len, + generate_full_logits=generate_full_logits, + adapter_checkpoint=adapter_ckpt, + adapter_config=adapter_cfg, + ) + lora_model = _prepare_model(lora_model, args, float_dtype, has_lora=True) + lora_models[adapter_name] = lora_model if args.multifunction: # Multifunction mode: separate prefill and decode graphs with weight sharing @@ -558,6 +628,21 @@ def main(): print("Prefill export successful!") print(prefill_ep) + # Export LoRA adapter methods + for adapter_name, lora_model in lora_models.items(): + print(f"\nTesting eager execution ({adapter_name} decode)...") + with torch.no_grad(): + lora_model(*decode_inputs) + print(f"\nTesting eager execution ({adapter_name} prefill)...") + with torch.no_grad(): + lora_model(*prefill_inputs) + + print(f"\nExporting {adapter_name} decode...") + lora_models[adapter_name] = ( + torch.export.export(lora_model, decode_inputs), + torch.export.export(lora_model, prefill_inputs), + ) + # Generate metadata for C++ runner # constant_methods are shared across all methods, so we prefix method-specific # metadata with the method name @@ -595,6 +680,26 @@ def main(): "prefill_kv_cache_specs": prefill_metadata["kv_cache_specs"], } + # Add per-adapter metadata (same I/O structure as base decode/prefill) + for adapter_name in list(lora_models.keys()): + for prefix_from, prefix_to in [ + ("decode", f"{adapter_name}_forward"), + ("prefill", f"{adapter_name}_prefill"), + ]: + src = decode_metadata if prefix_from == "decode" else prefill_metadata + constant_methods[f"{prefix_to}_input_len"] = src["forward_input_len"] + constant_methods[f"{prefix_to}_freqs_cos_input_index"] = src[ + "freqs_cos_input_index" + ] + constant_methods[f"{prefix_to}_freqs_sin_input_index"] = src[ + "freqs_sin_input_index" + ] + constant_methods[f"{prefix_to}_mask_specs"] = src["mask_specs"] + constant_methods[f"{prefix_to}_kv_cache_specs"] = src["kv_cache_specs"] + + if has_adapters: + constant_methods["has_lora"] = True + # Setup CoreML partitioner with multimethod weight sharing print("\nSetting up CoreML partitioner (multifunction with weight sharing)...") compile_specs = CoreMLBackend.generate_compile_specs( @@ -622,22 +727,114 @@ def main(): edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) # Create multi-method edge manager with decode as "forward" and prefill as "prefill" + method_programs = {"forward": decode_ep, "prefill": prefill_ep} + for adapter_name, (lora_decode_ep, lora_prefill_ep) in lora_models.items(): + method_programs[f"{adapter_name}_forward"] = lora_decode_ep + method_programs[f"{adapter_name}_prefill"] = lora_prefill_ep + edge_manager = to_edge_transform_and_lower( - {"forward": decode_ep, "prefill": prefill_ep}, + method_programs, partitioner=[partitioner], constant_methods=constant_methods, compile_config=edge_compile_config, ) - print("\nDelegated program (decode/forward):") - print(format_delegated_graph(edge_manager.exported_program().graph_module)) + for method_name in method_programs: + print(f"\nDelegated program ({method_name}):") + print( + format_delegated_graph( + edge_manager.exported_program(method_name).graph_module + ) + ) + elif has_adapters: + # Adapters without multifunction: base + adapter methods, all same seqlen + print(f"\nCreating example inputs (seqlen={args.input_len})...") + example_inputs, example_cache_len = _create_example_inputs( + model_args, args.input_len, args.max_context_len, float_dtype + ) - print("\nDelegated program (prefill):") - print( - format_delegated_graph( - edge_manager.exported_program("prefill").graph_module + # Test eager execution + print("\nTesting eager execution...") + with torch.no_grad(): + model(*example_inputs) + print("Eager execution successful!") + + # Export the base model + print("\nExporting base model...") + base_ep = torch.export.export(model, example_inputs) + print("Export successful!") + print(base_ep) + + # Export adapter models + method_programs = {"forward": base_ep} + for adapter_name, lora_model in lora_models.items(): + print(f"\nTesting eager execution ({adapter_name})...") + with torch.no_grad(): + lora_model(*example_inputs) + print(f"\nExporting {adapter_name}...") + lora_ep = torch.export.export(lora_model, example_inputs) + print(f"Export successful ({adapter_name})!") + print(lora_ep) + method_programs[adapter_name] = lora_ep + + # Generate metadata for C++ runner + print("\nGenerating metadata for C++ runner...") + constant_methods = _get_metadata( + model_args, example_inputs, args.input_len, example_cache_len, float_dtype + ) + + # Add per-adapter metadata (same I/O structure as base) + base_metadata = dict(constant_methods) + for adapter_name in lora_models: + for key in [ + "forward_input_len", + "freqs_cos_input_index", + "freqs_sin_input_index", + "mask_specs", + "kv_cache_specs", + ]: + constant_methods[f"{adapter_name}_{key}"] = base_metadata[key] + constant_methods["has_lora"] = True + + # Setup CoreML partitioner with multimethod weight sharing + print("\nSetting up CoreML partitioner (multimethod with weight sharing)...") + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision={ + torch.float16: ct.precision.FLOAT16, + torch.float32: ct.precision.FLOAT32, + }[float_dtype], + compute_unit=ct.ComputeUnit.CPU_AND_NE, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, + ) + compile_specs.append( + CoreMLBackend.generate_multimethod_weight_sharing_strategy_compile_spec( + MULTIMETHOD_WEIGHT_SHARING_STRATEGY.POSITIONAL ) ) + partitioner = CoreMLPartitioner( + compile_specs=compile_specs, + take_over_mutable_buffer=False, + skip_ops_for_coreml_delegation=[], + ) + + # Lower to edge + print(f"\nLowering to edge ({len(method_programs)} methods: {list(method_programs.keys())})...") + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + edge_manager = to_edge_transform_and_lower( + method_programs, + partitioner=[partitioner], + constant_methods=constant_methods, + compile_config=edge_compile_config, + ) + + for method_name in method_programs: + print(f"\nDelegated program ({method_name}):") + print( + format_delegated_graph( + edge_manager.exported_program(method_name).graph_module + ) + ) else: # Single method mode: fixed seqlen with generate_full_logits=True for lookahead print(f"\nCreating example inputs (seqlen={args.input_len})...") diff --git a/examples/apple/coreml/llama/utils.py b/examples/apple/coreml/llama/utils.py index 1e5a842fed5..1356587ed57 100644 --- a/examples/apple/coreml/llama/utils.py +++ b/examples/apple/coreml/llama/utils.py @@ -6,6 +6,8 @@ import torch +from executorch.examples.models.llama.lora import LoRALinear + class SplitLinearModule(torch.nn.Module): def __init__( @@ -94,7 +96,9 @@ def replace_linear_with_split_linear( model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1 ): for name, module in model.named_children(): - if isinstance(module, torch.nn.Linear): + if isinstance(module, LoRALinear): + continue + elif isinstance(module, torch.nn.Linear): assert module.bias is None, "SplitLinearModule does not support bias" new_module = SplitLinearModule( module.in_features, diff --git a/examples/models/llama/lora.py b/examples/models/llama/lora.py index 12c1c4e5d68..c75d99c7ca2 100644 --- a/examples/models/llama/lora.py +++ b/examples/models/llama/lora.py @@ -28,6 +28,7 @@ def __init__( self.use_bias = use_bias self.dropout = dropout + # self.linear = nn.Linear(in_dim, out_dim, bias=use_bias) linear = nn.Linear(in_dim, out_dim, bias=use_bias) weight = linear.weight bias = linear.bias if self.use_bias else None @@ -41,6 +42,7 @@ def __init__( self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: + # out = self.linear(x) out = torch.nn.functional.linear(x, self.weight, self.bias) lora_out = self.lora_a(self.dropout(x)) lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)