Skorch on Ray Train
Project description
ray-skorch
Distributed skorch on Ray Train
skorch-based wrapper for Ray Train. Experimental!
- Run
pip install -e .
to install necessary packages - Upon push, run
./format.sh
to make sure lint changes are applied appropriately. - The current working examples can be found in
examples
.
:warning:
RayTrainNeuralNet
and the rest of this package are experimental and not production ready. In particular, validation and error handling may be spotty. If you encounter any problems or have any suggestions please open an issue on GitHub.
Known issues & missing features
- Only numpy arrays, pandas dataframes and Ray Data Datasets are supported as inputs.
- Compatibility with scikit-learn hyperparameter tuners is not tested.
Basic example
The only breaking API difference compared to skorch
is the addition of a new num_workers
argument, contolling how many Ray workers to use for training. Please refer to docstrings for more information on other changes.
With numpy/pandas
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
from ray_skorch import RayTrainNeuralNet
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)
class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=nn.ReLU()):
super(MyModule, self).__init__()
self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, num_units)
self.output = nn.Linear(num_units, 2)
self.softmax = nn.Softmax(dim=-1)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = self.nonlin(self.dense1(X))
X = self.softmax(self.output(X))
return X
net = RayTrainNeuralNet(
MyModule,
num_workers=2, # the only new mandatory argument
criterion=nn.CrossEntropyLoss,
max_epochs=10,
lr=0.1,
# required for classification loss funcs
iterator_train__unsqueeze_label_tensor=False,
iterator_valid__unsqueeze_label_tensor=False,
)
net.fit(X, y)
# predict_proba returns a ray.data.Dataset
y_proba = net.predict_proba(X).to_pandas()
print(y_proba)
With Ray Data
import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from torch import nn
from ray.data import from_pandas
from ray_skorch import RayTrainNeuralNet
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = pd.DataFrame(X.astype(np.float32))
y = pd.Series(y.astype(np.int64))
X_pred = X.copy()
X["target"] = y
X = from_pandas(X)
# ensure no target column is in data for prediction
X_pred = from_pandas(X_pred)
class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=nn.ReLU()):
super(MyModule, self).__init__()
self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, num_units)
self.output = nn.Linear(num_units, 2)
self.softmax = nn.Softmax(dim=-1)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = self.nonlin(self.dense1(X))
X = self.softmax(self.output(X))
return X
net = RayTrainNeuralNet(
MyModule,
num_workers=2, # the only new mandatory argument
criterion=nn.CrossEntropyLoss,
max_epochs=10,
lr=0.1,
# required for classification loss funcs
iterator_train__unsqueeze_label_tensor=False,
iterator_valid__unsqueeze_label_tensor=False,
)
net.fit(X, "target")
# predict_proba returns a ray.data.Dataset
y_proba = net.predict_proba(X_pred).to_pandas()
print(y_proba)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
No source distribution files available for this release.See tutorial on generating distribution archives.
Built Distribution
ray_skorch-0.0.1-py3-none-any.whl
(30.8 kB
view hashes)
Close
Hashes for ray_skorch-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 772ea84902e3c5f629c66f7ecc1a57ddc9fd272a0a4e6ebff0270e2d3dbdaf95 |
|
MD5 | 3109c4cb59183243f60dd6174b757c14 |
|
BLAKE2b-256 | 8e3e9dc43ef74f56612f3a1fa6008fb2720c41e037a67897ae48f01e7894337f |