Skip to main content

NeuroTorch: A PyTorch-based framework for deep learning in neuroscience.

Project description

Star on GitHub Python 3.6 License

Tests Workflow Dist Workflow Doc Workflow Publish Workflow

Description

It's time to bring deep learning and neuroscience together. In this library, we offer machine learning tools to neuroscientists and we offer neuroscience tools to computer scientists. These two domains were created to be one.

Current Version (v0.0.1-alpha)

What can we do with NeuroTorch in the current version?

  • Image classification with spiking networks.
  • Classification of spiking time series with spiking networks.
  • Time series classification with spiking or Wilson-Cowan.
  • Reconstruction/Prediction of time series with Wilson-Cowan;
  • Reconstruction/Prediction of continuous time series with spiking networks.
  • Backpropagation Through Time.
  • Anything you are able to do using the modules already created.

Next Version (v0.0.1-beta)

Upcoming Version (v0.0.1)

  • Reinforcement Learning.

NeuroTorch is developed to be easy to use, so that you can do simple things in a few lines of code. Moreover NeuroTorch is modular so you can adapt it to your needs relatively quickly. Thanks and stay tuned, because more is coming!

This package is part of a postgraduate research project realized by Jérémie Gince and supervised by Simon Hardy and Patrick Desrosiers. Our work was supported by: (1) UNIQUE, a FRQNT-funded research center, (2) the Sentinelle Nord program of Université Laval, funded by the Canada First Research Excellence Fund, and (3) NSERC.

Important Links

Installation

Using pip

pip install neurotorch

With wheel:

  1. Download the .whl file here;
  2. Copy the path of this file on your computer;
  3. pip install it with pip install [path].whl

With pip+git:

pip install git+https://github.com/NeuroTorch/NeuroTorch

Tutorials / Applications

See the readme of the tutorials folder here.

Image classification with spiking networks (Mnist/Fashion-Mnist)

Classification of spiking time series (Heidelberg)

Time series classification with spiking networks

Sorry, it's a work in progress, so it's not publish yet.

Time series classification with Wilson-Cowan

Quick usage preview

import neurotorch as nt
import torch
import pprint

n_hidden_neurons = 128
checkpoint_folder = "./checkpoints/checkpoint_000"
checkpoint_manager = nt.CheckpointManager(checkpoint_folder)
dataloaders = get_dataloaders(
	batch_size=256,
	train_val_split_ratio=0.95,
)

network = nt.SequentialModel(
	layers=[
		nt.LIFLayer(
			input_size=nt.Size([
				nt.Dimension(None, nt.DimensionProperty.TIME),
				nt.Dimension(dataloaders["test"].dataset.n_units, nt.DimensionProperty.NONE)
			]),
			output_size=n_hidden_neurons,
			use_recurrent_connection=True,
		),
		nt.SpyLILayer(output_size=dataloaders["test"].dataset.n_classes),
	],
	name=f"Network",
	checkpoint_folder=checkpoint_folder,
).build()

trainer = nt.ClassificationTrainer(
	model=network,
	optimizer=torch.optim.Adam(network.parameters(), lr=1e-3),
	callbacks=[
        checkpoint_manager,
    ],
	verbose=True,
)
training_history = trainer.train(
	dataloaders["train"],
	dataloaders["val"],
	n_iterations=100,
	load_checkpoint_mode=nt.LoadCheckpointMode.LAST_ITR,
)
training_history.plot(show=True)

network.load_checkpoint(checkpoint_manager.checkpoints_meta_path, nt.LoadCheckpointMode.BEST_ITR, verbose=True)
predictions = {
	k: nt.metrics.ClassificationMetrics.compute_y_true_y_pred(network, dataloader, verbose=True, desc=f"{k} predictions")
	for k, dataloader in dataloaders.items()
}
accuracies = {
	k: nt.metrics.ClassificationMetrics.accuracy(network, y_true=y_true, y_pred=y_pred)
	for k, (y_true, y_pred) in predictions.items()
}
pprint.pprint(accuracies)

Found a bug or have a feature request?

Thanks

  • Anthony Drouin who helped develop the Wilson-Cowan application during his 2022 summer internship.
  • Antoine Légaré who made the awesome logo of NeuroTorch.
  • To my dog Chewy who has been a great help during the whole development.

License

Apache License 2.0

Citation

@misc{Gince2022,
  title={NeuroTorch: Deep Learning Python Library for Machine Learning and Neuroscience.},
  author={Jérémie Gince},
  year={2022},
  publisher={Université Laval},
  url={https://github.com/NeuroTorch},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

NeuroTorch-0.0.1a0.tar.gz (86.1 kB view hashes)

Uploaded Source

Built Distribution

NeuroTorch-0.0.1a0-py3-none-any.whl (105.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page