Source code for flambe.nlp.transformers.field
from typing import Optional, Union, List, Dict, Any, Tuple
import torch
from transformers import AutoTokenizer
from flambe.field import Field
[docs]class PretrainedTransformerField(Field):
"""Field intergation of the transformers library.
Instantiate this object using any alias available in the
`transformers` library. More information can be found here:
https://huggingface.co/transformers/
"""
def __init__(self,
alias: str,
cache_dir: Optional[str] = None,
max_len_truncate: int = 500,
add_special_tokens: bool = True, **kwargs) -> None:
"""Initialize a pretrained tokenizer.
Parameters
----------
alias: str
Alias of a pretrained tokenizer.
cache_dir: str, optional
A directory where to cache the downloaded vocabularies.
max_len_truncate: int, default = 500
Truncates the length of the tokenized sequence.
Because several pretrained models crash when this is
> 500, it defaults to 500
add_special_tokens: bool, optional
Add the special tokens to the inputs. Default ``True``.
"""
self._tokenizer = AutoTokenizer.from_pretrained(alias, cache_dir=cache_dir, **kwargs)
self.max_len_truncate = max_len_truncate
self.add_special_tokens = add_special_tokens
@property
[docs] def padding_idx(self) -> int:
"""Get the padding index.
Returns
-------
int
The padding index in the vocabulary
"""
pad_token = self._tokenizer.pad_token
return self._tokenizer.convert_tokens_to_ids(pad_token)
@property
[docs] def vocab_size(self) -> int:
"""Get the vocabulary length.
Returns
-------
int
The length of the vocabulary
"""
return len(self._tokenizer)
[docs] def process(self, example: # type: ignore
Union[str, Tuple[Any], List[Any], Dict[Any, Any]]) \
-> Union[torch.Tensor, Tuple[torch.Tensor, ...],
List[torch.Tensor], Dict[str, torch.Tensor]]:
"""Process an example, and create a Tensor.
Parameters
----------
example: str
The example to process, as a single string
Returns
-------
torch.Tensor
The processed example, tokenized and numericalized
"""
# special case of list of examples:
if isinstance(example, list) or isinstance(example, tuple):
return [self.process(e) for e in example] # type: ignore
elif isinstance(example, dict):
return dict([(key, self.process(val)) for key, val in example.items()]) # type: ignore
tokens = self._tokenizer.encode(example, add_special_tokens=self.add_special_tokens)
if self.max_len_truncate is not None:
tokens = tokens[:self.max_len_truncate]
return torch.tensor(tokens)