torch.fx is a PyTorch module that captures a model and applies transformation for optimization 1.
In recent days, the importance of model optimization is getting more important.
torch.fx enables transparent transformation without touching to the original model implementation, allowing fine-grained model optimization.
Since PyTorch 2.0, it seems TorchDynamo replaces legacy fx.tracer for tracing the model.
This post focuses on existing torch.fx module, and I will post another one regarding TorchDynamo if I have a chance.
Using torch.fx to implement pipeline parallelism #
torch.fx can be used for various purposes, mostly for performance optimization.
For me, I learned torch.fx to make models pipeline parallelizable, utilizing the characteristics of the output of torch.fx: GraphModule.
Pipeline parallelism requires the model to be a type of torch.nn.Sequential; the output of the previous layer is an input of the next layer 2:
torch.fx.GraphModule satisfies this requirement and can be used for pipeline parallelism.
For the rest of the post, I focus on implementing a pipeline parallelism model with torch.fx.
The first implementation of using torch.fx for pipeline parallelism that I have seen is Merak3.
This post has been inspired by Merak’s implementation, which borrows its implementation from FairScale.
fromtransformersimportAutoConfig,AutoModelForCausalLMfromtransformers.utils.fximportsymbolic_trace# it replaces torch.fx.symbolic_tracefromtorch.fximportGraphModuleconfig=AutoConfig.from_pretrained("gpt2")model=AutoModelForCausalLM.from_config(config)traced:GraphModule=symbolic_trace(model)
The type of traced is torch.fx.GraphModule, which has the same structure of model, but each internal modules are replaced with the base class torch.nn.Module:
traced.print_readable()classGraphModule(torch.nn.Module):defforward(self,input_ids:torch.Tensor):# No stacktrace found for following nodessize=input_ids.size()getitem=size[-1]...lm_head=self.lm_head(view_146);view_146=Nonereturn{'logits':lm_head,'past_key_values':((permute_1,permute_2),(permute_5,permute_6),(permute_9,permute_10),(permute_13,permute_14),(permute_17,permute_18),(permute_21,permute_22),(permute_25,permute_26),(permute_29,permute_30),(permute_33,permute_34),(permute_37,permute_38),(permute_41,permute_42),(permute_45,permute_46))}
The traced graph module does not return loss, meaning it cannot be used for training.
To make a loss from GraphModule, we need to pass a labels input.
defsymbolic_trace(model:PreTrainedModel,input_names:Optional[List[str]]=None,disable_check:bool=False,)->GraphModule:"""
Performs symbolic tracing on the model.
Args:
model ([`PretrainedModel`]):
The model to trace.
input_names (`List[str]`, *optional*): <--
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
disable_check (`bool`, *optional*, defaults to `False`):
If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
...
In symbolic_trace, it optionally takes input_names; according to the documents, model.dummy_inputs.keys() are used by default, which does not include labels key:
list(model.dummy_inputs.keys())['input_ids']# Dummy input keys specific for the GPT2 model
Let’s use an input from dataset.
fromdatasetsimportload_datasetraw_dataset=load_dataset("wikitext","wikitext-2-raw-v1")# preprocess the dataset explained in here:# https://insujang.github.io/2023-04-19/using-huggingface-transformers/#loading-a-tokenizer-and-preprocessingdefgroup_texts(samples):...result["labels"]=result["input_ids"].copy()returnresulttokenized_datasets=tokenized_datasets.map(group_texts,batched=True,load_from_cache_file=True)train_dataset=tokenized_datasets["train"]input_names=list(next(iter(train_dataset)).keys())traced=symbolic_trace(model,input_names=input_names)
With preprocessing, input_names now includes 3 items: ['input_ids', 'attention_mask', 'labels']. As it includes labels, symbolic_trace generates loss in the result as follows:
Using a traced torch.fx.GraphModule is simple. Replace model with it:
# This is for training the existing modeltrainer=Trainer(model,training_args,train_dataset=train_dataset,tokenizer=tokenizer,data_collatpr=default_data_collator,eval_dataset=None,compute_metrics=None)TrainOutput(global_step=10,training_loss=9.759646606445312,metrics={'train_runtime':8.3264,'train_samples_per_second':9.608,'train_steps_per_second':1.201,'total_flos':41806725120000.0,'train_loss':9.759646606445312,'epoch':0.03})# Equivalent training but with traced torch.fx.GraphModule modeltrainer=Trainer(traced,training_args,train_dataset=train_dataset,tokenizer=tokenizer,data_collatpr=default_data_collator,eval_dataset=None,compute_metrics=None)TrainOutput(global_step=10,training_loss=9.122493743896484,metrics={'train_runtime':7.6985,'train_samples_per_second':10.392,'train_steps_per_second':1.299,'total_flos':0.0,'train_loss':9.122493743896484,'epoch':0.03})
Disclaimer: the examples below do NOT work in distributed pipeline parallel execution environment since they do not properly pass inputs for sharded model. This post only focuses on how to split the model using torch.fx.
We can arbitrarily split torch.fx.GraphModule to distribute training to multiple GPUs or nodes.
We are going to split the model by transformer layers, following PiPPy’s HF GPT2 split point example.
The example splits the model in two points: transformer.h.N (where N is the number), and transformer.ln_f:
fromtypingimportList,TypefromtransformersimportPretrainedConfigdefget_split_points(config:Type[PretrainedConfig])->List[str]:split_points:List[str]=[]# config: GPT2Config for GPT2 modelforiinrange(config.num_hidden_layers):split_points.append(f"transformer.h.{i}")split_points.append("transformer.ln_f")returnsplit_points
During graph traversal, if we find a node that has a name starting with either transformer.h.N or transformer.ln_f, we split the graph and create another GraphModule.
Graph partitioning iterates the entire nodes twice: first to find a proper location of graph partitioning, and second to create GraphModules based on partitioned information.
We first iterate all nodes to find partitioning location by defining _split_nodes:
importtorch.fxfromtypingimportList,Dict,Tuple,Optionaldef_split_nodes(traced:torch.fx.GraphModule,split_points:List[str])->Tuple[Dict[str,int],Dict[int,List[str]]]:node_name_to_shard_id:Dict[str,int]={}shard_id=0nodes_so_far:List[torch.fx.Node]=[]extra_outputs:Dict[int,List[str]]={}fornodeintraced.graph.nodes:ifnode.opin["placeholder","get_attr","call_function","call_method","call_module",]:node_name_to_shard_id[node.name]=shard_idnodes_so_far.append(node)point=next(filter(lambdap:node.name.startswith(p),split_points),None)ifpoint:# Record outputs that should be used later.# they will be added in return of this shard.outputs=[]fornodeinnodes_so_far:foruserinnode.users.keys():ifuser.namenotinnode_name_to_shard_id:outputs.append(node.name)# Remove duplicateextra_outputs[shard_id]=list(dict.fromkeys(outputs).keys())shard_id+=1split_points.remove(point)elifnode.op=="output":breakassertlen(split_points)==0,"Sharding is not complete."returnnode_name_to_shard_id,extra_outputs
_split_nodes monotonically increases shard_id and maps every nodes to specific shard_id; nodes with the same shard_id are in the same partitioned GraphModule.
The input of _split_nodes can be prepared as:
The main function shard_model internally uses _split_nodes to partition the model into GraphModules:
defshard_model(model:torch.nn.Module,traced:torch.fx.GraphModule,split_points:List[str])->List[torch.fx.GraphModule]:module_list:List[torch.fx.GraphModule]=[]node_name_to_shard_id,extra_outputs=_split_nodes(traced,split_points)prev_shard_id=1000prev_node:Optional[torch.fx.Node]=Noneenv:Dict[str,torch.fx.Node]={}new_graph=torch.fx.Graph()# Iterate all nodesfornodeintraced.graph.nodes:ifnode.nameinnode_name_to_shard_id:current_shard_id=node_name_to_shard_id[node.name]ifprev_shard_id<current_shard_id:assertprev_nodewithnew_graph.inserting_after(prev_node):ifprev_shard_idinextra_outputs:outputs=extra_outputs[prev_shard_id]outputs=tuple([env[i]foriinoutputs])new_graph.output(outputs)else:outputs=tuple(env[prev_node.name])new_graph.output(outputs)# finalize this graph into GraphModule listnew_graph.lint()module_list.append(torch.fx.GraphModule(model,new_graph))# Create a new graphnew_graph=torch.fx.Graph()foroutputinoutputs:# Add all nodes in return of the previous graph to its inputnode_name=env[output.name].namepl_node=new_graph.create_node("placeholder",node_name)env[node_name]=pl_node# Cut is done. Add all nodes into the current graphifnode.opin["placeholder","get_attr","call_function","call_method","call_module",]:# Copy the nodes from the existing graph to the new graph.new_node=new_graph.node_copy(node,lambdax:env[x.name])env[node.name]=new_nodeelifnode.op=="output":# If this is the last node, we should add an output node and add the last graph to the list.assertprev_node,"prev_node cannot be None"withnew_graph.inserting_after(prev_node):new_node=new_graph.node_copy(node,lambdax:env[x.name])new_graph.lint()module_list.append(torch.fx.GraphModule(model,new_graph))breakprev_node=new_nodeprev_shard_id=node_name_to_shard_id[node.name]returnmodule_list
Here, we iterate all nodes again and create a list of GraphModules based on shard_ids.
Based on shard_id, if it is increased, that means the graph should be partitioned there; it finalizes the graph by adding its outputs, checks its integrity with new_graph.lint(), and create a GraphModule.
If some variables defined in the previous sharded modules are used later, they must be forwarded; thus we have to manually add outputs of the previous sharded module (new_graph.output(outputs)) and inputs of the next sharded module (new_graph.create_node("placeholder", ...)).