Components¶
The most important class in Flambé is Component
which implements
loading from YAML (using !ClassName
notation) and saving state.
Loading and Dumping from YAML¶
A Component
can be created from a YAML config representation,
as seen the Quickstart example.
Lets take the previously used TextClassifier
component:
!TextClassifier
embedder: !Embedder
embedding: !torch.Embedding
num_embeddings: 200
embedding_dim: 300
encoder: !PooledRNNEncoder
input_size: 300
rnn_type: lstm
n_layers: 3
hidden_size: 256
output_layer: !SoftmaxLayer
input_size: 256
output_size: 10
Loading and dumping objects can be done using flambe.compile.yaml
module.
1 2 3 4 5 6 7 8 | from flambe.compile import yaml
# Loading from YAML into a Schema
text_classifier_schema = yaml.load(open("model.yaml"))
text_classifier = text_classifier_schema() # Compile the Schema
# Dumping object
yaml.dump(text_classifier, open("new_model.yaml", "w"))
|
Important
Components
compile to an intermediate state called Schema
when calling
yaml.load()
. This partial representation can be compiled into the final
object by calling obj()
(ie executing __call__
), as shown in the example above. For more information
about this, go to Delayed Initialization.
See also
For more examples of the YAML representation of an object look at understanding-configuration_label
Saving and Loading State¶
While YAML represents the “architecture” or how to create an instance of some class,
it does not capture the state. For state, Components
rely on a recursive get_state()
and load_state()
methods that work similarly to PyTorch’s
nn.Module.state_dict
and nn.Module.load_state_dict
:
1 2 3 4 5 6 7 8 9 10 11 12 | from flambe.compile import yaml
# Loading from YAML into a Schema
text_classifier_schema = yaml.load(open("model.yaml"))
text_classifier = text_classifier_schema() # Compile the Schema
state = text_classifier.get_state()
from flambe.nlp.classification import TextClassifier
another_text_classifier = TextClassifier(...)
another_text_classifier.load_state(state)
|
Semantic Versioning
In order to identify and describe changes in class definitions, flambé supports opt-in semantic class versioning. (If you’re not familiar with semantic versioning see this link).
Each class has a class property _flambe_version
to prevent conflics when loading
previously saved states.
Initially, all versions are set to 0.0.0
, indicating that class versioning should
not be used. Once you increment the version, Flambé will then start comparing
the saved class version with the version on the class at load-time.
See also
See Adding Custom State for more information about
get_state()
and load_state()
.
Delayed Initialization¶
When you load Components
from YAML they are not initialized into objects immediately.
Instead, they are precompiled into a Schema
that you can think
of as a blueprint for how to create the object later.
This mechanism allows Components
to use links and grid search options.
If you load a schema directly from YAML you can compile it into an instance by calling the schema:
1 2 3 4 | from flambe.compile import yaml
schema = yaml.load('path/to/file.yaml')
obj = schema()
|
Core Components¶
Dataset
This object holds the training, validation and test data. Its only requirement is to have the three properties:
train
,dev
andtest
, each pointing to a list of examples. For convenience we provide aTabularDataset
implementation of the interface, which can load anycsv
ortsv
type format.1 2 3 4 5 6 7 8 9
from flambe.dataset import TabularDataset import numpy as np # Random dataset train = np.random.random((2, 100)) val = np.random.random((2, 10)) test = np.random.random((2, 10)) dataset = TabularDataset(train, val, test)
Field
A field takes raw examples and produces a
torch.Tensor
(or tuple oftorch.Tensor
). We provide useful fields such asTextField
, orLabelField
which perform tokenization and numericalization.1 2 3 4 5 6 7 8 9 10 11 12 13 14
from flambe.field import TextField from flambe.tokenizer import WordTokenizer import numpy as np # Random dataset data = np.array(['Flambe is awesome', 'This framework rocks!']) text_field = TextField(WordTokenizer()) # Setup the entire dataset to build vocab. text_field.setup(data) text_field.vocab_size # Returns to 9 text_field.process("Flambe rocks") # Returns tensor([6, 1])
Sampler
A sampler produces batches of data, as an interator. We provide a simple
BaseSampler
implementation, which takes a dataset as input, as well as the batch size, and produces batches of data. Each batch is a tuple of tensors, padded to the maximum length along each dimension.1 2 3 4 5 6 7 8 9
from flambe.sampler import BaseSampler from flambe.dataset import TabularDataset import numpy as np dataset = TabularDataset(np.random.random((2, 10))) sampler = BaseSampler(batch_size=4) for batch in sampler.sample(dataset): # Do something with batch
Module
This object is the main model component interface. It must implement the
forward
method as PyTorch’snn.Module
requires.We also provide additional machine learning components in the
nn
submodule, such asEncoder
with many different implementations of these interfaces.Trainer
- A
Trainer
takes as input the training and dev samplers, as well as a model and an optimizer. By default, the object keeps track of the last and best models, and each call to run is considered to be an arbitrary of training iterations, and a single evaluation pass over the validation set. It implements themetric()
method, which points to the best metric observed so far. Evaluator
- An
Evaluator
evaluates a givennn`Module
over aDataset
and computes given metrics. Script
- A
Script
integrate a pre-written script with Flambé.
Important
For more detailed information about this Components
, please refer to their documentation.
Custom Component¶
Custom Components
should implement the run()
method.
This method performs a single computation step, and returns a boolean,
indicating whether the Component
is done executing (True
iff there is more work to do).
1 2 3 4 5 6 7 8 9 | class MyClass(Component):
def __init__(self, a, b):
super().__init__()
...
def run(self) -> bool:
...
return continue_flag
|
Tip
We recommend always extending from an implementation of Component
rather
than implementing the plain interface. For example, if implementing an autoencoder,
inherit from Module
or if implementing cross validation training, inherit from Trainer
.
If you would like to include custom state in the state returned by get_state()
method
see the Adding Custom State section and the Component
package reference.
Then in YAML you could do the following:
!MyClass
a: val1
b: val2
# or using the registrable_factory
Flambé also provides a way of registering factory methods to be used in YAML:
1 2 3 4 5 6 7 8 9 | class MyClass(Component):
...
@registrable_factory
@classmethod
def special_factory(cls, x, y):
a, b = do_something(x, y)
return cls(a, b)
|
Now you can do:
!MyClass.special_factory
x: val1
y: val2
For information on how to add your custom Component
in the YAML files, go to Extensions