Skip to content

Misc improvements, fixes; config classes documentation; triton GRPO loss#478

Draft
jlamypoirier wants to merge 20 commits intojlp_reduce_losses_with_countsfrom
jlp_general_improvements
Draft

Misc improvements, fixes; config classes documentation; triton GRPO loss#478
jlamypoirier wants to merge 20 commits intojlp_reduce_losses_with_countsfrom
jlp_general_improvements

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented Mar 27, 2026

✨ Description

In-depth audit of the codebase, tests and documentation with claude

jlamypoirier and others added 2 commits March 27, 2026 12:47
…_property, overhaul config/data tests

- Rename `FieldUpdate` → `FieldOverride` throughout: clearer that it overrides
  inherited field metadata at class-definition time, distinct from the runtime
  `UpdateType.update` config-merge mechanism
- Convert `Field(init=False, hint=FieldHint.derived)` fields in `DistributedConfig`
  and `WandbAlertConfig` to `functools.cached_property`, removing computed state
  from `_validate()` in favour of lazy evaluation
- Overhaul `tests/config/`: consolidate fixture configs into `common.py`, replace
  weak repr/to_logs tests with parametrized checks against explicit expected dicts,
  restructure `test_field.py` with `FieldTestCase`/`ValidCase` dataclasses, add
  comprehensive `UpdateType` test cases in `test_update.py`
- Replace `result_path` with domain-scoped `data_result_path` fixture in all data
  tests to avoid `-n 12` worker collisions

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Make TOKEN_CUMSUM_RATE configurable via SamplingConfig.token_cumsum_rate
- Suppress expected non-writable buffer UserWarning in RedisStreamingDocumentData
- Fix cast warning in build_padded_token_cumsum (size_t → int64_t)
- Use copy-on-write memmap mode ("c") for MemmapDataset to allow worker writes
- Only pin_memory when CUDA is available in GPTData DataLoader
- Remove deprecated estimate_critical_batch field from ScheduleConfig
- Add pytest filterwarnings for PYTHONHASHSEED and itertools deprecation noise
- Update docs: mkdocs inline citations, exclude docs/README.md, add release guide
- Update README and recipe docs with revised benchmark numbers and config names
- Add setup.cfg/pyproject.toml tweaks

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@jlamypoirier jlamypoirier changed the title Jlp general improvements Misc improvements Mar 27, 2026
jlamypoirier and others added 14 commits March 27, 2026 15:47
- ReductionType enum values were copy-pasted from DataType (wrong string literals)
- _serialize_architecture_field dropped dict keys (set comprehension instead of dict)
- LayerBaseWithNamespace/LayerWithNamespace namespace param had spurious default=None
- ParameterConfig/OptionalParameterConfig lacked _abstract=False (inherited True from ModuleConfig); empty _validate() bodies were accidentally suppressing the abstract check; stray pass before actual code in OptionalParameterConfig.get_parameter
- FillInitializationConfig docstring was copy-pasted from NormalInitializationConfig
- NormalInitializationConfig.max field desc said "Min value" (copy-paste from min)
- UniformInitializationConfig.scale had wrong default=None and FieldHint.optional; mean validator Assert.geq(0) was wrong (mean can be negative)
- TensorLogsConfig.max_elements had skip_valid_if_none despite being non-optional int
- RunnableConfig._load_url opened the auth token file twice (second open unused)
- UpdateType converted to StrEnum for consistency

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…tput

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
80 tests covering get_module_dir, _relative_link, _unwrap_optional,
render_hint_badge, _class_one_liner, is_user_field, _extract_config_types,
render_type, render_default, format_nav_yaml, and smoke tests for
render_class_page / render_index_page.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add fast_llm/functional/triton/grpo_loss.py: Triton kernel for GRPO
  loss forward/backward, supporting both non-parallel and vocab-parallel
  (two-pass) modes, mirroring the entropy/z-loss Triton patterns
- Add use_triton field to LanguageModelGRPOLossConfig and dispatch to
  Triton kernel in LanguageModelGRPOLoss._forward_backward
- Update test_lm_losses.py: add num_labels_in_seq, test new_logprobs_mean,
  and add Triton kernel testing (guarded by triton_available)
- Fix Triton interpreter bug in __init__.py: monkeypatch _patch_lang_tensor
  to use .item() instead of int() for tensor.__index__, fixing a pre-existing
  failure of all Triton tests under TRITON_INTERPRET=1 (constexpr int args
  to device functions arrived as 1-d numpy arrays, not scalars)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Fix `self._run.index` → `run_index` in lm_eval evaluator (AttributeError on every eval)
