Torch FX Transformation and Pipeline Parallelism
Table of Contents
Torch fx #
torch.fx
is a PyTorch module that captures a model and applies transformation for optimization 1.
In recent days, the importance of model optimization is getting more important.
torch.fx
enables transparent transformation without touching to the original model implementation, allowing fine-grained model optimization.
Since PyTorch 2.0, it seems TorchDynamo replaces legacy fx.tracer for tracing the model. This post focuses on existing
torch.fx
module, and I will post another one regarding TorchDynamo if I have a chance.
Using torch.fx
to implement pipeline parallelism #
torch.fx
can be used for various purposes, mostly for performance optimization.
For me, I learned torch.fx
to make models pipeline parallelizable, utilizing the characteristics of the output of torch.fx
: GraphModule
.
Pipeline parallelism requires the model to be a type of torch.nn.Sequential
; the output of the previous layer is an input of the next layer 2:
def forward(self, inputs):
x = inputs
for layer in self.layers:
x = layer(x)
return x
torch.fx.GraphModule
satisfies this requirement and can be used for pipeline parallelism.
For the rest of the post, I focus on implementing a pipeline parallelism model with torch.fx
.
The first implementation of using
torch.fx
for pipeline parallelism that I have seen is Merak 3. This post has been inspired by Merak’s implementation, which borrows its implementation from FairScale.
Using Huggingface Transformers and torch.fx
#
HF Transformers provides its own fx tracer that wraps torch.fx
: we can use it to generate a GraphModule
of HF transformers models.
Generating torch.fx.GraphModule
#
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.utils.fx import symbolic_trace # it replaces torch.fx.symbolic_trace
from torch.fx import GraphModule
config = AutoConfig.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_config(config)
traced: GraphModule = symbolic_trace(model)
The type of traced
is torch.fx.GraphModule
, which has the same structure of model
, but each internal modules are replaced with the base class torch.nn.Module
:
model
GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50257, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0-11): 12 x GPT2Block(...)
...
)
...
)
)
traced
GraphModule(
(transformer): Module(
(wte): Embedding(50257, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): Module(
(0): Module
...
)
...
)
)
GraphModule
includes graph
, an IR representation of the model:
for n in traced.graph.nodes:
print(f'{n.name} = {n.op} target={n.target} args={n.args}')
input_ids = placeholder target=input_ids args=()
size = call_method target=size args=(input_ids,)
getitem = call_function target=<built-in function getitem> args=(size, -1)
view = call_method target=view args=(input_ids, -1, getitem)
size_1 = call_method target=size args=(view,)
...
transformer_ln_f = call_module target=transformer.ln_f args=(add_146,)
view_146 = call_method target=view args=(transformer_ln_f, add_2)
lm_head = call_module target=lm_head args=(view_146,)
output = output target=output args=({'logits': lm_head, 'past_key_values': ((permute_1, permute_2), (permute_5, permute_6), (permute_9, permute_10), (permute_13, permute_14), (permute_17, permute_18), (permute_21, permute_22), (permute_25, permute_26), (permute_29, permute_30), (permute_33, permute_34), (permute_37, permute_38), (permute_41, permute_42), (permute_45, permute_46))},)
which can be modified and used to create a new GraphModule
. You can retrieve .code
or print_readable()
for generated Python code:
traced.code
'\n\n\ndef forward(self, input_ids : torch.Tensor):\n size = input_ids.size()\n...'
traced.print_readable()
class GraphModule(torch.nn.Module):
def forward(self, input_ids : torch.Tensor):
# No stacktrace found for following nodes
size = input_ids.size()
getitem = size[-1]
...
lm_head = self.lm_head(view_146); view_146 = None
return {'logits': lm_head, 'past_key_values': ((permute_1, permute_2), (permute_5, permute_6), (permute_9, permute_10), (permute_13, permute_14), (permute_17, permute_18), (permute_21, permute_22), (permute_25, permute_26), (permute_29, permute_30), (permute_33, permute_34), (permute_37, permute_38), (permute_41, permute_42), (permute_45, permute_46))}
Generating torch.fx.GraphModule
“for Training” #
The traced
graph module does not return loss, meaning it cannot be used for training.
To make a loss from GraphModule
, we need to pass a labels
input.
def symbolic_trace(
model: PreTrainedModel,
input_names: Optional[List[str]] = None,
disable_check: bool = False,
) -> GraphModule:
"""
Performs symbolic tracing on the model.
Args:
model ([`PretrainedModel`]):
The model to trace.
input_names (`List[str]`, *optional*): <--
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
disable_check (`bool`, *optional*, defaults to `False`):
If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
...
In symbolic_trace
, it optionally takes input_names
; according to the documents, model.dummy_inputs.keys()
are used by default, which does not include labels
key:
list(model.dummy_inputs.keys())
['input_ids'] # Dummy input keys specific for the GPT2 model
Let’s use an input from dataset.
from datasets import load_dataset
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
# preprocess the dataset explained in here:
# https://insujang.github.io/2023-04-19/using-huggingface-transformers/#loading-a-tokenizer-and-preprocessing
def group_texts(samples):
...
result["labels"] = result["input_ids"].copy()
return result
tokenized_datasets = tokenized_datasets.map(
group_texts, batched=True, load_from_cache_file=True
)
train_dataset = tokenized_datasets["train"]
input_names = list(next(iter(train_dataset)).keys())
traced = symbolic_trace(model, input_names=input_names)
With preprocessing, input_names
now includes 3 items: ['input_ids', 'attention_mask', 'labels']
. As it includes labels
, symbolic_trace
generates loss
in the result as follows:
traced.print_readable()
class GraphModule(torch.nn.Module):
def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor, labels : torch.Tensor):
...
crossentropyloss_0 = self.crossentropyloss_0(view_148, view_149); view_148 = view_149 = None
return {'loss': crossentropyloss_0, 'logits': lm_head, 'past_key_values': ((permute_1, permute_2), (permute_5, permute_6), (permute_9, permute_10), (permute_13, permute_14), (permute_17, permute_18), (permute_21, permute_22), (permute_25, permute_26), (permute_29, permute_30), (permute_33, permute_34), (permute_37, permute_38), (permute_41, permute_42), (permute_45, permute_46))}
which can be used for training.
Using a traced torch.fx.GraphModule
is simple. Replace model with it:
# This is for training the existing model
trainer = Trainer(
model,
training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collatpr=default_data_collator,
eval_dataset=None,
compute_metrics=None
)
TrainOutput(global_step=10, training_loss=9.759646606445312, metrics={'train_runtime': 8.3264, 'train_samples_per_second': 9.608, 'train_steps_per_second': 1.201, 'total_flos': 41806725120000.0, 'train_loss': 9.759646606445312, 'epoch': 0.03})
# Equivalent training but with traced torch.fx.GraphModule model
trainer = Trainer(
traced,
training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collatpr=default_data_collator,
eval_dataset=None,
compute_metrics=None
)
TrainOutput(global_step=10, training_loss=9.122493743896484, metrics={'train_runtime': 7.6985, 'train_samples_per_second': 10.392, 'train_steps_per_second': 1.299, 'total_flos': 0.0, 'train_loss': 9.122493743896484, 'epoch': 0.03})
Implementing Pipeline Parallelism with torch.fx
#
This section is inspired by Fairscale’s model sharding implementation and PiPPy’s split point implementation.
Disclaimer: the examples below do NOT work in distributed pipeline parallel execution environment since they do not properly pass inputs for sharded model. This post only focuses on how to split the model using
torch.fx
.
We can arbitrarily split torch.fx.GraphModule
to distribute training to multiple GPUs or nodes.
We are going to split the model by transformer layers, following PiPPy’s HF GPT2 split point example.
The example splits the model in two points: transformer.h.N
(where N
is the number), and transformer.ln_f
:
from typing import List, Type
from transformers import PretrainedConfig
def get_split_points(config: Type[PretrainedConfig]) -> List[str]:
split_points: List[str] = []
# config: GPT2Config for GPT2 model
for i in range(config.num_hidden_layers):
split_points.append(f"transformer.h.{i}")
split_points.append("transformer.ln_f")
return split_points
During graph traversal, if we find a node that has a name starting with either transformer.h.N
or transformer.ln_f
, we split the graph and create another GraphModule.
Graph partitioning iterates the entire nodes twice: first to find a proper location of graph partitioning, and second to create GraphModules based on partitioned information.
We first iterate all nodes to find partitioning location by defining _split_nodes
:
import torch.fx
from typing import List, Dict, Tuple, Optional
def _split_nodes(
traced: torch.fx.GraphModule, split_points: List[str]
) -> Tuple[Dict[str, int], Dict[int, List[str]]]:
node_name_to_shard_id: Dict[str, int] = {}
shard_id = 0
nodes_so_far: List[torch.fx.Node] = []
extra_outputs: Dict[int, List[str]] = {}
for node in traced.graph.nodes:
if node.op in [
"placeholder",
"get_attr",
"call_function",
"call_method",
"call_module",
]:
node_name_to_shard_id[node.name] = shard_id
nodes_so_far.append(node)
point = next(filter(lambda p: node.name.startswith(p), split_points), None)
if point:
# Record outputs that should be used later.
# they will be added in return of this shard.
outputs = []
for node in nodes_so_far:
for user in node.users.keys():
if user.name not in node_name_to_shard_id:
outputs.append(node.name)
# Remove duplicate
extra_outputs[shard_id] = list(dict.fromkeys(outputs).keys())
shard_id += 1
split_points.remove(point)
elif node.op == "output":
break
assert len(split_points) == 0, "Sharding is not complete."
return node_name_to_shard_id, extra_outputs
_split_nodes
monotonically increases shard_id
and maps every nodes to specific shard_id
; nodes with the same shard_id
are in the same partitioned GraphModule
.
The input of _split_nodes
can be prepared as:
traced = symbolic_trace(model, input_names=input_names)
split_points = [p.replace(".", "_") for p in get_split_points(config)]
The main function shard_model
internally uses _split_nodes
to partition the model into GraphModule
s:
def shard_model(
model: torch.nn.Module,
traced: torch.fx.GraphModule,
split_points: List[str]
) -> List[torch.fx.GraphModule]:
module_list: List[torch.fx.GraphModule] = []
node_name_to_shard_id, extra_outputs = _split_nodes(traced, split_points)
prev_shard_id = 1000
prev_node: Optional[torch.fx.Node] = None
env: Dict[str, torch.fx.Node] = {}
new_graph = torch.fx.Graph()
# Iterate all nodes
for node in traced.graph.nodes:
if node.name in node_name_to_shard_id:
current_shard_id = node_name_to_shard_id[node.name]
if prev_shard_id < current_shard_id:
assert prev_node
with new_graph.inserting_after(prev_node):
if prev_shard_id in extra_outputs:
outputs = extra_outputs[prev_shard_id]
outputs = tuple([env[i] for i in outputs])
new_graph.output(outputs)
else:
outputs = tuple(env[prev_node.name])
new_graph.output(outputs)
# finalize this graph into GraphModule list
new_graph.lint()
module_list.append(torch.fx.GraphModule(model, new_graph))
# Create a new graph
new_graph = torch.fx.Graph()
for output in outputs:
# Add all nodes in return of the previous graph to its input
node_name = env[output.name].name
pl_node = new_graph.create_node("placeholder", node_name)
env[node_name] = pl_node
# Cut is done. Add all nodes into the current graph
if node.op in [
"placeholder",
"get_attr",
"call_function",
"call_method",
"call_module",
]:
# Copy the nodes from the existing graph to the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
elif node.op == "output":
# If this is the last node, we should add an output node and add the last graph to the list.
assert prev_node, "prev_node cannot be None"
with new_graph.inserting_after(prev_node):
new_node = new_graph.node_copy(node, lambda x: env[x.name])
new_graph.lint()
module_list.append(torch.fx.GraphModule(model, new_graph))
break
prev_node = new_node
prev_shard_id = node_name_to_shard_id[node.name]
return module_list
Here, we iterate all nodes again and create a list of GraphModules
based on shard_id
s.
Based on shard_id
, if it is increased, that means the graph should be partitioned there; it finalizes the graph by adding its outputs, checks its integrity with new_graph.lint()
, and create a GraphModule
.
If some variables defined in the previous sharded modules are used later, they must be forwarded; thus we have to manually add outputs of the previous sharded module (new_graph.output(outputs)
) and inputs of the next sharded module (new_graph.create_node("placeholder", ...)
).