Skip to main content

Neural Network Genetic Algorithm library used for deep learning problems

Project description

nnGA Library - Neural Network Genetic Algorithm Library (v0.0.5)

Off the shelf Genetic Algorithm library for deep learning problems

License

Our code is released under the MIT license (refer to the LICENSE file for details).

Requirements

To use the library you need atleast Python 3.6. Examples may require additional libraries.

Other required dependencies:

  • NumPy
  • Neptune

Usage/Examples

You can import the library by typing pip install faris-lab-train-model.

To learn how to use neptune, check the following examples:

import neptune.new as neptune
from nnga import nnGA, GaussianInitializationStrategy, \
    GaussianMutationStrategy, BasicCrossoverStrategy, \
    PopulationParameters

def make_network(parameters=None):
    ''' Function that creates a network given a set of parameters '''
    neural_network = ...
    return neural_network


def fitness(idx, parameters):
    ''' Fitness function to evaluate a set of parameters '''
    # Evaluate parameters
    network = make_network(parameters)
    return evaluate_network(network)


if __name__ == '__main__':
    # Initialize GA parameters
    network = make_initial_network()
    network_structure = [list(layer.shape) for layer in network]  # List of tuples, containing the shape of each layer
    
    # Population parameters
    population = PopulationParameters(population_size=200)
    
    # Mutation strategy
    mutation = GaussianMutationStrategy(network_structure, 1e-1)
    
    # Crossover strategy
    crossover = BasicCrossoverStrategy(network_structure)
    
    # Initialization strategy
    init = GaussianInitializationStrategy(
        mean=0., std=1., network_structure=network_structure)

    ga = nnGA(
        epochs=50,  # Number of epochs
        fitness_function=fitness,
        population_parameters=population,
        mutation_strategy=mutation,
        initialization_strategy=init,
        crossover_strategy=crossover,
        num_processors=8)  # Number of cores

    # Run GA with neptune
    run = neptune.init(project="common/quickstarts", 
                        api_token="ANONYMOUS",
                        ga)

In general the code has the following structure

from nnga import nnGA, GaussianInitializationStrategy, \
    GaussianMutationStrategy, BasicCrossoverStrategy, \
    PopulationParameters

def make_network(parameters=None):
    ''' Function that creates a network given a set of parameters '''
    neural_network = ...
    return neural_network


def fitness(idx, parameters):
    ''' Fitness function to evaluate a set of parameters '''
    # Evaluate parameters
    network = make_network(parameters)
    return evaluate_network(network)


if __name__ == '__main__':
    # Initialize GA parameters
    network = make_initial_network()
    network_structure = [list(layer.shape) for layer in network]  # List of tuples, containing the shape of each layer
    
    # Population parameters
    population = PopulationParameters(population_size=200)
    
    # Mutation strategy
    mutation = GaussianMutationStrategy(network_structure, 1e-1)
    
    # Crossover strategy
    crossover = BasicCrossoverStrategy(network_structure)
    
    # Initialization strategy
    init = GaussianInitializationStrategy(
        mean=0., std=1., network_structure=network_structure)

    ga = nnGA(
        epochs=50,  # Number of epochs
        fitness_function=fitness,
        population_parameters=population,
        mutation_strategy=mutation,
        initialization_strategy=init,
        crossover_strategy=crossover,
        num_processors=8)  # Number of cores

    # Run GA
    network_parameters, best_result, results = ga.run()

License: MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

faris-lab-train-model-0.0.10.tar.gz (9.4 kB view hashes)

Uploaded Source

Built Distribution

faris_lab_train_model-0.0.10-py3-none-any.whl (10.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page