- Fix `broadcast_kwargs` dict iteration missing `.items()` in huggingface inference (ValueError with TP)
- Fix missing `self` parameter in `HuggingfacePreTrainedModel.inner_forward`
- Fix `kwargs.get` → `kwargs.pop` for `gguf_file` in inference config
- Fix `get_state_tensor_iterator` and `import_state_tensor` using weight shard sizes for all shards (wrong for optimizer state shards with frozen params)
- Fix `reset_shard_pad` and debug logging only covering last FSDP in `initialize_weights` (missed frozen FSDP)
- Fix `len(grads_norm_slices) < 0` → `> 0` in grad norm slice merging (always-false condition, no merging ever happened)
- Fix `PowerLRStage._interpolate` and `CosineLRStage._interpolate` incorrectly marked `@abc.abstractmethod`
- Fix `CosineLRStage.lr`/`end_lr` typed as `int` instead of `float`
- Remove hardcoded debug `logger.info` for `layers.1.norm_1.weight` from `SafeLoad`
- Replace open("r"/"w") file handles with `.read_text()`/`.write_text()`/`.touch()` throughout
- Fix `pass` before `wandb.alert(...)` call that suppressed the alert
- Fix duplicate word in docstring ("not not needed")
- Fix misleading field descriptions for `depth_first_micro_batches`/`breadth_first_micro_batches`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
layer_out was only assigned inside the isinstance(output, tuple) branch,
causing NameError when out_channel_begin is set with a plain Linear.
Also fix tuple unpack: layer_out, tp_bias = output[0] → output.

Affects value-only LoRA in GQA attention (attention.py uses out_channel_begin).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…age in Llama import_config

- MixtureOfExpertMLP._forward called _add_shared_experts(top_experts, scores) but the
  signature is _add_shared_experts(scores, top_experts). Passing integer indices as
  scores and float scores as the top_experts would produce wrong dtypes for both the
  shared expert index arange and the concatenated scores tensor.
- LlamaAttentionConverter.import_config raised NotImplementedError with
  `type(config.rotary).__name__` where `config` is a dict (no .rotary attribute);
  use the `rope_type` variable already in scope instead.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Lines 16-17 were an exact duplicate of lines 12-13 (_prediction_distance > 1),
dead code that never executes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- MLPBase._get_intermediate_dims: remove discarded TensorDim("gate_and_up", 2)
  created but immediately thrown away before the ConcatenatedTensorDim call.
- StochasticMixer.get_preprocessing_config: remove first loop that called
  mixer.get_preprocessing_config() on each mixer but discarded every result,
  causing each mixer's method to be called twice.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- DynamicGradScaler.load: restore _scale from checkpoint (was never
  loaded, leaving _scale unset after resume from checkpoint)
- ConstantGradScaler.load: handle case where _scale not yet set
  (when loading checkpoint without prior reset_state call)
- AprielMambaConverter: fix dt_rank auto formula to use
  ceil(hidden_size / 16) matching the reference model

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
`yield from torch.load(...)` on a dict yields only keys (strings), not
the `(parameter_name, shard_name, tensor)` tuples that `_load_weights`
is expected to produce. Fix by iterating `.items()` and yielding the
correct 3-tuple.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Since StrEnum values are str subclasses, .value is redundant when used
in string contexts (f-strings, dict keys, DataLoader args, Triton kernels).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@jlamypoirier jlamypoirier changed the title Misc improvements Misc improvements, fixes; config classes documentation; triton GRPO loss Mar 30, 2026
jlamypoirier and others added 4 commits March 30, 2026 17:44
…e _validate

The class defined _validate twice; Python uses the last definition, silently
dropping Assert.multiple(self.value_heads, self.key_heads) from the first one.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…set/memmap/

- core/distributed.py: Remove spurious tensor.copy_() after Gloo+GPU send
  (copy-paste from recv; send should never write to the source tensor)
- data/dataset/memmap/token.py: Fix _get_nearest_split always rounding down
  (fraction used cumsum[left]/cumsum[left+1] giving always-negative numerator;
  correct formula uses cumsum[left-1]/cumsum[left] as document span boundaries;
  also fixes OOB access when left == len(cumsum) - 1)
- data/dataset/memmap/config.py: Fix blend_metadata writing rejected_spans data
  to wrong key "image_patches" and reading from wrong source key
- layers/block/config.py: Fix raise warnings.warn() crashing with TypeError
  (raise None); should be warnings.warn() without raise
- layers/block/block.py: Fix double-prefixed gradient tensor name in debug logging

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…d timeout

Three issues fixed:
1. ProcessGroupPool was created anonymously and immediately GC'd, calling
   shutdown() on the NCCL communicator before broadcasts completed. Fix by
   storing the pool in a variable (self._pool / pool) and shutting it down
   explicitly in the finally/cleanup block.
2. Consumer used torch.cuda.set_device() via ProcessGroupPool's old use_cuda
   path, corrupting the test context's CUDA device. Fix by adding a device
   parameter to ProcessGroupPool that accepts an explicit torch.device, so
   consumers can pass their already-set current device.
3. Consumer xread timeout was hardcoded at 10s, too short for the first
   training step of SDP/TP configs which require CUDA kernel JIT compilation.
   Fix by using streaming_config.timeout (120s) instead.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- token.py: fix confusing variable names in _get_nearest_split (left/right
  were swapped relative to what searchsorted returns)
- grpo_loss.py: split shared_kwargs from epsilon kwargs to avoid passing
  epsilon params to the parallel_max_logits kernel that doesn't accept them
- linear/config.py: inherit LinearBaseConfig from ModuleConfig and mark
  LinearConfig/AffineLinearConfig/CausalConv1dConfig as non-abstract so
  architecture comparison works correctly for linear layers
- base_model/config.py: assert non-architecture fields are not ModuleConfig
  subclasses (which would silently skip nested architecture validation)
- apriel2.py: use .value on sampling_strategy and activation enum fields
  when serializing to dict
- qwen2.py: fix attention_bias default (False, not True) for Qwen2 import

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant