How do you save a Tensorflow dataset to a file?
Solution 1
TFRecordWriter
seems to be the most convenient option, but unfortunately it can only write datasets with a single tensor per element. Here are a couple of workarounds you can use. First, since all your tensors have the same type and similar shape, you can concatenate them all into one, and split them back later on load:
import tensorflow as tf
# Write
a = tf.zeros((100, 512), tf.int32)
ds = tf.data.Dataset.from_tensor_slices((a, a, a, a[:, 0]))
print(ds)
# <TensorSliceDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>
def write_map_fn(x1, x2, x3, x4):
return tf.io.serialize_tensor(tf.concat([x1, x2, x3, tf.expand_dims(x4, -1)], -1))
ds = ds.map(write_map_fn)
writer = tf.data.experimental.TFRecordWriter('mydata.tfrecord')
writer.write(ds)
# Read
def read_map_fn(x):
xp = tf.io.parse_tensor(x, tf.int32)
# Optionally set shape
xp.set_shape([1537]) # Do `xp.set_shape([None, 1537])` if using batches
# Use `x[:, :512], ...` if using batches
return xp[:512], xp[512:1024], xp[1024:1536], xp[-1]
ds = tf.data.TFRecordDataset('mydata.tfrecord').map(read_map_fn)
print(ds)
# <MapDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>
But, more generally, you can simply have a separate file per tensor and then read them all:
import tensorflow as tf
# Write
a = tf.zeros((100, 512), tf.int32)
ds = tf.data.Dataset.from_tensor_slices((a, a, a, a[:, 0]))
for i, _ in enumerate(ds.element_spec):
ds_i = ds.map(lambda *args: args[i]).map(tf.io.serialize_tensor)
writer = tf.data.experimental.TFRecordWriter(f'mydata.{i}.tfrecord')
writer.write(ds_i)
# Read
NUM_PARTS = 4
parts = []
def read_map_fn(x):
return tf.io.parse_tensor(x, tf.int32)
for i in range(NUM_PARTS):
parts.append(tf.data.TFRecordDataset(f'mydata.{i}.tfrecord').map(read_map_fn))
ds = tf.data.Dataset.zip(tuple(parts))
print(ds)
# <ZipDataset shapes: (<unknown>, <unknown>, <unknown>, <unknown>), types: (tf.int32, tf.int32, tf.int32, tf.int32)>
It is possible to have the whole dataset in a single file with multiple separate tensors per element, namely as a file of TFRecords containing tf.train.Example
s, but I don't know if there is a way to create those within TensorFlow, that is, without having to get the data out of the dataset into Python and then write it to the records file.
Solution 2
An incident was open on GitHUb and it appears there's a new feature available in TF 2.3 to write to disk :
https://www.tensorflow.org/api_docs/python/tf/data/experimental/save https://www.tensorflow.org/api_docs/python/tf/data/experimental/load
I haven't tested this features yet but it seems to be doing what you want.
Solution 3
To add on Yoan's answer:
the tf.experimental.save() and load() API works well. You also need to MANUALLY save the ds.element_spec to disk to be able to load() later / within a different context.
Pickling works well for me:
1- Saving:
tf.data.experimental.save(
ds, tf_data_path, compression='GZIP'
)
with open(tf_data_path + '/element_spec', 'wb') as out_: # also save the element_spec to disk for future loading
pickle.dump(ds.element_spec, out_)
2- For loading, you need both the folder path with the tf shards and the element_spec that we manually pickled
with open(tf_data_path + '/element_spec', 'rb') as in_:
es = pickle.load(in_)
loaded = tf.data.experimental.load(
tf_data_path, es, compression='GZIP'
)
Solution 4
I have been working on this issus as well and so far I have written the following util (as to be found in my repo as well)
def cache_with_tf_record(filename: Union[str, pathlib.Path]) -> Callable[[tf.data.Dataset], tf.data.TFRecordDataset]:
"""
Similar to tf.data.Dataset.cache but writes a tf record file instead. Compared to base .cache method, it also insures that the whole
dataset is cached
"""
def _cache(dataset):
if not isinstance(dataset.element_spec, dict):
raise ValueError(f"dataset.element_spec should be a dict but is {type(dataset.element_spec)} instead")
Path(filename).parent.mkdir(parents=True, exist_ok=True)
with tf.io.TFRecordWriter(str(filename)) as writer:
for sample in dataset.map(transform(**{name: tf.io.serialize_tensor for name in dataset.element_spec.keys()})):
writer.write(
tf.train.Example(
features=tf.train.Features(
feature={
key: tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))
for key, value in sample.items()
}
)
).SerializeToString()
)
return (
tf.data.TFRecordDataset(str(filename), num_parallel_reads=tf.data.experimental.AUTOTUNE)
.map(
partial(
tf.io.parse_single_example,
features={name: tf.io.FixedLenFeature((), tf.string) for name in dataset.element_spec.keys()},
),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
.map(
transform(
**{name: partial(tf.io.parse_tensor, out_type=spec.dtype) for name, spec in dataset.element_spec.items()}
)
)
.map(
transform(**{name: partial(tf.ensure_shape, shape=spec.shape) for name, spec in dataset.element_spec.items()})
)
)
return _cache
With this util, I can do:
dataset.apply(cache_with_tf_record("filename")).map(...)
and also load directly the dataset for later use with only the second part of the util.
I am still working on it so it may change later on, especially to serialize with the correct types instead of all bytes to save space (I guess).
Vivek Subramanian
Currently a post-doc in machine learning, with focus on NLP. PhD in neuroengineering + computational neuroscience.
Updated on June 13, 2022Comments
-
Vivek Subramanian almost 2 years
There are at least two more questions like this on SO but not a single one has been answered.
I have a dataset of the form:
<TensorSliceDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>
and another of the form:
<BatchDataset shapes: ((None, 512), (None, 512), (None, 512), (None,)), types: (tf.int32, tf.int32, tf.int32, tf.int32)>
I have looked and looked but I can't find the code to save these datasets to files that can be loaded later. The closest I got was this page in the TensorFlow docs, which suggests serializing the tensors using
tf.io.serialize_tensor
and then writing them to a file usingtf.data.experimental.TFRecordWriter
.However, when I tried this using the code:
dataset.map(tf.io.serialize_tensor) writer = tf.data.experimental.TFRecordWriter('mydata.tfrecord') writer.write(dataset)
I get an error on the first line:
TypeError: serialize_tensor() takes from 1 to 2 positional arguments but 4 were given
How can I modify the above (or do something else) to accomplish my goal?
-
Vivek Subramanian almost 4 yearsI'm going to give this a shot shortly and will get back to you. Am very grateful for your help! Have been struggling to get answers to TF-related questions on here lately.
-
Vivek Subramanian almost 4 yearsThank you so much! You have no idea how helpful that was!
-
rodrigo-silveira over 3 yearsI'm experimenting with this. Couple of clunky things, but easy to get around: 1. If you specify GZIP compression, but don't make it obvious that it's gzipped, when you try to load it, if you don't specify
compression='GZIP'
, it'll load he data without complains, but when you try to use it, it'll say "data corrupted". Not obvious why it's corrupted. 2. You need to specify a tf.TypeSpec. Would be nice if tf.data.experimental.save created the required protobuf for you so you wouldn't need to worry about it. -
rodrigo-silveira over 3 yearsCorrection from the above comment: you need to specify a
tf.TensorSpec
(slightly different). But what is unfortunate from my brief experimenting with it: the size of the file is huge. I have a parquet file (gzipped) with 7M rows of mostlyuint16
s, at 277MB. Thetf.dataset.experimental.save
artifact (several directories and "shards"), which has fewer "columns" than the parquet file, but is also gzipped, is over 600MB. -
greedybuddha over 2 yearsWith TF 2.5+ you can use tf.data.experimental.save(...) and load without the element spec. But with older versions of TF (like 2.4-) this seems to be the approach.
-
gary69 almost 2 yearsTensorflow 2.9 docs say this is deprecated tensorflow.org/api_docs/python/tf/data/experimental/save
-
Agustin Barrachina almost 2 yearshere the code on how to use it. It was too long for a comment XD