import pickle
import os
import time
import json
[docs]class ProgressState:
def __init__(self, name, save_path, dependency_dag, factories_num=0):
self.fname = "state.pkl"
self.name = name
self.save_path = save_path
self.total = len(dependency_dag)
self.done = 0
# 0 if running local,
# Number of factories in case running remotely
self.factories_num = factories_num
# 0 -> PENDING
# 1 -> RUNNING
# 2 -> SUCCESS
# 3 -> FAILURE
# 4 -> MASKED
self.block_state = {}
for b in dependency_dag.keys():
self.block_state[b] = 0
self.time_lapses = {}
self.variants = {}
self.finished = False
self.prev_time = None
self.dependency_dag = dependency_dag
self._save()
[docs] def checkpoint_start(self, block_id):
self.prev_time = time.time()
self.block_state[block_id] = 1
self._save()
[docs] def refresh(self):
self.variants = {}
for b in self.dependency_dag.keys():
self.variants[b] = []
block_path = os.path.join(self.save_path, b)
if os.path.exists(block_path):
variants = filter(
lambda x: os.path.isdir(os.path.join(block_path, x)), os.listdir(block_path)
)
for v in variants:
# Check if hparams were specified.
# hparams follows this syntax:
# "hparam1=value1,hparam2=value2"
if '=' in v:
v = v[:v.find("-") - 5] # Substract year of date. TODO replace for regex
hparams = {}
for h in v.split(','):
k, v = h.split('=')
hparams[k] = v
self.variants[b].append(hparams)
[docs] def checkpoint_end(self, block_id, checkpoints, block_success):
curr_time = time.time()
if self.prev_time:
self.time_lapses[block_id] = curr_time - self.prev_time
self.prev_time = None
self.done += 1
self.block_state[block_id] = 2 if block_success else 3
self._save()
[docs] def finish(self):
self.finished = True
self._save()
[docs] def _save(self):
with open(os.path.join(self.save_path, self.fname), 'wb') as f:
pickle.dump(self, f)
[docs] def toJSON(self):
self.refresh()
return json.dumps(self, default=lambda o: o.__dict__,
sort_keys=True, indent=4)