Skip to content

Commit ae41f82

Browse files
authored
GH-48705: [Ruby] Add support for reading dictionary array (#48706)
### Rationale for this change Dictionary array is a special data type. We need to process dictionary batch message for this. ### What changes are included in this PR? * Add `ArrowFormat::DictionaryType` * Add `ArrowFormat::DictionaryArray` * Add support for dictionary batch messages ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. * GitHub Issue: #48705 Authored-by: Sutou Kouhei <kou@clear-code.com> Signed-off-by: Sutou Kouhei <kou@clear-code.com>
1 parent 6ee7f7e commit ae41f82

File tree

7 files changed

+313
-115
lines changed

7 files changed

+313
-115
lines changed

ruby/red-arrow-format/lib/arrow-format/array.rb

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -79,54 +79,34 @@ def initialize(type, size, validity_buffer, values_buffer)
7979
super(type, size, validity_buffer)
8080
@values_buffer = values_buffer
8181
end
82-
end
8382

84-
class Int8Array < IntArray
8583
def to_a
86-
apply_validity(@values_buffer.values(:S8, 0, @size))
84+
apply_validity(@values_buffer.values(@type.buffer_type, 0, @size))
8785
end
8886
end
8987

88+
class Int8Array < IntArray
89+
end
90+
9091
class UInt8Array < IntArray
91-
def to_a
92-
apply_validity(@values_buffer.values(:U8, 0, @size))
93-
end
9492
end
9593

9694
class Int16Array < IntArray
97-
def to_a
98-
apply_validity(@values_buffer.values(:s16, 0, @size))
99-
end
10095
end
10196

10297
class UInt16Array < IntArray
103-
def to_a
104-
apply_validity(@values_buffer.values(:u16, 0, @size))
105-
end
10698
end
10799

108100
class Int32Array < IntArray
109-
def to_a
110-
apply_validity(@values_buffer.values(:s32, 0, @size))
111-
end
112101
end
113102

114103
class UInt32Array < IntArray
115-
def to_a
116-
apply_validity(@values_buffer.values(:u32, 0, @size))
117-
end
118104
end
119105

120106
class Int64Array < IntArray
121-
def to_a
122-
apply_validity(@values_buffer.values(:s64, 0, @size))
123-
end
124107
end
125108

126109
class UInt64Array < IntArray
127-
def to_a
128-
apply_validity(@values_buffer.values(:u64, 0, @size))
129-
end
130110
end
131111

132112
class FloatingPointArray < Array
@@ -410,6 +390,27 @@ def to_a
410390
end
411391
end
412392

393+
class MapArray < VariableSizeListArray
394+
def to_a
395+
super.collect do |entries|
396+
if entries.nil?
397+
entries
398+
else
399+
hash = {}
400+
entries.each do |key, value|
401+
hash[key] = value
402+
end
403+
hash
404+
end
405+
end
406+
end
407+
408+
private
409+
def offset_type
410+
:s32 # TODO: big endian support
411+
end
412+
end
413+
413414
class UnionArray < Array
414415
def initialize(type, size, types_buffer, children)
415416
super(type, size, nil)
@@ -449,24 +450,27 @@ def to_a
449450
end
450451
end
451452

452-
class MapArray < VariableSizeListArray
453+
class DictionaryArray < Array
454+
def initialize(type, size, validity_buffer, indices_buffer, dictionary)
455+
super(type, size, validity_buffer)
456+
@indices_buffer = indices_buffer
457+
@dictionary = dictionary
458+
end
459+
453460
def to_a
454-
super.collect do |entries|
455-
if entries.nil?
456-
entries
461+
values = []
462+
@dictionary.each do |dictionary_chunk|
463+
values.concat(dictionary_chunk.to_a)
464+
end
465+
buffer_type = @type.index_type.buffer_type
466+
indices = apply_validity(@indices_buffer.values(buffer_type, 0, @size))
467+
indices.collect do |index|
468+
if index.nil?
469+
nil
457470
else
458-
hash = {}
459-
entries.each do |key, value|
460-
hash[key] = value
461-
end
462-
hash
471+
values[index]
463472
end
464473
end
465474
end
466-
467-
private
468-
def offset_type
469-
:s32 # TODO: big endian support
470-
end
471475
end
472476
end

ruby/red-arrow-format/lib/arrow-format/field.rb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ module ArrowFormat
1818
class Field
1919
attr_reader :name
2020
attr_reader :type
21-
def initialize(name, type, nullable)
21+
attr_reader :dictionary_id
22+
def initialize(name, type, nullable, dictionary_id)
2223
@name = name
2324
@type = type
2425
@nullable = nullable
26+
@dictionary_id = dictionary_id
2527
end
2628

2729
def nullable?

ruby/red-arrow-format/lib/arrow-format/file-reader.rb

Lines changed: 99 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,65 @@ def initialize(input)
4949

5050
validate
5151
@footer = read_footer
52-
@record_batches = @footer.record_batches
52+
@record_batch_blocks = @footer.record_batches
5353
@schema = read_schema(@footer.schema)
54+
@dictionaries = read_dictionaries
5455
end
5556

5657
def n_record_batches
57-
@record_batches.size
58+
@record_batch_blocks.size
5859
end
5960

6061
def read(i)
61-
block = @record_batches[i]
62+
fb_message, body = read_block(@record_batch_blocks[i])
63+
fb_header = fb_message.header
64+
unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch)
65+
raise FileReadError.new(@buffer,
66+
"Not a record batch message: #{i}: " +
67+
fb_header.class.name)
68+
end
69+
read_record_batch(fb_header, @schema, body)
70+
end
71+
72+
def each
73+
return to_enum(__method__) {n_record_batches} unless block_given?
74+
75+
@record_batch_blocks.size.times do |i|
76+
yield(read(i))
77+
end
78+
end
79+
80+
private
81+
def validate
82+
minimum_size = STREAMING_FORMAT_START_OFFSET +
83+
FOOTER_SIZE_SIZE +
84+
END_MARKER_SIZE
85+
if @buffer.size < minimum_size
86+
raise FileReadError.new(@buffer,
87+
"Input must be larger than or equal to " +
88+
"#{minimum_size}: #{@buffer.size}")
89+
end
90+
91+
start_marker = @buffer.slice(0, START_MARKER_SIZE)
92+
if start_marker != MAGIC_BUFFER
93+
raise FileReadError.new(@buffer, "No start marker")
94+
end
95+
end_marker = @buffer.slice(@buffer.size - END_MARKER_SIZE,
96+
END_MARKER_SIZE)
97+
if end_marker != MAGIC_BUFFER
98+
raise FileReadError.new(@buffer, "No end marker")
99+
end
100+
end
101+
102+
def read_footer
103+
footer_size_offset = @buffer.size - END_MARKER_SIZE - FOOTER_SIZE_SIZE
104+
footer_size = @buffer.get_value(FOOTER_SIZE_FORMAT, footer_size_offset)
105+
footer_data = @buffer.slice(footer_size_offset - footer_size,
106+
footer_size)
107+
Org::Apache::Arrow::Flatbuf::Footer.new(footer_data)
108+
end
62109

