diff --git a/.ci/scripts/test_lora.sh b/.ci/scripts/test_lora.sh index 89a7e99460e..6a929518e02 100644 --- a/.ci/scripts/test_lora.sh +++ b/.ci/scripts/test_lora.sh @@ -139,8 +139,7 @@ Okay, so I need to calculate 15% of 80." EXPECTED_QUANT_LORA_PREFIX=" <|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant To calculate 15% of 80, we can multiply 80 by 15/100. -80 * 15/100 = 12. -So, 15% of 80 is 12. +So, 15% of 80 is equal to 80 * 15/100 = 12. #### 12 The answer is: 12<|im_end|>" diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 4643ada9336..a0f03205ed5 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -625,7 +625,7 @@ def get_serialized_buffer_index( f"Serializing constant data node {tensor} but tensor value has no bytes", ) sha256_hash = hashlib.sha256(bytes(array)) - named_key = tensor.name + "_" + sha256_hash.hexdigest() + named_key = sha256_hash.hexdigest() size = const_val.untyped_storage().nbytes() xnn_graph.constant_data.append( diff --git a/examples/models/llama/lora.py b/examples/models/llama/lora.py index 12c1c4e5d68..99d583f52dd 100644 --- a/examples/models/llama/lora.py +++ b/examples/models/llama/lora.py @@ -26,22 +26,31 @@ def __init__( self.rank = rank self.alpha = alpha self.use_bias = use_bias - self.dropout = dropout - - linear = nn.Linear(in_dim, out_dim, bias=use_bias) - weight = linear.weight - bias = linear.bias if self.use_bias else None - self.register_parameter("weight", nn.Parameter(weight)) - self.register_parameter( - "bias", nn.Parameter(bias) if bias is not None else None - ) + self.linear = nn.Linear(in_dim, out_dim, bias=use_bias) self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + @property + def weight(self): + return self.linear.weight + + @property + def bias(self): + return self.linear.bias + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # Remap keys to "linear.*" + for attr in ("weight", "bias"): + old_key = prefix + attr + new_key = prefix + "linear." + attr + if old_key in state_dict and new_key not in state_dict: + state_dict[new_key] = state_dict.pop(old_key) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + def forward(self, x: torch.Tensor) -> torch.Tensor: - out = torch.nn.functional.linear(x, self.weight, self.bias) + out = self.linear(x) lora_out = self.lora_a(self.dropout(x)) lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index dbd2caad5a0..04a67c800dd 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -144,30 +144,11 @@ def quantize( # noqa C901 from torchao.utils import unwrap_tensor_subclass def filter_fn(m, fqn): - # Check if it's a regular nn.Linear - is_linear = isinstance(m, nn.Linear) - - # Check if it's a LoRALinear (which has a base weight parameter to quantize) - is_lora_linear = False - try: - from executorch.examples.models.llama.lora import LoRALinear - - is_lora_linear = isinstance(m, LoRALinear) - except ImportError: - pass - - # Check if the weight shape is compatible with group size - has_shape_compatible_with_group_size = False - if is_linear or is_lora_linear: - if group_size == 0: - has_shape_compatible_with_group_size = True - else: - has_shape_compatible_with_group_size = ( - m.weight.shape[1] % group_size == 0 - ) - return ( - is_linear or is_lora_linear - ) and has_shape_compatible_with_group_size + if not isinstance(m, nn.Linear): + return False + if group_size == 0: + return True + return m.weight.shape[1] % group_size == 0 weight_dtype = torch.int4 if qmode == "8da4w" else torch.int8 quantize_(