1 year ago
#388990
photon1981
Serialize TF graph after XLA compilation
I was wondering if there is a way to serialize a TF graph after compiling with XLA. Typically we can serialize a graph as a proto file and load it afterwards. However, the graph execution is slow the first time we run it after loading presumably due to JIT compilation. Is there a way to store the graph so that there is no performance overhead during the first run after loading it?
Consider the following example:
class ModelSaver(tf.Module):
@tf.function(jit_compile=True)
def fn(self, x, y):
## do something that is computationally expensive
return result
x = np.ones(shape=(1000, 1000))
y = np. ones(shape=(1000, 1000))
model = ModelSaver()
## First run is slow
model.fn(x, y)
## Second run is fast
model.fn(x, y)
# Now serialize the graph
tf.saved_model.save(graph, './tmp/graph')
## Now restore the graph
restored = tf.saved_model.load('./tmp/graph')
## The first run is again slow
restored.fn(x, y)
## The second run is fast
restored.fn(x, y)
So the question: Is there a way to serialize the TF graph in a way that the first run is also fast after restoring it?
In principle the graph was already compiled using XLA, Can we store it in some format so that XLA doesn’t have to compile it again?
Any help or guidance would be really helpful.
tensorflow
0 Answers
Your Answer