110+
def read_block(block)
63111
offset = block.offset
64112

65113
# If we can report property error information, we can use
@@ -101,54 +149,65 @@ def read(i)
101149

102150
metadata = @buffer.slice(offset, metadata_length)
103151
fb_message = Org::Apache::Arrow::Flatbuf::Message.new(metadata)
104-
fb_header = fb_message.header
105-
unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch)
106-
raise FileReadError.new(@buffer,
107-
"Not a record batch message: #{i}: " +
108-
fb_header.class.name)
109-
end
110152
offset += metadata_length
111153

112154
body = @buffer.slice(offset, block.body_length)
113-
read_record_batch(fb_header, @schema, body)
114-
end
115155

116-
def each
117-
return to_enum(__method__) {n_record_batches} unless block_given?
118-
119-
@record_batches.size.times do |i|
120-
yield(read(i))
121-
end
156+
[fb_message, body]
122157
end
123158

124-
private
125-
def validate
126-
minimum_size = STREAMING_FORMAT_START_OFFSET +
127-
FOOTER_SIZE_SIZE +
128-
END_MARKER_SIZE
129-
if @buffer.size < minimum_size
130-
raise FileReadError.new(@buffer,
131-
"Input must be larger than or equal to " +
132-
"#{minimum_size}: #{@buffer.size}")
133-
end
159+
def read_dictionaries
160+
dictionary_blocks = @footer.dictionaries
161+
return nil if dictionary_blocks.nil?
134162

135-
start_marker = @buffer.slice(0, START_MARKER_SIZE)
136-
if start_marker != MAGIC_BUFFER
137-
raise FileReadError.new(@buffer, "No start marker")
163+
dictionary_fields = {}
164+
@schema.fields.each do |field|
165+
next unless field.type.is_a?(DictionaryType)
166+
dictionary_fields[field.dictionary_id] = field
138167
end
139-
end_marker = @buffer.slice(@buffer.size - END_MARKER_SIZE,
140-
END_MARKER_SIZE)
141-
if end_marker != MAGIC_BUFFER
142-
raise FileReadError.new(@buffer, "No end marker")
168+
169+
dictionaries = {}
170+
dictionary_blocks.each do |block|
171+
fb_message, body = read_block(block)
172+
fb_header = fb_message.header
173+
unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::DictionaryBatch)
174+
raise FileReadError.new(@buffer,
175+
"Not a dictionary batch message: " +
176+
fb_header.inspect)
177+
end
178+
179+
id = fb_header.id
180+
if fb_header.delta?
181+
unless dictionaries.key?(id)
182+
raise FileReadError.new(@buffer,
183+
"A delta dictionary batch message " +
184+
"must exist after a non delta " +
185+
"dictionary batch message: " +
186+
fb_header.inspect)
187+
end
188+
else
189+
if dictionaries.key?(id)
190+
raise FileReadError.new(@buffer,
191+
"Multiple non delta dictionary batch " +
192+
"messages for the same ID is invalid: " +
193+
fb_header.inspect)
194+
end
195+
end
196+
197+
value_type = dictionary_fields[id].type.value_type
198+
schema = Schema.new([Field.new("dummy", value_type, true, nil)])
199+
record_batch = read_record_batch(fb_header.data, schema, body)
200+
if fb_header.delta?
201+
dictionaries[id] << record_batch.columns[0]
202+
else
203+
dictionaries[id] = [record_batch.columns[0]]
204+
end
143205
end
206+
dictionaries
144207
end
145208

146-
def read_footer
147-
footer_size_offset = @buffer.size - END_MARKER_SIZE - FOOTER_SIZE_SIZE
148-
footer_size = @buffer.get_value(FOOTER_SIZE_FORMAT, footer_size_offset)
149-
footer_data = @buffer.slice(footer_size_offset - footer_size,
150-
footer_size)
151-
Org::Apache::Arrow::Flatbuf::Footer.new(footer_data)
209+
def find_dictionary(id)
210+
@dictionaries[id]
152211
end
153212
end
154213
end

0 commit comments

Comments
 (0)