"Adapter" refers to a set of newly introduced weights, typically within the layers of a transformer model. Adapters provide an alternative to fully fine-tuning the model for each downstream task, while maintaining performance. They also have the added benefit of requiring as little as 1MB of storage space per task! Learn More!
Adapters, being self-contained moduar units, allow for easy extension and composition. This opens up opportunities to compose adapters to solve new tasks. Learn More!
AdapterHub builds on the HuggingFace transformers framework, requiring as little as two additional lines of code to train adapters for a downstream task.
Loading existing adapters from our repository is as simple as adding one additional line of code:
from adapters import AutoAdapterModel model = AutoAdapterModel.from_pretrained("bert-base-uncased") model.load_adapter("sentiment/sst-2@ukp") model.set_active_adapters("sst-2")
The SST adapter is light-weight: it is only 3MB! At the same time, it achieves results that are on-par with fully fine-tuned BERT. We can now leverage SST adapter to predict the sentiment of sentences:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") tokens = tokenizer.tokenize("AdapterHub is awesome!") input_tensor = torch.tensor([ tokenizer.convert_tokens_to_ids(tokens) ]) outputs = model(input_tensor)
Training a new task adapter requires only few modifications compared to fully fine-tuning a model with Hugging Face's Trainer
.
We first load a pre-trained model, e.g., roberta-base
and add a new task adapter:
model = AutoAdapterModel.from_pretrained('roberta-base') model.add_adapter("sst-2") model.train_adapter("sst-2")
By calling train_adapter("sst-2")
we freeze all transformer parameters except for the parameters of sst-2 adapter.
Before training we add a new classification head to our model:
model.add_classification_head("sst-2", num_labels=2) model.set_active_adapters("sst-2")
The weights of this classification head can be stored together with the adapter weights to allow for a full reproducibility.
The method call model.set_active_adapters("sst-2")
registers the sst-2 adapter as a default for training. This also supports adapter stacking and adapter fusion!
We can then train our adapter using the Hugging Face Trainer
:
trainer.train() model.save_all_adapters('output-path')
lr=0.0001
works well for most settings.That's it! model.save_all_adapters('output-path')
exports all adapters. Consider sharing them on AdapterHub!
Using AdapterFusion, we can combine the knowledge of multiple pre-trained adapters on a downstream task. First, we load a pre-trained model and a couple of pre-trained adapters. As we discard the prediction heads of the pre-trained adapters, we add a new head afterwards.
from adapters import AutoAdapterModel, Fuse model = AutoAdapterModel.from_pretrained("bert-base-uncased") model.load_adapter("nli/multinli@ukp", load_as="multinli", with_head=False) model.load_adapter("sts/qqp@ukp", with_head=False) model.load_adapter("nli/qnli@ukp", with_head=False) model.add_classification_head("cb")
On top of the loaded adapters, we add a new fusion layer using add_fusion()
.
For this purpose, we first define the adapter setup using the Fuse
composition block.
During training, only the weights of the fusion layer will be updated. We ensure this by first activating all adapters in the setup and then calling train_fusion()
:
adapter_setup = Fuse("multinli", "qqp", "qnli") model.add_adapter_fusion(adapter_setup) model.set_active_adapters(adapter_setup) model.train_adapter_fusion(adapter_setup)
From here on, the training procedure is identical to training a single adapters or a full model. Check out the full working example in the Colab notebook.
AdapterDrop allows us to remove adapters on lower layers during training and inference. This can be realised with the
skip_layers
argument. It specifies for which layers the adapters should be skipped during a forward pass. In
order to train a model with AdapterDrop, we specify a callback for the Trainer
class that sets the skip_layers
argument to the layers that should be skipped in each step as follows:
class AdapterDropTrainerCallback(TrainerCallback): def on_step_begin(self, args, state, control, **kwargs): skip_layers = list(range(np.random.randint(0, 11))) kwargs['model'].set_active_adapters("rotten_tomatoes", skip_layers=skip_layers) def on_evaluate(self, args, state, control, **kwargs): # Deactivate skipping layers during evaluation (otherwise it would use the # previous randomly chosen skip_layers and thus yield results not comparable # across different epochs) kwargs['model'].set_active_adapters("rotten_tomatoes", skip_layers=None)
Checkout the AdapterDrop Colab Notebook for further details.
During inference, it might be beneficial to pass the input data through several different adapters to compare
the results or predict different attributes in one forward pass. The
Parallel Block enables us to do this.
When the Parallel Block is used in combination with a ModelWithHeads
class, each adapter also has a corresponding head.
model = AutoAdapterModel.from_pretrained("bert-base-uncased") model.add_adapter("task1") model.add_adapter("task2") model.add_classification_head("task1", num_labels=3) model.add_classification_head("task2", num_labels=5) model.set_active_adapters(Parallel("task1", "task2")
A forward pass through the model with the Parallel Block is equivalent to two single forward passes. One through the model
with the task1
adapter and head activated and one through the model with the task2
adapter and head.
The output is returned as a MultiHeadOutput
, which acts as a list of the head outputs with an additional
loss
attribute. The loss attribute is the sum of the losses of individual outputs.
If you use the Adapters library in your work, please consider citing our library paper: Adapters: A Unified Library for Parameter-Efficient and Modular Transfer Learning
@inproceedings{poth-etal-2023-adapters, title = "Adapters: A Unified Library for Parameter-Efficient and Modular Transfer Learning", author = {Poth, Clifton and Sterz, Hannah and Paul, Indraneil and Purkayastha, Sukannya and Engl{\"a}nder, Leon and Imhof, Timo and Vuli{\'c}, Ivan and Ruder, Sebastian and Gurevych, Iryna and Pfeiffer, Jonas}, booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing: System Demonstrations", month = dec, year = "2023", address = "Singapore", publisher = "Association for Computational Linguistics", url = "https://aclanthology.org/2023.emnlp-demo.13", pages = "149--160", }
Alternatively, for the Hub infrastructure and adapters uploaded by the AdapterHub team, please consider citing our initial paper: AdapterHub: A Framework for Adapting Transformers
@inproceedings{pfeiffer2020AdapterHub, title={AdapterHub: A Framework for Adapting Transformers}, author={Jonas Pfeiffer and Andreas R\"uckl\'{e} and Clifton Poth and Aishwarya Kamath and Ivan Vuli\'{c} and Sebastian Ruder and Kyunghyun Cho and Iryna Gurevych}, booktitle={Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP 2020): Systems Demonstrations}, year={2020}, address = "Online", publisher = "Association for Computational Linguistics", url = "https://www.aclweb.org/anthology/2020.emnlp-demos.7", pages = "46--54", }