Skip to main content

A Keras model zoo with pretrained weights.

Project description

KIMM

Keras PyPI Contributions Welcome GitHub Workflow Status codecov

Keras Image Models

Introduction

Keras Image Models (kimm) is a collection of image models, blocks and layers written in Keras 3. The goal is to offer SOTA models with pretrained weights in a user-friendly manner.

KIMM is:

  • 🚀 A model zoo where almost all models come with pre-trained weights on ImageNet.
  • 🧰 Providing APIs to export models to .tflite and .onnx.
  • 🔧 Supporting the reparameterization technique.
  • ✨ Integrated with feature extraction capability.

Usage

  • kimm.list_models
  • kimm.models.*.available_feature_keys
  • kimm.models.*(...)
  • kimm.models.*(..., feature_extractor=True, feature_keys=[...])
  • kimm.utils.get_reparameterized_model
  • kimm.export.export_tflite
  • kimm.export.export_onnx
import keras
import kimm
import numpy as np


# List available models
print(kimm.list_models("mobileone", weights="imagenet"))
# ['MobileOneS0', 'MobileOneS1', 'MobileOneS2', 'MobileOneS3']

# Initialize model with pretrained ImageNet weights
x = keras.random.uniform([1, 224, 224, 3])
model = kimm.models.MobileOneS0()
y = model.predict(x)
print(y.shape)
# (1, 1000)

# Get reparameterized model by kimm.utils.get_reparameterized_model
reparameterized_model = kimm.utils.get_reparameterized_model(model)
y2 = reparameterized_model.predict(x)
np.testing.assert_allclose(
    keras.ops.convert_to_numpy(y), keras.ops.convert_to_numpy(y2), atol=1e-5
)

# Export model to tflite format
kimm.export.export_tflite(reparameterized_model, 224, "model.tflite")

# Export model to onnx format (note: must be "channels_first" format)
# kimm.export.export_onnx(reparameterized_model, 224, "model.onnx")

# List available feature keys of the model class
print(kimm.models.MobileOneS0.available_feature_keys)
# ['STEM_S2', 'BLOCK0_S4', 'BLOCK1_S8', 'BLOCK2_S16', 'BLOCK3_S32']

# Enable feature extraction by setting `feature_extractor=True`
# `feature_keys` can be optionally specified
model = kimm.models.MobileOneS0(
    feature_extractor=True, feature_keys=["BLOCK2_S16", "BLOCK3_S32"]
)
features = model.predict(x)
for feature_name, feature in features.items():
    print(feature_name, feature.shape)
# BLOCK2_S16 (1, 14, 14, 256)
# BLOCK3_S32 (1, 7, 7, 1024)
# TOP (1, 1000)

Installation

pip install keras kimm -U

Quickstart

Image classification using the model pretrained on ImageNet

Open In Colab

Using kimm.models.VisionTransformerTiny16:

african_elephant
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
Predicted: [('n02504458', 'African_elephant', 0.6895825), ('n01871265', 'tusker', 0.17934209), ('n02504013', 'Indian_elephant', 0.12927249)]

An end-to-end example: fine-tuning an image classification model on a cats vs. dogs dataset

Open In Colab

Using kimm.models.EfficientNetLiteB0:

kimm_prediction_0 kimm_prediction_1

Reference: Transfer learning & fine-tuning (keras.io)

Grad-CAM

Open In Colab

Using kimm.models.MobileViTS:

grad_cam

Reference: Grad-CAM class activation visualization (keras.io)

Model Zoo

Model Paper Weights are ported from API
ConvMixer ICLR 2022 Submission timm kimm.models.ConvMixer*
ConvNeXt CVPR 2022 timm kimm.models.ConvNeXt*
DenseNet CVPR 2017 timm kimm.models.DenseNet*
EfficientNet ICML 2019 timm kimm.models.EfficientNet*
EfficientNetLite ICML 2019 timm kimm.models.EfficientNetLite*
EfficientNetV2 ICML 2021 timm kimm.models.EfficientNetV2*
GhostNet CVPR 2020 timm kimm.models.GhostNet*
GhostNetV2 NeurIPS 2022 timm kimm.models.GhostNetV2*
HGNet timm kimm.models.HGNet*
HGNetV2 timm kimm.models.HGNetV2*
InceptionNeXt arXiv 2023 timm kimm.models.InceptionNeXt*
InceptionV3 CVPR 2016 timm kimm.models.InceptionV3
LCNet arXiv 2021 timm kimm.models.LCNet*
MobileNetV2 CVPR 2018 timm kimm.models.MobileNetV2*
MobileNetV3 ICCV 2019 timm kimm.models.MobileNetV3*
MobileOne CVPR 2023 timm kimm.models.MobileOne*
MobileViT ICLR 2022 timm kimm.models.MobileViT*
MobileViTV2 arXiv 2022 timm kimm.models.MobileViTV2*
RegNet CVPR 2020 timm kimm.models.RegNet*
RepVGG CVPR 2021 timm kimm.models.RepVGG*
ResNet CVPR 2015 timm kimm.models.ResNet*
TinyNet NeurIPS 2020 timm kimm.models.TinyNet*
VGG ICLR 2015 timm kimm.models.VGG*
ViT ICLR 2021 timm kimm.models.VisionTransformer*
Xception CVPR 2017 keras kimm.models.Xception

The export scripts can be found in tools/convert_*.py.

License

Please refer to timm as this project is built upon it.

kimm Code

The code here is licensed Apache 2.0.

Acknowledgements

Thanks for these awesome projects that were used in kimm

Citing

BibTeX

@misc{rw2019timm,
  author = {Ross Wightman},
  title = {PyTorch Image Models},
  year = {2019},
  publisher = {GitHub},
  journal = {GitHub repository},
  doi = {10.5281/zenodo.4414861},
  howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
}
@misc{hy2024kimm,
  author = {Hongyu Chiu},
  title = {Keras Image Models},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/james77777778/kimm}}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kimm-0.1.8.tar.gz (62.1 kB view hashes)

Uploaded Source

Built Distribution

kimm-0.1.8-py3-none-any.whl (90.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page