import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit
import optax
# Define a simple quadratic loss function
def loss_fn(params, x, y):
predictions = jnp.dot(x, params)
return jnp.mean((predictions - y) ** 2)
# Initialize parameters, here we'll assume a simple linear model with 1 parameter
params = jnp.array([1.0])
# Your data, x and y
x = jnp.array([[1.0], [2.0], [3.0]])
y = jnp.array([2.0, 4.0, 6.0])
# Setup the optimizer, here we use stochastic gradient descent with a learning rate of 0.1
optimizer = optax.sgd(0.1)
# Initialize the optimizer state
opt_state = optimizer.init(params)
# Wrap the update step in a `jit`'d function for performance
@jit
def update(params, x, y, opt_state):
"""Single optimization step."""
# Compute the gradient
grads = grad(loss_fn)(params, x, y)
# Update the parameters and optimizer state
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
# Run the optimization loop for a number of steps
for _ in range(100):
params, opt_state = update(params, x, y, opt_state)
# Print out the optimized parameters
print("Optimized Parameters:", params)
Optimized Parameters: [2.]