Commencer avec JAX
'JAX pour débutants'
Alimenter l’avenir du calcul numérique haute performance et de la recherche en ML
Introduction
JAX est une bibliothèque Python développée par Google pour effectuer des calculs numériques haute performance sur n’importe quel type d’appareil (CPU, GPU, TPU, etc.). Une des principales applications de JAX est la recherche et le développement en apprentissage automatique et en apprentissage profond, bien que la bibliothèque soit principalement conçue pour fournir toutes les capacités nécessaires pour effectuer des tâches de calcul scientifique à usage général (opérations sur des matrices de dimensions élevées, etc.).
En ce qui concerne plus particulièrement le calcul haute performance, JAX a été conçu pour être extrêmement rapide en étant construit sur XLA (Accélération des opérations linéaires). XLA est en fait un compilateur conçu pour accélérer les opérations d’algèbre linéaire et peut être utilisé pour travailler également avec d’autres frameworks tels que TensorFlow et Pytorch. De plus, les tableaux JAX ont été conçus pour suivre les mêmes principes que Numpy, ce qui facilite vraiment la migration du code Numpy ancien vers JAX et permet de bénéficier d’accélérations de performances grâce aux GPU et aux TPU.
Certaines des principales caractéristiques de JAX sont les suivantes :
- Compilation Just in Time (JIT) : La compilation JIT et le matériel accéléré sont ce qui permet à JAX d’être beaucoup plus rapide que Numpy simple. En utilisant la fonction jit(), il est possible de compiler et mettre en cache des fonctions personnalisées avec le noyau XLA. En utilisant la mise en cache, nous augmentons le temps d’exécution global lors de la première exécution de la fonction, puis réduisons considérablement le temps pour les exécutions suivantes. Lors de l’utilisation de la mise en cache, il est important de s’assurer de vider les caches lorsque cela est nécessaire afin d’éviter des résultats obsolètes (par exemple, des variables globales qui changent).
- Parallélisation automatique : La répartition asynchrone permet aux vecteurs JAX d’être évalués de manière paresseuse, en matérialisant le contenu uniquement lorsqu’il est accédé (le contrôle est renvoyé au programme avant la fin du calcul). De plus, afin de rendre possible l’optimisation graphique, les tableaux JAX sont immuables (des concepts similaires avec l’évaluation paresseuse et l’optimisation graphique s’appliquent à Apache Spark). La fonction pmap() peut être utilisée pour paralléliser les calculs sur plusieurs GPU/TPU.
- Vectorisation automatique : La vectorisation automatique pour paralléliser les opérations peut être effectuée à l’aide de la fonction vmap(). Pendant la vectorisation, un algorithme est transformé pour opérer avec un seul valeur à un ensemble de valeurs.
- Différenciation automatique : La fonction grad() peut être utilisée pour calculer automatiquement le gradient (dérivée) des fonctions. En particulier, la différenciation automatique de JAX permet le développement de programmes différentiels à usage général en dehors du spectre de l’apprentissage profond. Cela permet de différencier à travers la récursivité, les branches, les boucles, d’effectuer une différenciation d’ordre supérieur (par exemple, les jacobins et les hessiens) et d’utiliser à la fois la différenciation en mode direct et en mode inverse.
Par conséquent, JAX est capable de nous fournir toutes les bases nécessaires pour construire des modèles d’apprentissage profond avancés, mais pas des utilitaires de haut niveau prêts à l’emploi pour certaines des opérations d’apprentissage profond les plus courantes (par exemple, les fonctions de perte/activation, les couches, etc.). Par exemple, les paramètres du modèle appris lors de l’entraînement en apprentissage automatique peuvent être stockés dans une structure Pytree dans JAX. Compte tenu de tous les avantages offerts par JAX, différents frameworks orientés DL ont été construits dessus, tels que Haiku (utilisé par DeepMind) et Flax (utilisé par Google Brain).
- Fondamentaux de la détection d’anomalies avec la distribution gaussienne multivariée
- Apprentissage automatique à effets mixtes pour les variables catégorielles à haute cardinalité – Partie I Une comparaison empirique de différentes méthodes.
- Des scientifiques du MIT ont construit un système capable de générer des modèles d’IA pour la recherche en biologie.
Démonstration
Dans le cadre de cet article, nous allons maintenant voir comment résoudre un problème de classification simple en utilisant JAX et l’ensemble de données de classification des prix des téléphones mobiles Kaggle [1] pour prédire dans quelle fourchette de prix un téléphone se trouvera. Tout le code utilisé tout au long de cet article (et plus encore !) est disponible sur mes comptes GitHub et Kaggle .
Tout d’abord, nous devons nous assurer d’avoir JAX installé dans notre environnement.
pip install jax
À ce stade, nous sommes prêts à importer les bibliothèques et les ensembles de données nécessaires (Figure 1). Afin de simplifier notre analyse, au lieu d’utiliser toutes les classes de notre étiquette, nous filtrons les données pour n’utiliser que 2 classes et réduire le nombre de fonctionnalités.
import pandas as pdimport jax.numpy as jnpfrom jax import gradfrom sklearn.preprocessing import StandardScalerfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import classification_reportimport matplotlib.pyplot as pltdf = pd.read_csv('/kaggle/input/mobile-price-classification/train.csv')df = df.iloc[:, 10:]df = df.loc[df['price_range'] <= 1]df.head()

