Skip to main content

A tensorflow 2.0 implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction

Project description

Tensorflow 2 DA-RNN

A Tensorflow 2 (Keras) implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction

Paper: https://arxiv.org/abs/1704.02971

Install

For Tensorflow 2

pip install da-rnn[keras]

For PyTorch

pip install da-rnn[torch]

Usage

For Tensorflow 2

from da_rnn.keras import DARNN

model = DARNN(T=10, m=128)

# Train
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=100,
    verbose=1
)

# Predict
y_hat = model(inputs)

For PyTorch (with poutyne)

import torch
from poutyne import Model
from da_rnn.torch import DARNN

darnn = DARNN(n=50, T=10, m=128)
model = Model(darnn)

# Train
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=100,
    verbose=1
)

# Predict
with torch.no_grad():
    y_hat = model(inputs)

Python Docstring Notations

In docstrings of the methods of this project, we have the following notation convention:

variable_{subscript}__{superscript}

For example:

  • y_T__i means y_T__i, the i-th prediction value at time T.
  • alpha_t__k means alpha_t__k, the attention weight measuring the importance of the k-th input feature (driving series) at time t.

DARNN(T, m, p, y_dim=1)

DARNN(n, T, m, p, y_dim=1)

The naming of the following (hyper)parameters is consistent with the paper, except y_dim which is not mentioned in the paper.

  • n (torch only) int input size, the number of features of a single driving series
  • T int the length (time steps) of the window
  • m int the number of the encoder hidden states
  • p int the number of the decoder hidden states
  • y_dim int=1 the prediction dimention. Defaults to 1.

Return the DA-RNN model instance.

Data Processing

Each feature item of the dataset should be of shape (batch_size, T, length_of_driving_series + y_dim)

And each label item of the dataset should be of shape (batch_size, y_dim)

Development

Install dependencies:

make install

Run notebook:

cd notebook
jupyter lab

TODO

  • no hardcoding (1 for now) for prediction dimentionality

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

da-rnn-1.0.2.tar.gz (8.5 kB view hashes)

Uploaded Source

Built Distribution

da_rnn-1.0.2-py3-none-any.whl (16.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page