Creating an Agent¶
Author: Alexander Holden Miller
In this tutorial, we’ll be setting up an agent which learns from the data it sees to produce the right answers.
For this agent, we’ll be implementing a simple GRU Seq2Seq agent based on Sequence to Sequence Learning with Neural Networks (Sutskever et al. 2014) and Sean Robertson’s Seq2Seq PyTorch tutorial.
Part 1: Naming Things¶
In order to make programmatic importing easier, we use a simple naming scheme for our models, so that on the command line we can just type “–model seq2seq” to load up the seq2seq model.
To this end, we create a folder under parlai/agents with the name seqseq, and then put an empty __init__.py file there along with seq2seq.py. Then, we name our agent “Seq2seqAgent”.
This way, “–model seq2seq” can translate to “parlai.agents.seq2seq.seq2seq:Seq2seqAgent”. Underscores in the name become capitals in the class name: “–model local_human” resides at “parlai.agents.local_human.local_human:LocalHumanAgent”. If you need to put a model at a different path, you can specify the full path on the command line in the format above (with a colon in front of the class name). For example, “–model parlai.agents.remote_agent.remote_agent:ParsedRemoteAgent”.
Part 2: Main Agent Methods¶
First off, generally we should inherit from the Agent class in parlai.core.agents.
This provides us with some default implementations (often, pass
) of some utility
functions like “shutdown”.
First let’s focus on the primary functions to implement: __init__
, observe
, and act
.
The standard initialization parameters for agents are a dict of command-line parameters opt and an optional dict of shared parameters called shared.
For our Seq2Seq model we’ll call our parent init method, which does a few basic operations like setting self.observation to None and creating a deep copy of the opt dict.
Then, we do a check to see if the shared parameter is set. When it is not None, it’s telling this instance to initialize with this particular state, as this instance will be used either for batched or hogwild training (depending on your preference). We’ll take a quick digression to describe how batching is set up.
Batching Example¶
Let’s say we are training our seq2seq model on babi:task10k:1. What happens behind the scenes for a batch size of 4 is that we actually create four shared versions of the bAbI Task10k teacher, and four shared versions of the seq2seq agent. These shared versions are initialized from the originals: for the bAbI teachers, they inherit the data from their parent agent, but they each have their own local state such as the current example they’re showing or how far through a bAbI episode they are (bAbI task 1 has five examples per episode). For the seq2seq agent, each shared agent is keeping track of the previous examples they’ve seen in this same episode, since each observation does not repeat previously seen but related information–the agent has to remember it.
For example, in the first example the agent could get something like the following: “John is in the bathroom. Mary is in the kitchen. Where is Mary?” And in the second example in the episode, the agent could get: “Mary picked up the milk. Mary went to the hallway. Where is John?” Here, the answer is in the first example’s context, so the agent had to remember it.
Observations are generated by calling the act
function on each teacher, then
passing those observations to each agent by calling the observe
function of the
shared agents. The agents are free to transform the previous observation
(for example, prepending previously seen text from the same episode, if applicable).
These transformed observations are packed into a list, which is then passed to
batch_act
function our agent implements. We can implement batch_act
differently
from the simple act
function to take advantage of the effects of batching
over multiple examples when executing or updating our model.
Thus, since our agent’s shared-instances will only be used to keep track
of state particular to their sequence of examples in the batch, we have
barely anything to do when setting these shared instances up: we just initialize the
self.episodeDone
flag so we know whether we are in the middle of an episode or not.
The full initialization of the model is included further below, but is very particular to this particular implementation. Let’s talk more about the primary agent functions we need to define first.
Observing and Acting¶
Let’s take a look at the observe
function. Here, we can modify the
observation dict if necessary, and then return it to be queued for batching.
In this version, we first make a deep copy of the observation. Then, if this is not the first entry in an episode (some datasets like SQuAD have only one entry for every episode, but others like bAbI have multiple), then we prepend the previous text to the current text. We use a newline to separate them in case the model wants to recognize the difference between different lines.
Then, we store whether this is the last entry in the episode so that we’ll be ready to reset next time if we need to.
def observe(self, observation):
observation = copy.deepcopy(observation)
if not self.episode_done:
# if the last example wasn't the end of an episode, then we need to
# recall what was said in that example
prev_dialogue = self.observation['text']
observation['text'] = prev_dialogue + '\n' + observation['text']
self.observation = observation
self.episode_done = observation['episode_done']
return observation
Next up is the act
function. Since we are going to implement a batched
version, we’ll just call the batched version from our single-example act to
reduce code duplication. The performance hit here won’t matter much since we’ll
only use a batch size of one when debugging.
def act(self):
# call batch_act with this batch of one
return self.batch_act([self.observation])[0]
Now it’s time for the batch_act function. This function gets a list of length batchsize of observations and returns a list of the same length with this agent’s replies.
We’ll follow this loose format:
- Set up our list of dicts to send back as replies, with the agent’s ID set.
- Convert the incoming observations into tensors to feed into our model.
- Produce predictions on the input text using the model. If labels were provided, update the model as well.
- Unpack the predictions into the reply dicts and return them.
def batch_act(self, observations):
batchsize = len(observations)
# initialize a table of replies with this agent's id
batch_reply = [{'id': self.getID()} for _ in range(batchsize)]
# convert the observations into batches of inputs and targets
# valid_inds tells us the indices of all valid examples
# e.g. for input [{}, {'text': 'hello'}, {}, {}], valid_inds is [1]
# since the other three elements had no 'text' field
xs, ys, valid_inds = self.batchify(observations)
if len(xs) == 0:
# no valid examples, just return the empty responses we set up
return batch_reply
# produce predictions either way, but use the targets if available
predictions = self.predict(xs, ys)
for i in range(len(predictions)):
# map the predictions back to non-empty examples in the batch
# we join with spaces since we produce tokens one at a time
batch_reply[valid_inds[i]]['text'] = ' '.join(
c for c in predictions[i] if c != self.EOS)
return batch_reply
Since the implementation of batchify
and predict
are particular to our
model, we’ll table those for now. Next up, we’ll cover some of
the other methods in the Agent API.
Part 3: Extended Agent API¶
There are a few other useful methods you may want to define in your agent to take of additional functionality one might want during training. Many of these functions will be automatically called if you use our example training function to train your model.
shutdown()¶
This function allows your model to do any final wrapup, such as writing any last logging info, saving an end-state version of the model if desired, or closing any open connections.
Our seq2seq model doesn’t implement this, but the agents in parlai/agents/remote_agent use this to close their open TCP connection after sending a shutdown signal through.
Part 4: Finishing the Seq2Seq model¶
Here we’ll take a look at the full details of __init__
, batchify
, predict
, and more.
Full __init__()¶
Here’s the full code to get the initialization of our model working. While you might define the model as a separate class if you prefer, we’re going to define its modules in-line here, since it’s such a simple model.
class Seq2seqAgent(Agent):
def __init__(self, opt, shared=None):
# initialize defaults first
super().__init__(opt, shared)
if not shared:
# this is not a shared instance of this class, so do full
# initialization. if shared is set, only set up shared members.
self.dict = DictionaryAgent(opt)
self.id = 'Seq2Seq'
# we use EOS markers to break input and output and end our output
self.EOS = self.dict.eos_token
self.observation = {'text': self.EOS, 'episode_done': True}
self.EOS_TENSOR = torch.LongTensor(self.dict.parse(self.EOS))
# store important params directly
hsz = opt['hiddensize']
self.hidden_size = hsz
self.num_layers = opt['numlayers']
self.learning_rate = opt['learningrate']
self.longest_label = 1
# set up modules
self.criterion = nn.NLLLoss()
# lookup table stores word embeddings
self.lt = nn.Embedding(len(self.dict), hsz, padding_idx=0,
scale_grad_by_freq=True)
# encoder captures the input text
self.encoder = nn.GRU(hsz, hsz, opt['numlayers'])
# decoder produces our output states
self.decoder = nn.GRU(hsz, hsz, opt['numlayers'])
# linear layer helps us produce outputs from final decoder state
self.h2o = nn.Linear(hsz, len(self.dict))
# droput on the linear layer helps us generalize
self.dropout = nn.Dropout(opt['dropout'])
# softmax maps output scores to probabilities
self.softmax = nn.LogSoftmax()
# set up optims for each module
lr = opt['learningrate']
self.optims = {
'lt': optim.SGD(self.lt.parameters(), lr=lr),
'encoder': optim.SGD(self.encoder.parameters(), lr=lr),
'decoder': optim.SGD(self.decoder.parameters(), lr=lr),
'h2o': optim.SGD(self.h2o.parameters(), lr=lr),
}
# check for cuda
self.use_cuda = not opt.get('no_cuda') and torch.cuda.is_available()
if self.use_cuda:
print('[ Using CUDA ]')
torch.cuda.set_device(opt['gpu'])
if self.use_cuda:
self.cuda()
self.episode_done = True
batchify()¶
The batchify function takes in a list of observations and turns them into tensors to use with our model.
def batchify(self, observations):
"""Convert a list of observations into input & target tensors."""
# valid examples
exs = [ex for ex in observations if 'text' in ex]
# the indices of the valid (non-empty) tensors
valid_inds = [i for i, ex in enumerate(observations) if 'text' in ex]
# set up the input tensors
batchsize = len(exs)
# tokenize the text
parsed = [self.parse(ex['text']) for ex in exs]
max_x_len = max([len(x) for x in parsed])
xs = torch.LongTensor(batchsize, max_x_len).fill_(0)
# pack the data to the right side of the tensor for this model
for i, x in enumerate(parsed):
offset = max_x_len - len(x)
for j, idx in enumerate(x):
xs[i][j + offset] = idx
if self.use_cuda:
xs = xs.cuda(async=True)
xs = Variable(xs)
# set up the target tensors
ys = None
if 'labels' in exs[0]:
# randomly select one of the labels to update on, if multiple
# append EOS to each label
labels = [random.choice(ex['labels']) + ' ' + self.EOS for ex in exs]
parsed = [self.parse(y) for y in labels]
max_y_len = max(len(y) for y in parsed)
ys = torch.LongTensor(batchsize, max_y_len).fill_(0)
for i, y in enumerate(parsed):
for j, idx in enumerate(y):
ys[i][j] = idx
if self.use_cuda:
ys = ys.cuda(async=True)
ys = Variable(ys)
return xs, ys, valid_inds
predict()¶
The predict function returns an output from our model. If the targets are provided, then it also updates the model. The predictions will be biased in this case, since we condition each token on the true label token, but we are okay with that–it just improves training F1 scores.
def predict(self, xs, ys=None):
"""Produce a prediction from our model. Update the model using the
targets if available.
"""
batchsize = len(xs)
# first encode context
xes = self.lt(xs).t()
h0 = torch.zeros(self.num_layers, bsz, self.hidden_size)
if self.use_cuda:
h0 = h0.cuda(async=True)
h0 = Variable(h0)
_output, hn = self.encoder(xes, h0)
# next we use EOS as an input to kick off our decoder
x = Variable(self.EOS_TENSOR)
xe = self.lt(x).unsqueeze(1)
xes = xe.expand(xe.size(0), batchsize, xe.size(2))
# list of output tokens for each example in the batch
output_lines = [[] for _ in range(batchsize)]
if ys is not None:
# update the model based on the labels
self.zero_grad()
loss = 0
# keep track of longest label we've ever seen
self.longest_label = max(self.longest_label, ys.size(1))
for i in range(ys.size(1)):
output, hn = self.decoder(xes, hn)
preds, scores = self.hidden_to_idx(output, drop=True)
y = ys.select(1, i)
loss += self.criterion(scores, y)
# use the true token as the next input instead of predicted
# this produces a biased prediction but better training
xes = self.lt(y).unsqueeze(0)
for b in range(batchsize):
# convert the output scores to tokens
token = self.v2t([preds.data[b][0]])
output_lines[b].append(token)
loss.backward()
self.update_params()
else:
# just produce a prediction without training the model
done = [False for _ in range(batchsize)]
total_done = 0
max_len = 0
while(total_done < batchsize) and max_len < self.longest_label:
# keep producing tokens until we hit EOS or max length for each
# example in the batch
output, hn = self.decoder(xes, hn)
preds, scores = self.hidden_to_idx(output, drop=False)
xes = self.lt(preds.t())
max_len += 1
for b in range(batchsize):
if not done[b]:
# only add more tokens for examples that aren't done yet
token = self.v2t(preds.data[b])
if token == self.EOS:
# if we produced EOS, we're done
done[b] = True
total_done += 1
else:
output_lines[b].append(token)
return output_lines