diff --git a/wsds/ws_dataset.py b/wsds/ws_dataset.py index a8ca555..4ae8475 100644 --- a/wsds/ws_dataset.py +++ b/wsds/ws_dataset.py @@ -214,7 +214,7 @@ def _shard_n_samples(self, shard_name: (str, str)) -> int: raise IndexError(f"Shard not found: {shard_name}") return r[0] - def iter_shard(self, shard_name): + def iter_shard(self, shard_name: (str, str)): dataset_path, shard_name = shard_name if shard_name.endswith(".wsds"): shard_name = shard_name[:-5] diff --git a/wsds/ws_sink.py b/wsds/ws_sink.py index ed8678b..242db44 100644 --- a/wsds/ws_sink.py +++ b/wsds/ws_sink.py @@ -65,6 +65,8 @@ def write_batch(self, b, flush=False): try: record = pyarrow.RecordBatch.from_pylist(b, self._sink_schema if self._sink else None) except Exception: + print(f"Batch data causing serialization error: {repr(b)}") + print(f"Schema: {self._sink_schema}") print(f"Error while serializing: {repr(b)}") raise if self._sink is None: @@ -73,6 +75,8 @@ def write_batch(self, b, flush=False): self.batch_size *= 2 return schema = record.schema.with_metadata({"batch_size": str(len(b))}) + print(f"Initializing WSBatchedSink with batch size {self.batch_size} and schema:\n{schema}") + print(f"First batch data:\n{repr(b)}") self._sink = pyarrow.RecordBatchFileWriter( self.fname, schema, options=pyarrow.ipc.IpcWriteOptions(compression=self.compression) )