Inférence variationnelle Les bases

Variational Inference Basics.

Nous vivons à l’ère de la quantification. Mais la quantification rigoureuse est plus facile à dire qu’à faire. Dans des systèmes complexes tels que la biologie, les données peuvent être difficiles et coûteuses à collecter. Tandis que dans des applications à enjeux élevés, telles que la médecine et la finance, il est crucial de tenir compte de l’incertitude. L’inférence variationnelle – une méthodologie à la pointe de la recherche en IA – est une façon de prendre en compte ces aspects.

Ce tutoriel vous initie aux bases : le quand, le pourquoi et le comment de l’inférence variationnelle.

Quand l’inférence variationnelle est-elle utile ?

L’inférence variationnelle est intéressante dans les trois cas d’utilisation suivants, qui sont étroitement liés :

1. si vous avez peu de données (c’est-à-dire un faible nombre d’observations),

2. si vous vous souciez de l’incertitude,

3. pour la modélisation générative.

Nous aborderons chaque cas d’utilisation dans notre exemple pratique.

1. Inférence variationnelle avec peu de données

Fig. 1: L'inférence variationnelle vous permet de faire un compromis entre les connaissances de domaine et les informations provenant d'exemples. Image de l'auteur.

Parfois, la collecte de données est coûteuse. Par exemple, les mesures d’ADN ou d’ARN peuvent facilement coûter plusieurs milliers d’euros par observation. Dans ce cas, vous pouvez programmer en dur des connaissances de domaine à la place d’échantillons supplémentaires. L’inférence variationnelle peut aider à “réduire” systématiquement les connaissances de domaine à mesure que vous collectez plus d’exemples, et à vous appuyer plus fortement sur les données (Fig. 1).

2. Inférence variationnelle pour l’incertitude

Pour les applications critiques en matière de sécurité, telles que la finance et les soins de santé, l’incertitude est importante. L’incertitude peut affecter tous les aspects du modèle, le résultat prévu étant le plus évident. Les paramètres du modèle (par exemple, les poids et les biais) sont moins évidents. Au lieu des tableaux habituels de nombres – les poids et les biais – vous pouvez doter les paramètres d’une distribution pour les rendre flous. L’inférence variationnelle vous permet d’inférer la ou les plages de valeurs raisonnables.

3. Inférence variationnelle pour la modélisation générative

Les modèles génératifs fournissent une spécification complète de la façon dont les données ont été générées. Par exemple, comment générer une image d’un chat ou d’un chien. Habituellement, il y a une représentation latente z qui porte un sens sémantique (par exemple, z décrit un chat siamois). À travers un ensemble d’étapes de transformation (non linéaires) et d’échantillonnage, z est transformé en l’image réelle x (par exemple, les valeurs de pixels du chat siamois). L’inférence variationnelle est une façon d’inférer, et d’échantillonner à partir de, l’espace sémantique latent z. Un exemple bien connu est l’auto-encodeur variationnel.

Qu’est-ce que l’inférence variationnelle ?

Fondamentalement, l’inférence variationnelle est une entreprise bayésienne [1]. Dans la perspective bayésienne, vous laissez toujours la machine apprendre à partir des données, comme d’habitude. Ce qui est différent, c’est que vous donnez au modèle un indice (une priorité) et que vous autorisez la solution (la posterior) à être plus floue. Plus concrètement, disons que vous avez un ensemble d’entraînement X = [ x ₁, x ₂,.., x ₘ ]ᵗ de m exemples. Nous utilisons le théorème de Bayes :

p ( Θ | X ) = p ( X | Θ ) p ( Θ ) / p ( X ),

pour inférer une plage – une distribution – de solutions Θ. Par opposition à l’approche classique de l’apprentissage automatique, où nous minimisons une perte ℒ( Θ, X ) = ln p ( X | Θ ) pour trouver une solution spécifique Θ. L’inférence bayésienne consiste à trouver un moyen de déterminer p ( Θ | X ) : la distribution postérieure des paramètres Θ étant donné l’ensemble d’entraînement X. En général, c’est un problème difficile. En pratique, deux méthodes sont utilisées pour résoudre p ( Θ | X) : (i) en utilisant la simulation (chaîne de Markov Monte Carlo) ou (ii) par optimisation.

L’inférence variationnelle concerne l’option (ii).

La borne inférieure de l’évidence (ELBO)

Fig. 2: Schéma de l'inférence variationnelle. Nous cherchons une distribution q(Θ) qui est proche de p(Θ|X). Image by Author.

L’idée derrière l’inférence variationnelle est de chercher une distribution q ( Θ ) qui est un substitut (un proxy) pour p ( Θ | X ). Nous essayons ensuite de rendre q ( Θ|Φ ) similaire à p ( Θ | X ) en modifiant les valeurs de Φ (Fig. 2). Cela se fait en maximisant la borne inférieure de l’évidence (ELBO) :

