Skip to main content

Torch FX Transformation and Pipeline Parallelism

Torch fx #

torch.fx is a PyTorch module that captures a model and applies transformation for optimization 1. In recent days, the importance of model optimization is getting more important. torch.fx enables transparent transformation without touching to the original model implementation, allowing fine-grained model optimization.

Since PyTorch 2.0, it seems TorchDynamo replaces legacy fx.tracer for tracing the model.

torchdynamo
This post focuses on existing torch.fx module, and I will post another one regarding TorchDynamo if I have a chance.

Using torch.fx to implement pipeline parallelism #

torch.fx can be used for various purposes, mostly for performance optimization. For me, I learned torch.fx to make models pipeline parallelizable, utilizing the characteristics of the output of torch.fx: GraphModule.

Pipeline parallelism requires the model to be a type of torch.nn.Sequential; the output of the previous layer is an input of the next layer 2:

def forward(self, inputs):
    x = inputs
    for layer in self.layers:
        x = layer(x)
    return x

torch.fx.GraphModule satisfies this requirement and can be used for pipeline parallelism. For the rest of the post, I focus on implementing a pipeline parallelism model with torch.fx.

The first implementation of using torch.fx for pipeline parallelism that I have seen is Merak 3. This post has been inspired by Merak’s implementation, which borrows its implementation from FairScale.

Using Huggingface Transformers and torch.fx #

HF Transformers provides its own fx tracer that wraps torch.fx: we can use it to generate a GraphModule of HF transformers models.

Generating torch.fx.GraphModule #

from transformers import AutoConfig, AutoModelForCausalLM
from transformers.utils.fx import symbolic_trace # it replaces torch.fx.symbolic_trace
from torch.fx import GraphModule

config = AutoConfig.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_config(config)

traced: GraphModule = symbolic_trace(model)

The type of traced is torch.fx.GraphModule, which has the same structure of model, but each internal modules are replaced with the base class torch.nn.Module:

model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
        (0-11): 12 x GPT2Block(...)
        ...
    )
    ...
  )
)
traced

GraphModule(
  (transformer): Module(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): Module(
        (0): Module
        ...
    )
    ...
  )
)

GraphModule includes graph, an IR representation of the model:

for n in traced.graph.nodes:
    print(f'{n.name} = {n.op} target={n.target} args={n.args}')

input_ids = placeholder target=input_ids args=()
size = call_method target=size args=(input_ids,)
getitem = call_function target=<built-in function getitem> args=(size, -1)
view = call_method target=view args=(input_ids, -1, getitem)
size_1 = call_method target=size args=(view,)
...
transformer_ln_f = call_module target=transformer.ln_f args=(add_146,)
view_146 = call_method target=view args=(transformer_ln_f, add_2)
lm_head = call_module target=lm_head args=(view_146,)
output = output target=output args=({'logits': lm_head, 'past_key_values': ((permute_1, permute_2), (permute_5, permute_6), (permute_9, permute_10), (permute_13, permute_14), (permute_17, permute_18), (permute_21, permute_22), (permute_25, permute_26), (permute_29, permute_30), (permute_33, permute_34), (permute_37, permute_38), (permute_41, permute_42), (permute_45, permute_46))},)

which can be modified and used to create a new GraphModule. You can retrieve .code or print_readable() for generated Python code:

traced.code

'\n\n\ndef forward(self, input_ids : torch.Tensor):\n    size = input_ids.size()\n...'
traced.print_readable()

class GraphModule(torch.nn.Module):
    def forward(self, input_ids : torch.Tensor):
        # No stacktrace found for following nodes
        size = input_ids.size()
        getitem = size[-1]
        ...
        lm_head = self.lm_head(view_146);  view_146 = None
        return {'logits': lm_head, 'past_key_values': ((permute_1, permute_2), (permute_5, permute_6), (permute_9, permute_10), (permute_13, permute_14), (permute_17, permute_18), (permute_21, permute_22), (permute_25, permute_26), (permute_29, permute_30), (permute_33, permute_34), (permute_37, permute_38), (permute_41, permute_42), (permute_45, permute_46))}

Generating torch.fx.GraphModule “for Training” #

The traced graph module does not return loss, meaning it cannot be used for training. To make a loss from GraphModule, we need to pass a labels input.

