Using Torch Ranker Agent¶
Authors: Emily Dinan
TorchRankerAgent is an abstract parent class for PyTorch models that rank possible responses from a set of possible candidates. It inherits from TorchAgent and contains boilerplate code for training and evaluating ranking models.
Several existing models in ParlAI inherit from TorchRankerAgent. Try some of the examples below:
Train a Bag-of-words Ranker model on ConvAI2:
parlai train_model --model examples/tra --task convai2 --model-file /tmp/test --batchsize 32
Train a Transformer Ranker model on ConvAI2:
parlai train_model --model transformer/ranker --task convai2 --model-file /tmp/tr_convai2_test
Train a Memory Network model on Daily Dialog:
parlai train_model --model memnn --task dailydialog --model-file /tmp/memnn_dd_test --batchsize 20 --candidates batch --eval-candidates batch
Train a BERT-based Bi-Encoder ranker model on Twitter:
parlai train_model --model bert_ranker/bi_encoder_ranker --task twitter --model-file /tmp/bert_twitter_test --batchsize 10 --candidates batch --eval-candidates batch --data-parallel True
Creating a Model¶
In order to write a ranking model that inherits from TorchRankerAgent, you must implement the following functions: (An example is available at parlai/agents/examples/tra.py)
def score_candidates(self, batch, cand_vecs, cand_encs=None): """This function takes in a Batch object as well as a Tensor of candidate vectors. It must return a list of scores corresponding to the likelihood that the candidate vector at that index is the proper response. If `cand_encs` is not None (when we cache the encoding of the candidate vectors), you may use these instead of calling self.model on `cand_vecs`. """ pass def build_model(self): """This function is required to build the model and assign to the object `self.model`. """ pass
Training a Model¶
This flag is used to determine which candidates to rank during training. There are several options:
Possible sources of candidates include batch, batch-all-cands, inline, and fixed.
batch – Use all labels in the batch as the candidate set (with all but the example’s label being treated as negatives).
With this setting, the candidate set is identical for all examples in a batch. This option may be undesirable if it is possible for duplicate labels to occur in a batch, since the second instance of the correct label will be treated as a negative.
batch-all-cands – Use all inline candidates in the batch as candidate set.
This can result in a very large number of candidates.
In this case we will deduplicate candidates.
just like with ‘batch’ the candidate set is identical for all examples in a batch.
inline – If each example comes with a list of possible label candidates, use those. Each teacher act for the task should contain the field ‘label_candidates’.
With this setting, each example will have its own candidate set.
fixed – Use a global candidates list, provided by the user. If
self.fixed\_candidatesis not None, use a set of fixed candidates for all examples.
This setting is not recommended for training unless the universe of possible candidates is very small. To use this, add the path to your text file with the candidates to the flag
Tracking ranking metrics¶
During training, we omit some ranking metrics (like
hits@k) for the
sake of speed. To get these ranking metrics, use the flag
Evaluating a Model¶
Evaluating on a fixed candidate set¶
As during training, you must add the path to your text file with the
candidates to the flag
-fcp. For many
models, it’s convenient to cache the the encoding of the candidates in
the case that the encoding is independent of the context. In order to do
this and save to a file, set the flag
--encode-candidate-vecs True. In
order to do this, you must implement the function
which takes in a batch of padded candidates and outputs a batch of
candidates encoded with the model.
Evaluating on “vocab” candidates¶
In addition to the options above for evaluating a model, we also have
the option of evaluating “vocab” candidates. This is one global
candidate list, extracted from the vocabulary with the exception of