ℒ ( Φ ) = E[ln p ( X , Θ ) — ln q ( Θ|Φ) ],

où l’espérance E[·] est prise sur q ( Θ|Φ ). À première vue, il semble que nous devons être prudents lors de la prise de dérivées (par rapport à Φ ) en raison de la dépendance de E[·] sur q ( Θ|Φ ). Heureusement, les packages d’autograd comme JAX prennent en charge des astuces de reparamétrisation [2] qui vous permettent de prendre directement des dérivées à partir d’échantillons aléatoires (par exemple, de la distribution gamma) au lieu de vous appuyer sur des approches variationnelles à boîte noire à variance élevée [3]. En bref : estimez ∇ℒ(Φ) avec un lot [ Θ ₁, Θ ₂,..] ~ q ( Θ|Φ ) et laissez votre package d’autograd s’occuper des détails.

Inférence variationnelle à partir de zéro

Fig. 3: Exemple d'une image de zéro manuscrit provenant de l'ensemble de données de chiffres de sci-kit learn. Image by Author.

Pour consolider notre compréhension, implémentons l’inférence variationnelle à partir de zéro en utilisant JAX. Dans cet exemple, vous entraînerez un modèle génératif sur des chiffres manuscrits de sci-kit learn. Vous pouvez suivre le notebook de Colab.

Pour simplifier, nous n’analyserons que le chiffre « zéro ».

from sklearn import datasetsdigits = datasets.load_digits()is_zero = digits.target == 0X_train = digits.images[is_zero]# Aplatir la grille d'image en un vecteur.n_pixels = 64  # 8 par 8.X_train = X_train.reshape((-1, n_pixels))

Chaque image est un tableau 8 par 8 de valeurs de pixels discrètes allant de 0 à 16. Comme les pixels sont des données de comptage, modélisons les pixels, x, en utilisant la distribution de Poisson avec une priori gamma pour le taux Θ. Le taux Θ détermine l’intensité moyenne des pixels. Ainsi, la distribution conjointe est donnée par :

p ( x , Θ ) = Poisson( x | Θ ) Gamma( Θ | a , b ),

a et b sont la forme et le taux de la distribution gamma.

Fig. 4: La connaissance du domaine du chiffre « zéro » est utilisée comme priorité. Image de l'auteur.

Le prior — dans ce cas, Gamma( Θ | a , b ) — est l’endroit où vous infusez votre connaissance du domaine (cas d’utilisation 1.). Par exemple, vous pouvez avoir une idée de ce à quoi ressemble en moyenne le chiffre zéro (Fig. 4). Vous pouvez utiliser cette information a priori pour guider votre choix de a et b . Pour utiliser la Figure 4 comme information a priori — appelons-la x ₀ — et pondérez son importance en tant que deux exemples, alors définissez a = 2 x ₀ ; b = 2.

Écrit en Python, cela ressemble à :

import jax.numpy as jnpimport jax.scipy as jsp# Hyperparameters of the model.a = 2. * x_domain_knowledgeb = 2.def log_joint(θ):  log_likelihood = jnp.sum(jsp.stats.gamma.logpdf(θ, a, scale=1./b))  log_likelihood += jnp.sum(jsp.stats.poisson.logpmf(X_train, θ))  return log_likelihood

Notez que nous avons utilisé l’implémentation de numpy et scipy de JAX, afin que nous puissions prendre des dérivées.

Ensuite, nous devons choisir une distribution de substitution q ( Θ|Φ ). Pour vous rappeler, notre objectif est de changer Φ afin que la distribution de substitution q ( Θ|Φ ) corresponde à p ( Θ|X) . Ainsi, le choix de q ( Θ ) détermine le niveau d’approximation (nous supprimons la dépendance de Φ lorsque le contexte le permet). À des fins d’illustration, choisissons une distribution variationnelle composée de (un produit de) gamma :

q ( Θ|Φ ) = Gamma( Θ | α , β ),

où nous avons utilisé le raccourci Φ = { α , β }.

Ensuite, pour implémenter la limite inférieure de l’évidence ℒ ( Φ ) = E[ln p ( X , Θ ) — ln q ( Θ|Φ )], écrivons d’abord le terme à l’intérieur des crochets d’expectation :

@partial(vmap, in_axes=(0, None, None))def evidence_lower_bound(θ_i, alpha, inv_beta):  elbo = log_joint(θ_i) - jnp.sum(jsp.stats.gamma.logpdf(θ_i, alpha, scale=inv_beta))  return elbo

Ici, nous avons utilisé le vmap de JAX pour vectoriser la fonction afin que nous puissions l’exécuter sur un lot [ Θ ₁, Θ ₂,.., Θ ₁₂₈]ᵗ.