def symbolic_trace(
    model: PreTrainedModel,
    input_names: Optional[List[str]] = None,
    disable_check: bool = False,
) -> GraphModule:
    """
    Performs symbolic tracing on the model.

    Args:
        model ([`PretrainedModel`]):
            The model to trace.
        input_names (`List[str]`, *optional*): <--
            The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
        disable_check (`bool`, *optional*, defaults to `False`):
            If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
        ...

In symbolic_trace, it optionally takes input_names; according to the documents, model.dummy_inputs.keys() are used by default, which does not include labels key:

list(model.dummy_inputs.keys())
['input_ids'] # Dummy input keys specific for the GPT2 model

Let’s use an input from dataset.

from datasets import load_dataset
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
# preprocess the dataset explained in here:
# https://insujang.github.io/2023-04-19/using-huggingface-transformers/#loading-a-tokenizer-and-preprocessing
def group_texts(samples):
    ...
    result["labels"] = result["input_ids"].copy()
    return result

tokenized_datasets = tokenized_datasets.map(
    group_texts, batched=True, load_from_cache_file=True
)

train_dataset = tokenized_datasets["train"]
input_names = list(next(iter(train_dataset)).keys())

traced = symbolic_trace(model, input_names=input_names)

With preprocessing, input_names now includes 3 items: ['input_ids', 'attention_mask', 'labels']. As it includes labels, symbolic_trace generates loss in the result as follows:

traced.print_readable()

class GraphModule(torch.nn.Module):
    def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor, labels : torch.Tensor):
        ...
        crossentropyloss_0 = self.crossentropyloss_0(view_148, view_149);  view_148 = view_149 = None
        return {'loss': crossentropyloss_0, 'logits': lm_head, 'past_key_values': ((permute_1, permute_2), (permute_5, permute_6), (permute_9, permute_10), (permute_13, permute_14), (permute_17, permute_18), (permute_21, permute_22), (permute_25, permute_26), (permute_29, permute_30), (permute_33, permute_34), (permute_37, permute_38), (permute_41, permute_42), (permute_45, permute_46))}

which can be used for training.

Using a traced torch.fx.GraphModule is simple. Replace model with it:

# This is for training the existing model
trainer = Trainer(
    model,
    training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    data_collatpr=default_data_collator,
    eval_dataset=None,
    compute_metrics=None
)

TrainOutput(global_step=10, training_loss=9.759646606445312, metrics={'train_runtime': 8.3264, 'train_samples_per_second': 9.608, 'train_steps_per_second': 1.201, 'total_flos': 41806725120000.0, 'train_loss': 9.759646606445312, 'epoch': 0.03})

# Equivalent training but with traced torch.fx.GraphModule model
trainer = Trainer(
    traced,
    training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    data_collatpr=default_data_collator,
    eval_dataset=None,
    compute_metrics=None
)

TrainOutput(global_step=10, training_loss=9.122493743896484, metrics={'train_runtime': 7.6985, 'train_samples_per_second': 10.392, 'train_steps_per_second': 1.299, 'total_flos': 0.0, 'train_loss': 9.122493743896484, 'epoch': 0.03})

Implementing Pipeline Parallelism with torch.fx #

This section is inspired by Fairscale’s model sharding implementation and PiPPy’s split point implementation.

Disclaimer: the examples below do NOT work in distributed pipeline parallel execution environment since they do not properly pass inputs for sharded model. This post only focuses on how to split the model using torch.fx.

We can arbitrarily split torch.fx.GraphModule to distribute training to multiple GPUs or nodes. We are going to split the model by transformer layers, following PiPPy’s HF GPT2 split point example. The example splits the model in two points: transformer.h.N (where N is the number), and transformer.ln_f:

from typing import List, Type
from transformers import PretrainedConfig
def get_split_points(config: Type[PretrainedConfig]) -> List[str]:
    split_points: List[str] = []
    # config: GPT2Config for GPT2 model
    for i in range(config.num_hidden_layers):
        split_points.append(f"transformer.h.{i}")
    split_points.append("transformer.ln_f")
    return split_points

During graph traversal, if we find a node that has a name starting with either transformer.h.N or transformer.ln_f, we split the graph and create another GraphModule.

Graph partitioning iterates the entire nodes twice: first to find a proper location of graph partitioning, and second to create GraphModules based on partitioned information. We first iterate all nodes to find partitioning location by defining _split_nodes:

import torch.fx
from typing import List, Dict, Tuple, Optional

def _split_nodes(
        traced: torch.fx.GraphModule, split_points: List[str]
    ) -> Tuple[Dict[str, int], Dict[int, List[str]]]:
    node_name_to_shard_id: Dict[str, int] = {}
    shard_id = 0

    nodes_so_far: List[torch.fx.Node] = []
    extra_outputs: Dict[int, List[str]] = {}

    for node in traced.graph.nodes:
        if node.op in [
            "placeholder",
            "get_attr",
            "call_function",
            "call_method",
            "call_module",
        ]:
            node_name_to_shard_id[node.name] = shard_id
            nodes_so_far.append(node)

            point = next(filter(lambda p: node.name.startswith(p), split_points), None)
            if point:
                # Record outputs that should be used later.
                # they will be added in return of this shard.
                outputs = []
                for node in nodes_so_far:
                    for user in node.users.keys():
                        if user.name not in node_name_to_shard_id:
                            outputs.append(node.name)

                # Remove duplicate
                extra_outputs[shard_id] = list(dict.fromkeys(outputs).keys())

                shard_id += 1
                split_points.remove(point)
    
        elif node.op == "output":
            break

    assert len(split_points) == 0, "Sharding is not complete."

    return node_name_to_shard_id, extra_outputs

_split_nodes monotonically increases shard_id and maps every nodes to specific shard_id; nodes with the same shard_id are in the same partitioned GraphModule. The input of _split_nodes can be prepared as:

traced = symbolic_trace(model, input_names=input_names)
split_points = [p.replace(".", "_") for p in get_split_points(config)]

The main function shard_model internally uses _split_nodes to partition the model into GraphModules:

def shard_model(
        model: torch.nn.Module,
        traced: torch.fx.GraphModule,
        split_points: List[str]
    ) -> List[torch.fx.GraphModule]:
    module_list: List[torch.fx.GraphModule] = []
    
    node_name_to_shard_id, extra_outputs = _split_nodes(traced, split_points)

    prev_shard_id = 1000
    prev_node: Optional[torch.fx.Node] = None
    env: Dict[str, torch.fx.Node] = {}

    new_graph = torch.fx.Graph()
    
    # Iterate all nodes
    for node in traced.graph.nodes:
        if node.name in node_name_to_shard_id:
            current_shard_id = node_name_to_shard_id[node.name]
            if prev_shard_id < current_shard_id:
                assert prev_node

                with new_graph.inserting_after(prev_node):
                    if prev_shard_id in extra_outputs:
                        outputs = extra_outputs[prev_shard_id]
                        outputs = tuple([env[i] for i in outputs])
                        new_graph.output(outputs)
                    else:
                        outputs = tuple(env[prev_node.name])
                        new_graph.output(outputs)
                
                # finalize this graph into GraphModule list
                new_graph.lint()
                module_list.append(torch.fx.GraphModule(model, new_graph))

                # Create a new graph
                new_graph = torch.fx.Graph()
                for output in outputs:
                    # Add all nodes in return of the previous graph to its input
                    node_name = env[output.name].name
                    pl_node = new_graph.create_node("placeholder", node_name)
                    env[node_name] = pl_node

        # Cut is done. Add all nodes into the current graph
        if node.op in [
            "placeholder",
            "get_attr",
            "call_function",
            "call_method",
            "call_module",
        ]:
            # Copy the nodes from the existing graph to the new graph.
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
        elif node.op == "output":
            # If this is the last node, we should add an output node and add the last graph to the list.
            assert prev_node, "prev_node cannot be None"
            with new_graph.inserting_after(prev_node):
                new_node = new_graph.node_copy(node, lambda x: env[x.name])
            new_graph.lint()
            module_list.append(torch.fx.GraphModule(model, new_graph))
            break

        prev_node = new_node
        prev_shard_id = node_name_to_shard_id[node.name]

    return module_list

Here, we iterate all nodes again and create a list of GraphModules based on shard_ids. Based on shard_id, if it is increased, that means the graph should be partitioned there; it finalizes the graph by adding its outputs, checks its integrity with new_graph.lint(), and create a GraphModule. If some variables defined in the previous sharded modules are used later, they must be forwarded; thus we have to manually add outputs of the previous sharded module (new_graph.output(outputs)) and inputs of the next sharded module (new_graph.create_node("placeholder", ...)).

See the full notebook here.


  1. Torch.fx: Practical Program Capture and Transformation for Deep Learning in Python ↩︎

  2. DeepSpeed: Expressing Pipeline Models ↩︎

  3. Merak: An Efficient Distributed DNN Training Framework With Automated 3D Parallelism for Giant Foundation Models ↩︎