Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion wsds/ws_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _shard_n_samples(self, shard_name: (str, str)) -> int:
raise IndexError(f"Shard not found: {shard_name}")
return r[0]

def iter_shard(self, shard_name):
def iter_shard(self, shard_name: (str, str)):
dataset_path, shard_name = shard_name
if shard_name.endswith(".wsds"):
shard_name = shard_name[:-5]
Expand Down
4 changes: 4 additions & 0 deletions wsds/ws_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def write_batch(self, b, flush=False):
try:
record = pyarrow.RecordBatch.from_pylist(b, self._sink_schema if self._sink else None)
except Exception:
print(f"Batch data causing serialization error: {repr(b)}")
print(f"Schema: {self._sink_schema}")
print(f"Error while serializing: {repr(b)}")
raise
if self._sink is None:
Expand All @@ -73,6 +75,8 @@ def write_batch(self, b, flush=False):
self.batch_size *= 2
return
schema = record.schema.with_metadata({"batch_size": str(len(b))})
print(f"Initializing WSBatchedSink with batch size {self.batch_size} and schema:\n{schema}")
print(f"First batch data:\n{repr(b)}")
self._sink = pyarrow.RecordBatchFileWriter(
self.fname, schema, options=pyarrow.ipc.IpcWriteOptions(compression=self.compression)
)
Expand Down