Pour compléter la mise en œuvre de ℒ ( Φ ), nous moyennons la fonction ci-dessus sur des réalisations de la distribution variationnelle Θ ᵢ ~ q ( Θ ) :

def loss(Φ: dict, key):  """Estimation stochastique de la limite inférieure de l'évidence."""  alpha = jnp.exp(Φ['log_alpha'])  inv_beta = jnp.exp(-Φ['log_beta'])  # Échantillonner un lot à partir de la distribution variationnelle q.  batch_size = 128  batch_shape = [batch_size, n_pixels]  θ_samples = random.gamma(key, alpha , shape=batch_shape) * inv_beta    # Calculer l'estimation Monte Carlo de la limite inférieure de l'évidence.  elbo_loss = jnp.mean(evidence_lower_bound(θ_samples, alpha, inv_beta))    # Transformer l'ELBO en perte.  return -elbo_loss

Voici quelques points à remarquer à propos des arguments :

  • Nous avons empaqueté Φ en tant que dictionnaire (ou techniquement, une pytree) contenant ln( α ) et ln( β ). Cette astuce garantit que α >0 et β >0, une exigence imposée par la distribution gamma lors de l’optimisation.
  • La perte est une estimation aléatoire de l’ELBO. En JAX, nous avons besoin d’une nouvelle clé de générateur de nombres pseudo-aléatoires (PRNG) à chaque fois que nous échantillonnons. Dans ce cas, nous utilisons la clé pour échantillonner [ Θ ₁, Θ ₂,.., Θ ₁₂₈]ᵗ.

Cela complète la spécification du modèle p ( x , Θ) , de la distribution variationnelle q ( Θ ) et de la perte ℒ ( Φ ).

Entraînement du modèle

Ensuite, nous minimisons la perte ℒ ( Φ ) en variant Φ = { α , β } afin que q ( Θ|Φ ) corresponde à postérieur p ( Θ | X ). Comment ? En utilisant la descente de gradient à l’ancienne ! Pour plus de commodité, nous utilisons l’optimiseur Adam d’Optax et initialisons les paramètres avec le prior α = a et β = b [n’oubliez pas, le prior était Gamma( Θ | a , b ) et a codifié notre connaissance du domaine].

# Initialise parameters using prior.Φ = {    'log_alpha': jnp.log(a),    'log_beta': jnp.full(fill_value=jnp.log(b), shape=[n_pixels]),}loss_val_grad = jit(jax.value_and_grad(loss))optimiser = optax.adam(learning_rate=0.2)opt_state = optimiser.init(Φ)

Ici, nous utilisons value_and_grad pour évaluer simultanément l’ELBO et sa dérivée. Pratique pour surveiller la convergence ! Nous compilons ensuite la fonction résultante avec jit pour la rendre plus rapide.

Enfin, nous allons entraîner le modèle pendant 5000 étapes. Étant donné que la perte est aléatoire, pour chaque évaluation, nous devons fournir une clé de générateur de nombres pseudo-aléatoires (PRNG). Nous faisons cela en allouant 5000 clés avec random.split .

n_iter = 5_000keys = random.split(random.PRNGKey(42), num=n_iter)for i, key in enumerate(keys):  elbo, grads = loss_val_grad(Φ, key)  updates, opt_state = optimiser.update(grads, opt_state)  Φ = optax.apply_updates(Φ, updates)

Félicitations ! Vous avez réussi à former votre premier modèle en utilisant l’inférence variationnelle !

Vous pouvez accéder au notebook avec le code complet ici sur Colab .

Résultats

Fig. 5: Comparison of variational distribution with exact posterior distribution. Image by Author.

Revenons en arrière et apprécions ce que nous avons construit (Fig. 5). Pour chaque pixel, la substitution q ( Θ ) décrit l’incertitude sur l’intensité moyenne des pixels (cas d’utilisation 2.). En particulier, notre choix de q ( Θ ) capture deux éléments complémentaires :

  • L’intensité typique du pixel.
  • La variabilité de l’intensité d’une image à une autre.

Il s’avère que la distribution conjointe p ( x , Θ ) que nous avons choisie a une solution exacte :

p ( Θ|X) = Gamma( Θ | a + Σ x ᵢ, m + b ),

où m est le nombre d’échantillons dans l’ensemble d’entraînement X. Ici, nous voyons explicitement comment la connaissance du domaine – codifiée en a et b – est réduite à mesure que nous recueillons plus d’exemples xᵢ.

Nous pouvons facilement comparer la forme apprise α et le taux β avec les vraies valeurs a + Σ xᵢ et m + b. Dans la Fig. 4, nous comparons les distributions – q (Θ|Φ) versus p (Θ|X) — pour deux pixels spécifiques. Et voilà, une correspondance parfaite!

Bonus: génération d’images synthétiques

