XLA TensorFlow Tutorial
Getting Started
Flags to use: os.environ[‘TF_XLA_FLAGS’] = ‘ –tf_xla_auto_jit=2’ os.environ[‘XLA_FLAGS’] = ‘–xla_dump_hlo_as_dot –xla_dump_hlo_as_text –xla_dump_to=xla_dump/’
Interpreting an XLA Compiled Toy Model
Reading an XLA graph:
#! TODO run dummy - -xla and put the graph here / schedule follow up XLA meeting next monday? #! TODO attach the code for the dummy script #! -> get the code to wrap things in tf.function from jack -> am I able to use the sequential model with tf.function? might as well get rid of it because you will need to subclass model to do something with a loop
Each graph will show both the operations and tensor shape handling like padding, reshaping, and broadcasting. There is a root node at the end of each graph for the output of that subgraph.
XLA Toy Example
#todo add some pictures or examples of full filenames
If we just enable the dotfiles to be created with the --xla_dump_hlo_as_dot
flag, our toy model generates 18 differnt files:
- 4 .dot files with numeric prefixes. XLA has determined that there are two clusters,
cluster_0
andcluster_1
, and each cluster has a graph for before and after optimization. - 8 .ll files. Of the 4 files per cluster, we have before and after representations of the graph with and without constant folding/propagation.
- 2 .mlir files, one per cluster
- 2 .ptx files, one per cluster
- 2 .thunk_schedule files, one per cluster
Toy Model
Each XLA cluster will get its own
“In the case of toy models, there is a clear correspondence between module_0031 and cluster_0, and module_0045 and cluster_1. This pattern suggests that each module might correspond to a single cluster.” -chatGPT
module, computation, stage
V5 Model
Profiling our full v5 model gives a few different groups of html files
- get_cam_pos (patch extraction?)
- do_preprocess
- do_s0
- do_s1
- do_s2
Questions
-
what does occupancy tell you? is it meaningful when we are memory bandwidth? given a certain kernel, how can i optimize the shape of the tensor to increase the occupancy of the kernel? why are the top FN kernels by time at 66% occupancy and what would we expect them to be? is it the same on the g5s as it is my laptop? https://docs.nvidia.com/gameworks/content/developertools/desktop/analysis/report/cudaexperiments/kernellevel/achievedoccupancy.htm
-
why does making the focalnet deeper change the ratio of precomputed_conv to wgrad kernels (blue and pink)
Filenames
The filename appears to be quite descriptive, containing several key pieces of information:
“Module name: module_0000 Computation name: a_inference_get_cam_pos_17 Configuration and optimization flags: XlaMustCompile_true_config_proto_6001324581131673121_executor_type_11160318154034397263__21.before_optimizations This means:
module_0000: This likely indicates the index of the module within your TensorFlow model. a_inference_get_cam_pos_17: This suggests the function or layer name being dumped. In this case, it looks like an inference function related to camera position. XlaMustCompile_true: Indicates that XLA compilation was mandatory. config_proto_6001324581131673121: Configuration protocol identifier. executor_type_11160318154034397263__21: Executor type identifier and possibly a version or sequence number. before_optimizations: Indicates that this dump was taken before optimizations were applied.” -chatGPT
Extensions
Using different output flags will dump different information, but they all seek to explain what the structure of the graph was originally and how it gets updated after one or many passes
Created with --xla_dump_hlo_as_dot
flag
Created with --xla_dump_hlo_as_text
flag
Created with --xla_dump_hlo_as_html
flag
.ll files - low level Intermediate Representations (IR) created by LLVM
.mlir files - higher level intermediate representations
.dot files - the control flow for the graph. This file is in SSA and can be compiled into a png using graphviz with the command
dot -Tpng path/to/dotfile.dot -o path/to/output.png
.ptx files / short for parallel thread execution, these files show what is essentially GPU assembly code .thunk_schedule files - the execution order of low level operations buffer.txt files - memory allocations for different ops
*TODO - add section on gpu lowering as it relates to ptx
Finding Bottlenecks
The name of the game is to use the static compute graph and the live sequence of ops in the trace viewer to triangulate which parts of the code are bottlenecks.
Here is the whole process of finding oddities and optimizations in your XLA graph. You can start from either end.
tensorboard profiling - trace viewer and look at the xla ops. The forwards and backwards passes should be in the op name. If there is anything taking a lot longer than expected, find the name of the XLA op in the trace viewer and go look for it in the html file. bonus: kernel stats page
XLA dump as html and text –> look at the post optimized graph for things that look expensive
-ir-no-opt and ir-with-opt signify non-optimized and optimized IR, respectively. For a given cluster, examining the changes between these files shows the effect that the optimization had. -noconst - constant folding or propagation has been disabled
What determines whether auto clustering or manual clustering is used?
Manual clustering means you add a tf.function() decorator around certain bits of code that you want to lump together, and tell TensorFlow to use the XLA jit compiler while you do it. All together, this looks like
@tf.function(jit_compile=True)
Auto clustering can be enabled with an environmental variable, which looks like
Misc
XLA Modules hold XLA computations
tf.test.is_built_with_xla()
- to make sure the version of tensorflow you’re using is even able to use XLA
TF_CPP_MIN_VLOG_LEVEL` - going from 0 to 3 with higher numbers reporting more information
“There is always a module_ prefix, which indicates that this is the graph for an HLO Module. The first XXXX is the HLO module’s unique ID, generated by XLA. The HLO modules have CamelCase class names by convention. For the file names, these are converted to snake_case.” 1
References
Resources
Clustering with tf.function bare bones tutorial: https://openxla.org/xla/tf2xla/tutorials/jit_compile Autoclustering tutorial https://openxla.org/xla/tf2xla/tutorials/autoclustering_xla
- pretty good talk on the matter https://openxla.org/xla/tf2xla#talks
-
https://docs.graphcore.ai/projects/tensorflow-user-guide/en/latest/tensorflow/logging.html ↩