Une fois le jeu de données nettoyé, nous pouvons maintenant le diviser en sous-ensembles d’entraînement et de test et standardiser les caractéristiques d’entrée afin de nous assurer qu’elles se situent toutes dans les mêmes plages. À ce stade, les données d’entrée sont également converties en tableaux JAX.
X = df.iloc[:, :-1]y = df.iloc[:, -1]X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, stratify=y)X_train, X_test, y_train, Y_test = jnp.array(X_train), jnp.array(X_test), \ jnp.array(y_train), jnp.array(y_test)scaler = StandardScaler()scaler.fit(X_train)X_train = scaler.transform(X_train)X_test = scaler.transform(X_test)
Afin de prédire la plage de prix des téléphones, nous allons créer un modèle de régression logistique à partir de zéro. Pour ce faire, nous avons d’abord besoin de créer quelques fonctions d’aide (une pour créer la fonction d’activation Sigmoid, et une autre pour la fonction de perte binaire).
def activation(r): return 1 / (1 + jnp.exp(-r))def loss(c, w, X, y, lmbd=0.1): p = activation(jnp.dot(X, w) + c) loss = jnp.sum(y * jnp.log(p) + (1 - y) * jnp.log(1 - p)) / y.size reg = 0.5 * lmbd * (jnp.dot(w, w) + c * c) return - loss + reg
Nous sommes maintenant prêts à créer notre boucle d’entraînement et à tracer les résultats (Figure 2).
n_iter, eta = 100, 1e-1w = 1.0e-5 * jnp.ones(X.shape[1])c = 1.0history = [float(loss(c, w, X_train, y_train))]for i in range(n_iter): c_current = c c -= eta * grad(loss, argnums=0)(c_current, w, X_train, y_train) w -= eta * grad(loss, argnums=1)(c_current, w, X_train, y_train) history.append(float(loss(c, w, X_train, y_train)))

Une fois satisfaits des résultats, nous pouvons alors tester le modèle sur notre ensemble de test (Figure 3).
y_pred = jnp.array(activation(jnp.dot(X_test, w) + c))y_pred = jnp.where(y_pred > 0.5, 1, 0) print(classification_report(y_test, y_pred))

Conclusion
Comme le démontre cet exemple succinct, JAX dispose d’une API très intuitive qui suit de près les conventions de Numpy tout en permettant d’utiliser le même code pour une utilisation CPU/GPU/TPU. En utilisant ces blocs de construction, il est alors possible de créer des modèles d’apprentissage profond hautement personnalisables optimisés par conception pour les performances.
Contacts
Si vous souhaitez rester informé de mes derniers articles et projets, suivez-moi sur VoAGI et abonnez-vous à ma liste de diffusion. Voici quelques-uns de mes coordonnées :
- Site Web Personnel
- Profil VoAGI
- GitHub
- Kaggle
Bibliographie
[1] “Classification des prix des téléphones mobiles” (ABHISHEK SHARMA). Consulté sur : https://thecleverprogrammer.com/2021/03/05/mobile-price-classification-with-machine-learning/ (Licence MIT : https://github.com/alifrmf/Mobile-Price-Prediction-Classification-Analysis/tree/main )
We will continue to update IPGirl; if you have any questions or suggestions, please contact us!
Was this article helpful?
93 out of 132 found this helpful
Related articles
- Apprendre le langage des molécules pour prédire leurs propriétés
- Rencontrez JourneyDB un ensemble de données à grande échelle comprenant 4 millions d’images variées et de haute qualité générées, sélectionnées pour la compréhension visuelle multimodale.
- Commencer avec Amazon SageMaker Ground Truth
- Ce document AI présente DreamDiffusion un modèle de pensées vers image pour générer directement des images de haute qualité à partir des signaux cérébraux EEG.
- Une approche fondée pour faire évoluer le choix et le contrôle du contenu web
- Intégrez les plateformes SaaS avec Amazon SageMaker pour permettre des applications alimentées par l’apprentissage automatique.
- Annonce du premier défi de désapprentissage automatique