Fig. 6: Images synthétiques générées à l'aide de l'inférence variationnelle. Image de l'auteur.

L’inférence variationnelle est excellente pour la modélisation générative (cas d’utilisation 3.). Avec la distribution postérieure de substitution q (Θ) en main, la génération de nouvelles images synthétiques est triviale. Les deux étapes sont les suivantes:

  • Échantillonner les intensités de pixels Θ ~ q (Θ).
# Extraire les paramètres de q.alpha = jnp.exp(Φ['log_alpha'])inv_beta = jnp.exp(-Φ['log_beta'])# 1) Générer des intensités de pixels pour 10 images.key_θ, key_x = random.split(key)m_new_images = 10new_batch_shape = [m_new_images, n_pixels]θ_samples = random.gamma(key_θ, alpha , shape=new_batch_shape) * inv_beta
  • Échantillonner des images en utilisant x ~ Poisson (x | Θ).
# 2) Échantillonner des images à partir des intensités.X_synthetic = random.poisson(key_x, θ_samples)

Vous pouvez voir le résultat dans la Fig. 6. Remarquez que le caractère “zéro” est légèrement moins net que prévu. Cela faisait partie de nos hypothèses de modélisation : nous avons modélisé les pixels comme étant mutuellement indépendants plutôt que corrélés. Pour prendre en compte les corrélations de pixels, vous pouvez étendre le modèle pour regrouper les intensités de pixels : cela s’appelle la factorisation de Poisson [4].

Résumé

Dans ce tutoriel, nous avons introduit les bases de l’inférence variationnelle et l’avons appliquée à un exemple simplifié : apprendre le chiffre manuscrit zéro. Grâce à autograd, la mise en œuvre de l’inférence variationnelle à partir de zéro ne prend que quelques lignes de Python.

L’inférence variationnelle est particulièrement puissante si vous disposez de peu de données. Nous avons vu comment infuser et échanger la connaissance du domaine avec l’information provenant des données. La distribution de substitution inférée q (Θ) donne une représentation “floue” des paramètres du modèle, plutôt qu’une valeur fixe. C’est idéal si vous vous trouvez dans une application à enjeux élevés où l’incertitude est importante! Enfin, nous avons démontré la modélisation générative. La génération d’échantillons synthétiques est facile une fois que vous pouvez échantillonner à partir de q (Θ).

En résumé, cela en fait un composant essentiel de la boîte à outils de la science des données.

En exploitant la puissance de l’inférence variationnelle, nous pouvons aborder des problèmes complexes, ce qui nous permet de prendre des décisions éclairées, de quantifier les incertitudes et en fin de compte de débloquer le véritable potentiel de la science des données.

Remerciements

Je tiens à remercier Dorien Neijzen et Martin Banchero pour la relecture.

Références :

[1] Blei, David M., Alp Kucukelbir, and Jon D. McAuliffe. “Variational inference: A review for statisticians.” Journal of the American statistical Association 112.518 (2017): 859–877.

[2] Figurnov, Mikhail, Shakir Mohamed, and Andriy Mnih. “Implicit reparameterization gradients.” Advances in neural information processing systems 31 (2018).

[3] Ranganath, Rajesh, Sean Gerrish, and David Blei. “Black box variational inference.” Artificial intelligence and statistics. PMLR, 2014.

[4] Gopalan, Prem, Jake M. Hofman et David M. Blei. “Recommandation évolutive avec factorisation de Poisson.” Article préliminaire arXiv arXiv:1311.1704 (2013).

We will continue to update IPGirl; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

AI

Surveillance des données non structurées pour LLM et NLP

Une fois que vous avez déployé une solution basée sur la PNL ou le LLM, vous avez besoin d'un moyen de la surveiller....

Apprentissage automatique

PDG de NVIDIA Les créateurs seront boostés par l'IA générative

La création d’IA générative va “booster” les créateurs dans tous les secteurs et types de contenus,...

AI

Dévoiler l'avenir 10 outils AI de pointe que vous ne pouvez tout simplement pas ignorer (Juin 2023)

LOVO AI LOVO AI Text to Speech permet aux utilisateurs de créer des voix off professionnelles dans 100 langues. Cette...

AI

La vision du Premier ministre Modi pour la réglementation de l'IA en Inde Sommet B20 2023

Alors que le sommet B20 Inde 2023 touchait à sa fin à Delhi, les échos des paroles du Premier ministre Narendra Modi ...

AI

Chattez avec des PDF | Donnez du pouvoir à l'interaction textuelle avec Python et OpenAI

Introduction Dans un monde rempli d’informations, les documents PDF sont devenus un élément essentiel pour part...

AI

10 Types d'algorithmes de regroupement en apprentissage automatique

Introduction Avez-vous déjà pensé à la façon dont de vastes volumes de données peuvent être démêlés, révélant des mot...