from typing import Tuple
import torch
from torch.optim.optimizer import Optimizer
from torch.nn import Module
from piepline.utils.utils import dict_recursive_bypass
from piepline.train_config.train_config import BaseTrainConfig
__all__ = ['DataProcessor', 'TrainDataProcessor']
[docs]class DataProcessor:
"""
DataProcessor manage: model, data processing, device choosing
Args:
model (Module): model, that will be used for process data
device (torch.device): what device pass data for processing
"""
def __init__(self, model: Module, device: torch.device = None):
self._checkpoints_manager = None
self._model = model
self._device = device
self._pick_model_input = lambda data: data['data']
self._data_preprocess = lambda data: data
self._data_to_device = self._pass_object_to_device
[docs] def model(self) -> Module:
"""
Get current module
"""
return self._model
[docs] def predict(self, data: torch.Tensor or dict) -> object:
"""
Make predict by data
:param data: data as :class:`torch.Tensor` or dict with key ``data``
:return: processed output
:rtype: the model output type
"""
self.model().eval()
with torch.no_grad():
output = self._model(self._data_to_device(self._data_preprocess(self._pick_model_input(data))))
return output
def set_data_to_device(self, data_to_device: callable) -> 'DataProcessor':
self._data_to_device = data_to_device
return self
[docs] def set_data_preprocess(self, data_preprocess: callable) -> 'DataProcessor':
"""
Set callback, that will get output from :mod:`DataLoader` and return preprocessed data.
For example may be used for pass data to device.
Default mode:
.. highlight:: python
.. code-block:: python
:meth:`_pass_data_to_device`
Args:
data_preprocess (callable): preprocess callable. This callback need to get one parameter: dataset output
Returns:
self object
Examples:
.. highlight:: python
.. code-block:: python
from piepline.utils import dict_recursive_bypass
data_processor.set_data_preprocess(lambda data: dict_recursive_bypass(data, lambda v: v.cuda()))
"""
self._data_preprocess = data_preprocess
return self
def _pass_object_to_device(self, data) -> torch.Tensor or dict:
"""
Internal method, that pass data to specified device
:param data: data as any object type. If will passed to device if it's instance of :class:`torch.Tensor` or dict with key
``data``. Otherwise data will be doesn't changed
:return: processed on target device
"""
if self._device is None:
return data
if isinstance(data, dict):
return dict_recursive_bypass(data, lambda v: v.to(self._device))
elif isinstance(data, torch.Tensor):
return data.to(self._device)
else:
return data
[docs]class TrainDataProcessor(DataProcessor):
"""
TrainDataProcessor is make all of DataProcessor but produce training process.
:param train_config: train config
"""
[docs] class TDPException(Exception):
def __init__(self, msg):
self._msg = msg
def __str__(self):
return self._msg
def __init__(self, train_config: 'BaseTrainConfig', device: torch.device = None):
super().__init__(train_config.model(), device)
self._pick_target = lambda data: data['target']
self._target_preprocess = lambda data: data
self._target_to_device = self._pass_object_to_device
self._criterion = train_config.loss()
self._optimizer = train_config.optimizer()
def optimizer(self) -> Optimizer:
return self._optimizer
[docs] def predict(self, data, is_train=False) -> torch.Tensor or dict:
"""
Make predict by data. If ``is_train`` is ``True`` - this operation will compute gradients. If
``is_train`` is ``False`` - this will work with ``model.eval()`` and ``torch.no_grad``
:param data: data in dict
:param is_train: is data processor need train on data or just predict
:return: processed output
:rtype: model return type
"""
if is_train:
self.model().train()
output = self._model(self._data_to_device(self._data_preprocess(self._pick_model_input(data))))
else:
output = super().predict(data)
return output
[docs] def process_batch(self, batch: {}, is_train: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Process one batch of data
Args:
batch (dict): contains 'data' and 'target' keys. The values for key must be instance of torch.Tensor or dict
is_train (bool): is batch process for train
Returns:
tuple of `class`:torch.Tensor of losses, predicts and targets with shape (N, ...) where N is batch size
"""
if is_train:
self._optimizer.zero_grad()
res = self.predict(batch, is_train)
target = self._target_to_device(self._target_preprocess(self._pick_target(batch)))
loss = self._criterion(res, target)
if is_train:
loss.backward()
self._optimizer.step()
return loss, res, target
[docs] def update_lr(self, lr: float) -> None:
"""
Update learning rate straight to optimizer
:param lr: target learning rate
"""
for param_group in self._optimizer.param_groups:
param_group['lr'] = lr
[docs] def get_lr(self) -> float:
"""
Get learning rate from optimizer
"""
for param_group in self._optimizer.param_groups:
return param_group['lr']
[docs] def get_state(self) -> {}:
"""
Get model and optimizer state dicts
:return: dict with keys [weights, optimizer]
"""
return {'weights': self._model.model().state_dict(), 'optimizer': self._optimizer.state_dict()}
[docs] def save_state(self, path: str) -> None:
"""
Save state of optimizer and perform epochs number
"""
torch.save(self.optimizer().state_dict(), path)
[docs] def set_pick_target(self, pick_target: callable) -> 'DataProcessor':
"""
Set callback, that will get output from :mod:`DataLoader` and return target.
Default mode:
.. highlight:: python
.. code-block:: python
lambda data: data['target']
Args:
pick_target (callable): pick target callable. This callback need to get one parameter: dataset output
Returns:
self object
Examples:
.. highlight:: python
.. code-block:: python
data_processor.set_pick_target(lambda data: data['target'])
data_processor.set_pick_target(lambda data: data[1])
"""
self._pick_target = pick_target
return self
[docs] def set_target_preprocess(self, target_preprocess: callable) -> 'DataProcessor':
"""
Set callback, that will get output from :mod:`DataLoader` and return preprocessed target.
For example may be used for pass target to device.
Default mode:
.. highlight:: python
.. code-block:: python
:meth:`_pass_target_to_device`
Args:
target_preprocess (callable): preprocess callable. This callback need to get one parameter: targetset output
Returns:
self object
Examples:
.. highlight:: python
.. code-block:: python
from piepline.utils import dict_recursive_bypass
target_processor.set_target_preprocess(lambda target: dict_recursive_bypass(target, lambda v: v.cuda()))
"""
self._target_preprocess = target_preprocess
return self
def set_target_to_device(self, target_to_device: callable) -> 'DataProcessor':
self._target_to_device = target_to_device
return self