Source code for flambe.nlp.transformers.utils
from typing import Optional, Type, Any
import torch
import pytorch_transformers as pt
from flambe.field import Field
from flambe.nn import Module
[docs]class TransformerTextField(Field):
def __init__(self,
alias: str,
cache_dir: Optional[str] = None,
max_len_truncate: Optional[int] = None,
add_special_tokens: bool = True, **kwargs) -> None:
"""Initialize from a pretrained tokenizer.
Parameters
----------
alias: str
Alias of a pretrained tokenizer.
cache_dir: str, optional
A directory where to cache the downloaded vocabularies.
max_len_truncate: int, optional
If given, truncates the length of the tokenized sequence.
"""
self._tokenizer = self._cls.from_pretrained(alias, cache_dir=cache_dir, **kwargs)
self.max_len_truncate = max_len_truncate
self.add_special_tokens = add_special_tokens
@property
[docs] def padding_idx(self) -> int:
"""Get the padding index.
Returns
-------
int
The padding index in the vocabulary
"""
pad_token = self._tokenizer.pad_token
return self._tokenizer.convert_tokens_to_ids(pad_token)
@property
[docs] def vocab_size(self) -> int:
"""Get the vocabulary length.
Returns
-------
int
The length of the vocabulary
"""
return len(self._tokenizer)
[docs] def process(self, example: str) -> torch.Tensor: # type: ignore
"""Process an example, and create a Tensor.
Parameters
----------
example: str
The example to process, as a single string
Returns
-------
torch.Tensor
The processed example, tokenized and numericalized
"""
tokens = self._tokenizer.encode(example, add_special_tokens=self.add_special_tokens)
if self.max_len_truncate is not None:
tokens = tokens[:self.max_len_truncate]
return torch.tensor(tokens)
[docs]class TransformerEmbedder(Module):
def __init__(self,
alias: str,
cache_dir: Optional[str] = None,
padding_idx: Optional[int] = None,
pool: bool = False, **kwargs) -> None:
"""Initialize from a pretrained model.
Parameters
----------
alias: str
Alias of a pretrained model.
cache_dir: str, optional
A directory where to cache the downloaded vocabularies.
padding_idx: int, optional
The padding index used to compute the attention mask.
pool: optional, optional
Whether to return the pooled output or the full sequence
encoding. Default ``False``.
"""
super().__init__()
embedder = self._cls.from_pretrained(alias, cache_dir=cache_dir, **kwargs)
self.config = embedder.config
self._embedder = embedder
self.padding_idx = padding_idx
self.pool = pool
[docs] def forward(self,
data: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Perform a forward pass through the network.
If pool was provided, will only return the pooled output
of shape [B x H]. Otherwise, returns the full sequence encoding
of shape [S x B x H].
Parameters
----------
data : torch.Tensor
The input data of shape [B x S]
token_type_ids : Optional[torch.Tensor], optional
Segment token indices to indicate first and second portions
of the inputs. Indices are selected in ``[0, 1]``: ``0``
corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token. Has shape [B x S]
attention_mask : Optional[torch.Tensor], optional
FloatTensor of shape [B x S]. Masked values should
be 0 for padding tokens, 1 otherwise.
position_ids : Optional[torch.Tensor], optional
Indices of positions of each input sequence tokens
in the position embedding. Defaults to the order given
in the input. Has shape [B x S].
head_mask : Optional[torch.Tensor], optional
Mask to nullify selected heads of the self-attention
modules. Should be 0 for heads to mask, 1 otherwise.
Has shape [num_layers x num_heads]
Returns
-------
torch.Tensor
If pool is True, returns a tneosr of shape [B x H],
else returns an encoding for each token in the sequence
of shape [B x S x H].
"""
if attention_mask is None and self.padding_idx is not None:
attention_mask = (data != self.padding_idx).float()
outputs = self._embedder(data,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask)
output = outputs[0] if not self.pool else outputs[1]
return output
[docs] def __getattr__(self, name: str) -> Any:
"""Override getattr to inspect config.
Parameters
----------
name : str
The attribute to fetch
Returns
-------
Any
The attribute
"""
try:
return super().__getattr__(name)
except AttributeError as e:
config = self.__dict__['config']
if hasattr(config, name):
return getattr(config, name)
else:
raise e