Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions src/ipc/metadata/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ export class Message<T extends MessageHeader = any> {
const bodyLength: bigint = _message.bodyLength()!;
const version: MetadataVersion = _message.version();
const headerType: MessageHeader = _message.headerType();
const message = new Message(bodyLength, version, headerType);
const metadata = decodeMessageCustomMetadata(_message);
const message = new Message(bodyLength, version, headerType, undefined, metadata);
message._createHeader = decodeMessageHeader(_message, headerType);
return message;
}
Expand All @@ -98,11 +99,24 @@ export class Message<T extends MessageHeader = any> {
} else if (message.isDictionaryBatch()) {
headerOffset = DictionaryBatch.encode(b, message.header() as DictionaryBatch);
}

// Encode custom metadata if present (must be done before startMessage)
const customMetadataOffset = !(message.metadata && message.metadata.size > 0) ? -1 :
_Message.createCustomMetadataVector(b, [...message.metadata].map(([k, v]) => {
const key = b.createString(`${k}`);
const val = b.createString(`${v}`);
_KeyValue.startKeyValue(b);
_KeyValue.addKey(b, key);
_KeyValue.addValue(b, val);
return _KeyValue.endKeyValue(b);
}));

_Message.startMessage(b);
_Message.addVersion(b, MetadataVersion.V5);
_Message.addHeader(b, headerOffset);
_Message.addHeaderType(b, message.headerType);
_Message.addBodyLength(b, BigInt(message.bodyLength));
if (customMetadataOffset !== -1) { _Message.addCustomMetadata(b, customMetadataOffset); }
_Message.finishMessageBuffer(b, _Message.endMessage(b));
return b.asUint8Array();
}
Expand All @@ -113,7 +127,7 @@ export class Message<T extends MessageHeader = any> {
return new Message(0, MetadataVersion.V5, MessageHeader.Schema, header);
}
if (header instanceof RecordBatch) {
return new Message(bodyLength, MetadataVersion.V5, MessageHeader.RecordBatch, header);
return new Message(bodyLength, MetadataVersion.V5, MessageHeader.RecordBatch, header, header.metadata);
}
if (header instanceof DictionaryBatch) {
return new Message(bodyLength, MetadataVersion.V5, MessageHeader.DictionaryBatch, header);
Expand All @@ -126,24 +140,27 @@ export class Message<T extends MessageHeader = any> {
protected _bodyLength: number;
protected _version: MetadataVersion;
protected _compression: BodyCompression | null;
protected _metadata: Map<string, string>;
public get type() { return this.headerType; }
public get version() { return this._version; }
public get headerType() { return this._headerType; }
public get compression() { return this._compression; }
public get bodyLength() { return this._bodyLength; }
public get metadata() { return this._metadata; }
declare protected _createHeader: MessageHeaderDecoder;
public header() { return this._createHeader<T>(); }
public isSchema(): this is Message<MessageHeader.Schema> { return this.headerType === MessageHeader.Schema; }
public isRecordBatch(): this is Message<MessageHeader.RecordBatch> { return this.headerType === MessageHeader.RecordBatch; }
public isDictionaryBatch(): this is Message<MessageHeader.DictionaryBatch> { return this.headerType === MessageHeader.DictionaryBatch; }

constructor(bodyLength: bigint | number, version: MetadataVersion, headerType: T, header?: any) {
constructor(bodyLength: bigint | number, version: MetadataVersion, headerType: T, header?: any, metadata?: Map<string, string>) {
this._version = version;
this._headerType = headerType;
this.body = new Uint8Array(0);
this._compression = header?.compression;
header && (this._createHeader = () => header);
this._bodyLength = bigIntToNumber(bodyLength);
this._metadata = metadata || new Map();
}
}

Expand All @@ -157,23 +174,27 @@ export class RecordBatch {
protected _buffers: BufferRegion[];
protected _compression: BodyCompression | null;
protected _variadicBufferCounts: number[];
protected _metadata: Map<string, string>;
public get nodes() { return this._nodes; }
public get length() { return this._length; }
public get buffers() { return this._buffers; }
public get compression() { return this._compression; }
public get variadicBufferCounts() { return this._variadicBufferCounts; }
public get metadata() { return this._metadata; }
constructor(
length: bigint | number,
nodes: FieldNode[],
buffers: BufferRegion[],
compression: BodyCompression | null,
variadicBufferCounts: number[] = []
variadicBufferCounts: number[] = [],
metadata?: Map<string, string>
) {
this._nodes = nodes;
this._buffers = buffers;
this._length = bigIntToNumber(length);
this._compression = compression;
this._variadicBufferCounts = variadicBufferCounts;
this._metadata = metadata || new Map();
}
}

Expand Down Expand Up @@ -468,6 +489,17 @@ function decodeCustomMetadata(parent?: _Schema | _Field | null) {
return data;
}

/** @ignore */
function decodeMessageCustomMetadata(message: _Message) {
const data = new Map<string, string>();
for (let entry, key, i = -1, n = Math.trunc(message.customMetadataLength()); ++i < n;) {
if ((entry = message.customMetadata(i)) && (key = entry.key()) != null) {
data.set(key, entry.value()!);
}
}
return data;
}

/** @ignore */
function decodeIndexType(_type: _Int) {
return new Int(_type.isSigned(), _type.bitWidth() as IntBitWidth);
Expand Down
12 changes: 6 additions & 6 deletions src/ipc/reader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ abstract class RecordBatchReaderImpl<T extends TypeMap = any> implements RecordB
return this;
}

protected _loadRecordBatch(header: metadata.RecordBatch, body: Uint8Array): RecordBatch<T> {
protected _loadRecordBatch(header: metadata.RecordBatch, body: Uint8Array, messageMetadata?: Map<string, string>): RecordBatch<T> {
let children: Data<any>[];
if (header.compression != null) {
const codec = compressionRegistry.get(header.compression.type);
Expand All @@ -379,7 +379,7 @@ abstract class RecordBatchReaderImpl<T extends TypeMap = any> implements RecordB
}

const data = makeData({ type: new Struct(this.schema.fields), length: header.length, children });
return new RecordBatch(this.schema, data);
return new RecordBatch(this.schema, data, messageMetadata);
}

protected _loadDictionaryBatch(header: metadata.DictionaryBatch, body: Uint8Array) {
Expand Down Expand Up @@ -512,7 +512,7 @@ class RecordBatchStreamReaderImpl<T extends TypeMap = any> extends RecordBatchRe
this._recordBatchIndex++;
const header = message.header();
const buffer = reader.readMessageBody(message.bodyLength);
const recordBatch = this._loadRecordBatch(header, buffer);
const recordBatch = this._loadRecordBatch(header, buffer, message.metadata);
return { done: false, value: recordBatch };
} else if (message.isDictionaryBatch()) {
this._dictionaryIndex++;
Expand Down Expand Up @@ -587,7 +587,7 @@ class AsyncRecordBatchStreamReaderImpl<T extends TypeMap = any> extends RecordBa
this._recordBatchIndex++;
const header = message.header();
const buffer = await reader.readMessageBody(message.bodyLength);
const recordBatch = this._loadRecordBatch(header, buffer);
const recordBatch = this._loadRecordBatch(header, buffer, message.metadata);
return { done: false, value: recordBatch };
} else if (message.isDictionaryBatch()) {
this._dictionaryIndex++;
Expand Down Expand Up @@ -640,7 +640,7 @@ class RecordBatchFileReaderImpl<T extends TypeMap = any> extends RecordBatchStre
if (message?.isRecordBatch()) {
const header = message.header();
const buffer = this._reader.readMessageBody(message.bodyLength);
const recordBatch = this._loadRecordBatch(header, buffer);
const recordBatch = this._loadRecordBatch(header, buffer, message.metadata);
return recordBatch;
}
}
Expand Down Expand Up @@ -714,7 +714,7 @@ class AsyncRecordBatchFileReaderImpl<T extends TypeMap = any> extends AsyncRecor
if (message?.isRecordBatch()) {
const header = message.header();
const buffer = await this._reader.readMessageBody(message.bodyLength);
const recordBatch = this._loadRecordBatch(header, buffer);
const recordBatch = this._loadRecordBatch(header, buffer, message.metadata);
return recordBatch;
}
}
Expand Down
6 changes: 5 additions & 1 deletion src/ipc/writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ export class RecordBatchWriter<T extends TypeMap = any> extends ReadableInterop<
return this;
}

public write(payload?: Table<T> | RecordBatch<T> | Iterable<RecordBatch<T>> | null): void;
// Overload for UnderlyingSink compatibility (used by DOM streams)
public write(chunk: RecordBatch<T>, controller: WritableStreamDefaultController): void;
public write(payload?: Table<T> | RecordBatch<T> | Iterable<RecordBatch<T>> | null) {
let schema: Schema<T> | null = null;

Expand Down Expand Up @@ -275,7 +278,7 @@ export class RecordBatchWriter<T extends TypeMap = any> extends ReadableInterop<

protected _writeRecordBatch(batch: RecordBatch<T>) {
const { byteLength, nodes, bufferRegions, buffers, variadicBufferCounts } = this._assembleRecordBatch(batch);
const recordBatch = new metadata.RecordBatch(batch.numRows, nodes, bufferRegions, this._compression, variadicBufferCounts);
const recordBatch = new metadata.RecordBatch(batch.numRows, nodes, bufferRegions, this._compression, variadicBufferCounts, batch.metadata);
const message = Message.from(recordBatch, byteLength);
return this
._writeDictionaries(batch)
Expand Down Expand Up @@ -589,3 +592,4 @@ function recordBatchToJSON(records: RecordBatch) {
'columns': columns
}, null, 2);
}

25 changes: 17 additions & 8 deletions src/recordbatch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ export interface RecordBatch<T extends TypeMap = any> {
export class RecordBatch<T extends TypeMap = any> {

constructor(columns: { [P in keyof T]: Data<T[P]> });
constructor(schema: Schema<T>, data?: Data<Struct<T>>);
constructor(schema: Schema<T>, data?: Data<Struct<T>>, metadata?: Map<string, string>);
constructor(...args: any[]) {
switch (args.length) {
case 3:
case 2: {
[this.schema] = args;
if (!(this.schema instanceof Schema)) {
Expand All @@ -60,7 +61,8 @@ export class RecordBatch<T extends TypeMap = any> {
nullCount: 0,
type: new Struct<T>(this.schema.fields),
children: this.schema.fields.map((f) => makeData({ type: f.type, nullCount: 0 }))
})
}),
this._metadata = new Map()
] = args;
if (!(this.data instanceof Data)) {
throw new TypeError('RecordBatch constructor expects a [Schema, Data] pair.');
Expand All @@ -84,17 +86,24 @@ export class RecordBatch<T extends TypeMap = any> {
const schema = new Schema<T>(fields);
const data = makeData({ type: new Struct<T>(fields), length, children, nullCount: 0 });
[this.schema, this.data] = ensureSameLengthData<T>(schema, data.children as Data<T[keyof T]>[], length);
this._metadata = new Map();
break;
}
default: throw new TypeError('RecordBatch constructor expects an Object mapping names to child Data, or a [Schema, Data] pair.');
}
}

protected _dictionaries?: Map<number, Vector>;
protected _metadata: Map<string, string>;

public readonly schema: Schema<T>;
public readonly data: Data<Struct<T>>;

/**
* Custom metadata for this RecordBatch.
*/
public get metadata() { return this._metadata; }

public get dictionaries() {
return this._dictionaries || (this._dictionaries = collectDictionaries(this.schema.fields, this.data.children));
}
Expand Down Expand Up @@ -188,7 +197,7 @@ export class RecordBatch<T extends TypeMap = any> {
*/
public slice(begin?: number, end?: number): RecordBatch<T> {
const [slice] = new Vector([this.data]).slice(begin, end).data;
return new RecordBatch(this.schema, slice);
return new RecordBatch(this.schema, slice, this._metadata);
}

/**
Expand Down Expand Up @@ -240,7 +249,7 @@ export class RecordBatch<T extends TypeMap = any> {
schema = new Schema(fields, new Map(this.schema.metadata));
data = makeData({ type: new Struct<T>(fields), children });
}
return new RecordBatch(schema, data);
return new RecordBatch(schema, data, this._metadata);
}

/**
Expand All @@ -259,7 +268,7 @@ export class RecordBatch<T extends TypeMap = any> {
children[index] = this.data.children[index] as Data<T[K]>;
}
}
return new RecordBatch(schema, makeData({ type, length: this.numRows, children }));
return new RecordBatch(schema, makeData({ type, length: this.numRows, children }), this._metadata);
}

/**
Expand All @@ -272,7 +281,7 @@ export class RecordBatch<T extends TypeMap = any> {
const schema = this.schema.selectAt<K>(columnIndices);
const children = columnIndices.map((i) => this.data.children[i]).filter(Boolean);
const subset = makeData({ type: new Struct(schema.fields), length: this.numRows, children });
return new RecordBatch<{ [P in keyof K]: K[P] }>(schema, subset);
return new RecordBatch<{ [P in keyof K]: K[P] }>(schema, subset, this._metadata);
}

// Initialize this static property via an IIFE so bundlers don't tree-shake
Expand Down Expand Up @@ -347,9 +356,9 @@ function collectDictionaries(fields: Field[], children: readonly Data[], diction
* @private
*/
export class _InternalEmptyPlaceholderRecordBatch<T extends TypeMap = any> extends RecordBatch<T> {
constructor(schema: Schema<T>) {
constructor(schema: Schema<T>, metadata?: Map<string, string>) {
const children = schema.fields.map((f) => makeData({ type: f.type }));
const data = makeData({ type: new Struct<T>(schema.fields), nullCount: 0, children });
super(schema, data);
super(schema, data, metadata || new Map());
}
}
Binary file added test/data/test_message_metadata.arrow
Binary file not shown.
97 changes: 97 additions & 0 deletions test/unit/ipc/reader/message-metadata-tests.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import { readFileSync } from 'node:fs';
import path from 'node:path';
import { tableFromIPC, RecordBatch } from 'apache-arrow';

// Path to the test file with message-level metadata
// Use process.cwd() since tests are run from project root
const testFilePath = path.resolve(process.cwd(), 'test/data/test_message_metadata.arrow');

describe('RecordBatch message metadata', () => {
const buffer = readFileSync(testFilePath);
const table = tableFromIPC(buffer);

test('should read RecordBatch metadata from IPC file', () => {
expect(table.batches).toHaveLength(3);

for (let i = 0; i < table.batches.length; i++) {
const batch = table.batches[i];
expect(batch).toBeInstanceOf(RecordBatch);
expect(batch.metadata).toBeInstanceOf(Map);
expect(batch.metadata.size).toBeGreaterThan(0);

// Verify specific metadata keys exist
expect(batch.metadata.has('batch_index')).toBe(true);
expect(batch.metadata.has('batch_id')).toBe(true);
expect(batch.metadata.has('producer')).toBe(true);

// Verify batch_index matches the batch position
expect(batch.metadata.get('batch_index')).toBe(String(i));
expect(batch.metadata.get('batch_id')).toBe(`batch_${String(i).padStart(4, '0')}`);
}
});

test('should read unicode metadata values', () => {
const batch = table.batches[0];
expect(batch.metadata.has('unicode_test')).toBe(true);
expect(batch.metadata.get('unicode_test')).toBe('Hello 世界 🌍 مرحبا');
});

test('should handle empty metadata values', () => {
const batch = table.batches[0];
expect(batch.metadata.has('optional_field')).toBe(true);
expect(batch.metadata.get('optional_field')).toBe('');
});

test('should read JSON metadata values', () => {
const batch = table.batches[0];
expect(batch.metadata.has('batch_info_json')).toBe(true);
const jsonStr = batch.metadata.get('batch_info_json')!;
const parsed = JSON.parse(jsonStr);
expect(parsed.batch_number).toBe(0);
expect(parsed.processing_stage).toBe('final');
expect(parsed.tags).toEqual(['validated', 'complete']);
});

describe('metadata preservation', () => {
test('should preserve metadata through slice()', () => {
const batch = table.batches[0];
const sliced = batch.slice(0, 2);
expect(sliced.metadata).toBeInstanceOf(Map);
expect(sliced.metadata.size).toBe(batch.metadata.size);
expect(sliced.metadata.get('batch_index')).toBe(batch.metadata.get('batch_index'));
});

test('should preserve metadata through select()', () => {
const batch = table.batches[0];
const selected = batch.select(['id', 'name']);
expect(selected.metadata).toBeInstanceOf(Map);
expect(selected.metadata.size).toBe(batch.metadata.size);
expect(selected.metadata.get('batch_index')).toBe(batch.metadata.get('batch_index'));
});

test('should preserve metadata through selectAt()', () => {
const batch = table.batches[0];
const selectedAt = batch.selectAt([0, 1]);
expect(selectedAt.metadata).toBeInstanceOf(Map);
expect(selectedAt.metadata.size).toBe(batch.metadata.size);
expect(selectedAt.metadata.get('batch_index')).toBe(batch.metadata.get('batch_index'));
});
});
});
Loading