Skip to content

Commit 18768b7

Browse files
introduce ProtoUnknownFields
1 parent 4dd9296 commit 18768b7

File tree

8 files changed

+446
-18
lines changed

8 files changed

+446
-18
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright 2017-2024 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.serialization.protobuf
6+
7+
import kotlinx.serialization.*
8+
import kotlinx.serialization.protobuf.internal.*
9+
import kotlinx.serialization.protobuf.internal.ProtoWireType
10+
11+
/**
12+
* Mark a property as a holder for unknown fields in protobuf message.
13+
*/
14+
@SerialInfo
15+
@Target(AnnotationTarget.PROPERTY)
16+
@ExperimentalSerializationApi
17+
public annotation class ProtoUnknownFields
18+
19+
@Serializable(with = ProtoMessageSerializer::class)
20+
public class ProtoMessage internal constructor(
21+
public val fields: List<ProtoField>
22+
) {
23+
public companion object {
24+
public val Empty: ProtoMessage = ProtoMessage(emptyList())
25+
}
26+
27+
public val size: Int get() = fields.size
28+
public fun asByteArray(): ByteArray = fields.fold(ByteArray(0)) { acc, protoField -> acc + protoField.asWireContent() }
29+
30+
public constructor(vararg fields: ProtoField) : this(fields.toList())
31+
public fun merge(other: ProtoMessage): ProtoMessage {
32+
return ProtoMessage(fields + other.fields)
33+
}
34+
35+
public fun merge(vararg field: ProtoField): ProtoMessage {
36+
return ProtoMessage(fields + field)
37+
}
38+
39+
override fun hashCode(): Int {
40+
return fields.hashCode()
41+
}
42+
43+
override fun equals(other: Any?): Boolean {
44+
if (this === other) return true
45+
if (other == null || this::class != other::class) return false
46+
47+
other as ProtoMessage
48+
49+
return fields == other.fields
50+
}
51+
}
52+
53+
@OptIn(ExperimentalSerializationApi::class)
54+
@Serializable(with = ProtoFieldSerializer::class)
55+
@KeepGeneratedSerializer
56+
@ConsistentCopyVisibility
57+
public data class ProtoField internal constructor(
58+
internal val id: Int,
59+
internal val wireType: ProtoWireType,
60+
internal val data: ProtoContentHolder
61+
) {
62+
public companion object {
63+
public val Empty: ProtoField = ProtoField(0, ProtoWireType.INVALID, ProtoContentHolder.ByteArrayContent(ByteArray(0)))
64+
}
65+
66+
public fun asWireContent(): ByteArray = byteArrayOf(((id shl 3) or wireType.typeId).toByte()) + data.byteArray
67+
68+
public val contentLength: Int
69+
get() = asWireContent().size
70+
71+
override fun equals(other: Any?): Boolean {
72+
if (this === other) return true
73+
if (other == null || this::class != other::class) return false
74+
75+
other as ProtoField
76+
77+
if (id != other.id) return false
78+
if (wireType != other.wireType) return false
79+
if (!data.contentEquals(other.data)) return false
80+
81+
return true
82+
}
83+
84+
override fun hashCode(): Int {
85+
var result = id
86+
result = 31 * result + wireType.hashCode()
87+
result = 31 * result + data.contentHashCode()
88+
return result
89+
}
90+
}
91+
92+
internal sealed interface ProtoContentHolder {
93+
val byteArray: ByteArray
94+
95+
data class ByteArrayContent(override val byteArray: ByteArray) : ProtoContentHolder {
96+
override fun equals(other: Any?): Boolean {
97+
return other is ProtoContentHolder && this.contentEquals(other)
98+
}
99+
100+
override fun hashCode(): Int {
101+
return this.contentHashCode()
102+
}
103+
}
104+
105+
data class MessageContent(val content: ProtoMessage) : ProtoContentHolder {
106+
override val byteArray: ByteArray
107+
get() = content.asByteArray()
108+
109+
override fun equals(other: Any?): Boolean {
110+
return other is ProtoContentHolder && this.contentEquals(other)
111+
}
112+
113+
override fun hashCode(): Int {
114+
return this.contentHashCode()
115+
}
116+
}
117+
}
118+
119+
internal fun ProtoContentHolder(content: ByteArray): ProtoContentHolder = ProtoContentHolder.ByteArrayContent(content)
120+
121+
internal val ProtoContentHolder.contentLength: Int
122+
get() = byteArray.size
123+
124+
internal fun ProtoContentHolder.contentEquals(other: ProtoContentHolder): Boolean {
125+
return byteArray.contentEquals(other.byteArray)
126+
}
127+
128+
internal fun ProtoContentHolder.contentHashCode(): Int {
129+
return byteArray.contentHashCode()
130+
}
131+

formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Helpers.kt

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,25 @@ internal enum class ProtoWireType(val typeId: Int) {
4646
}
4747

4848
internal const val ID_HOLDER_ONE_OF = -2
49+
internal const val ID_HOLDER_UNKNOWN_FIELDS = -3
4950

51+
private const val UNKNOWN_FIELD_MASK = 1L shl 37
5052
private const val ONEOFMASK = 1L shl 36
5153
private const val INTTYPEMASK = 3L shl 33
5254
private const val PACKEDMASK = 1L shl 32
5355

5456
@Suppress("NOTHING_TO_INLINE")
55-
internal inline fun ProtoDesc(protoId: Int, type: ProtoIntegerType, packed: Boolean = false, oneOf: Boolean = false): ProtoDesc {
57+
internal inline fun ProtoDesc(
58+
protoId: Int,
59+
type: ProtoIntegerType,
60+
packed: Boolean = false,
61+
oneOf: Boolean = false,
62+
unknown: Boolean = false,
63+
): ProtoDesc {
5664
val packedBits = if (packed) PACKEDMASK else 0L
5765
val oneOfBits = if (oneOf) ONEOFMASK else 0L
58-
return packedBits or oneOfBits or type.signature or protoId.toLong()
66+
val unknownBits = if (unknown) UNKNOWN_FIELD_MASK else 0L
67+
return packedBits or oneOfBits or type.signature or protoId.toLong() or unknownBits
5968
}
6069

6170
internal inline val ProtoDesc.protoId: Int get() = (this and Int.MAX_VALUE.toLong()).toInt()
@@ -81,6 +90,9 @@ internal val ProtoDesc.isPacked: Boolean
8190
internal val ProtoDesc.isOneOf: Boolean
8291
get() = (this and ONEOFMASK) != 0L
8392

93+
internal val ProtoDesc.isUnknown: Boolean
94+
get() = (this and UNKNOWN_FIELD_MASK) != 0L
95+
8496
internal fun ProtoDesc.overrideId(protoId: Int): ProtoDesc {
8597
return this and (0xFFFFFFF00000000L) or protoId.toLong()
8698
}
@@ -91,6 +103,7 @@ internal fun SerialDescriptor.extractParameters(index: Int): ProtoDesc {
91103
var format: ProtoIntegerType = ProtoIntegerType.DEFAULT
92104
var protoPacked = false
93105
var isOneOf = false
106+
var isUnknown = false
94107

95108
for (i in annotations.indices) { // Allocation-friendly loop
96109
val annotation = annotations[i]
@@ -103,6 +116,8 @@ internal fun SerialDescriptor.extractParameters(index: Int): ProtoDesc {
103116
protoPacked = true
104117
} else if (annotation is ProtoOneOf) {
105118
isOneOf = true
119+
} else if (annotation is ProtoUnknownFields) {
120+
isUnknown = true
106121
}
107122
}
108123
if (isOneOf) {
@@ -111,7 +126,7 @@ internal fun SerialDescriptor.extractParameters(index: Int): ProtoDesc {
111126
// See [kotlinx.serialization.protobuf.internal.ProtobufDecoder.decodeElementIndex] for detail
112127
protoId = index + 1
113128
}
114-
return ProtoDesc(protoId, format, protoPacked, isOneOf)
129+
return ProtoDesc(protoId, format, protoPacked, isOneOf, isUnknown)
115130
}
116131

117132
/**
@@ -126,6 +141,9 @@ internal fun extractProtoId(descriptor: SerialDescriptor, index: Int, zeroBasedD
126141
if (annotation is ProtoOneOf) {
127142
// Fast return for one of field
128143
return ID_HOLDER_ONE_OF
144+
} else if (annotation is ProtoUnknownFields) {
145+
// Fast return for unknown fields holder
146+
return ID_HOLDER_UNKNOWN_FIELDS
129147
} else if (annotation is ProtoNumber) {
130148
result = annotation.number
131149
// 0 or negative numbers are acceptable for enums
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright 2017-2024 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.serialization.protobuf.internal
6+
7+
import kotlinx.serialization.*
8+
import kotlinx.serialization.builtins.*
9+
import kotlinx.serialization.descriptors.*
10+
import kotlinx.serialization.encoding.*
11+
import kotlinx.serialization.protobuf.*
12+
13+
internal object ProtoMessageSerializer : KSerializer<ProtoMessage> {
14+
internal val fieldsSerializer = ProtoFieldSerializer
15+
16+
override val descriptor: SerialDescriptor
17+
get() = UnknownFieldsDescriptor(fieldsSerializer.descriptor)
18+
19+
override fun deserialize(decoder: Decoder): ProtoMessage {
20+
if (decoder is ProtobufDecoder) {
21+
return decoder.decodeStructure(descriptor) {
22+
ProtoMessage(fieldsSerializer.deserializeComposite(this))
23+
}
24+
}
25+
return ProtoMessage.Empty
26+
}
27+
28+
override fun serialize(encoder: Encoder, value: ProtoMessage) {
29+
if (encoder is ProtobufEncoder) {
30+
value.fields.forEach {
31+
fieldsSerializer.serialize(encoder, it)
32+
}
33+
}
34+
}
35+
}
36+
37+
internal object ProtoFieldSerializer : KSerializer<ProtoField> {
38+
private val delegate = ByteArraySerializer()
39+
40+
override val descriptor: SerialDescriptor
41+
get() = UnknownFieldsDescriptor(delegate.descriptor)
42+
43+
fun deserializeComposite(compositeDecoder: CompositeDecoder): ProtoField {
44+
if (compositeDecoder is ProtobufDecoder) {
45+
return deserialize(compositeDecoder)
46+
}
47+
return ProtoField.Empty
48+
}
49+
50+
override fun deserialize(decoder: Decoder): ProtoField {
51+
if (decoder is ProtobufDecoder) {
52+
return deserialize(decoder, decoder.currentTag)
53+
}
54+
return ProtoField.Empty
55+
}
56+
57+
internal fun deserialize(protobufDecoder: ProtobufDecoder, currentTag: ProtoDesc): ProtoField {
58+
if (currentTag != MISSING_TAG) {
59+
val id = currentTag.protoId
60+
val type = protobufDecoder.currentType
61+
val data = protobufDecoder.decodeRawElement()
62+
val field = ProtoField(
63+
id = id,
64+
wireType = type,
65+
data = ProtoContentHolder(data),
66+
)
67+
return field
68+
}
69+
return ProtoField.Empty
70+
}
71+
72+
override fun serialize(encoder: Encoder, value: ProtoField) {
73+
if (encoder is ProtobufEncoder) {
74+
encoder.encodeRawElement(value.id, value.wireType, value.data.byteArray)
75+
}
76+
}
77+
}
78+
79+
internal class UnknownFieldsDescriptor(private val original: SerialDescriptor) : SerialDescriptor by original {
80+
override val serialName: String
81+
get() = "UnknownProtoFieldsHolder"
82+
83+
override fun equals(other: Any?): Boolean {
84+
return other is UnknownFieldsDescriptor && other.original == original
85+
}
86+
87+
override fun hashCode(): Int {
88+
return original.hashCode()
89+
}
90+
}

0 commit comments

Comments
 (0)