9. An Asset Pricing Problem#
GPU
This lecture was built using a machine with JAX installed and access to a GPU.
To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.
To run this lecture on your own machine, you need to install Google JAX.
9.1. Overview#
In this lecture we consider some asset pricing problems and use them to illustrate some foundations of JAX programming.
The main difference from the lecture Asset Pricing: The Lucas Asset Pricing Model, which also considers asset prices, is that the the state spaces will be discrete and multi-dimensional.
Most of the heavy lifting is done through routines from linear algebra.
Along the way, we will show how to solve some memory-intensive problems with large state spaces.
We do this using elegant techniques made available by JAX, involving the use of linear operators to avoid instantiating large matrices.
If you wish to skip all motivation and move straight to the first equation we plan to solve, you can jump to (9.13).
The code outputs below are generated by machine connected to the following GPU
!nvidia-smi
Mon May 19 04:15:36 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.51.03 Driver Version: 575.51.03 CUDA Version: 12.9 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Tesla T4 Off | 00000000:00:1E.0 Off | 0 |
| N/A 32C P0 32W / 70W | 0MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
In addition to JAX and Anaconda, this lecture will need the following libraries:
!pip install quantecon
Show code cell output
Requirement already satisfied: quantecon in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (0.8.0)
Requirement already satisfied: numba>=0.49.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (from quantecon) (0.60.0)
Requirement already satisfied: numpy>=1.17.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (from quantecon) (1.26.4)
Requirement already satisfied: requests in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (from quantecon) (2.32.3)
Requirement already satisfied: scipy>=1.5.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (from quantecon) (1.13.1)
Requirement already satisfied: sympy in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (from quantecon) (1.13.2)
Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (from numba>=0.49.0->quantecon) (0.43.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (from requests->quantecon) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (from requests->quantecon) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (from requests->quantecon) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (from requests->quantecon) (2024.8.30)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.12/site-packages (from sympy->quantecon) (1.3.0)
Below we use the following imports
import scipy
import quantecon as qe
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from collections import namedtuple
from time import time
We will use 64 bit floats with JAX in order to increase precision.
jax.config.update("jax_enable_x64", True)
9.2. Pricing a single payoff#
Suppose, at time
The simplest way to price this asset is to use “risk-neutral” asset pricing, which
asserts that the price of the asset at time
Here
Roughly speaking, (9.1) says that the cost (i.e., price) equals expected benefit.
The discount factor is introduced because most people prefer payments now to payments in the future.
One problem with this very simple model is that it does not take into account attitudes to risk.
For example, investors often demand higher rates of return for holding risky assets.
This feature of asset prices cannot be captured by risk neutral pricing.
Hence we modify (9.1) to
In this expression,
In essence, allowing discounting to become a random variable gives us the flexibility to combine temporal discounting and attitudes to risk.
We leave further discussion to other lectures because our aim is to move to the computational problem.
9.3. Pricing a cash flow#
Now let’s try to price an asset like a share, which delivers a cash flow
We will call these payoffs “dividends”.
If we buy the share, hold it for one period and sell it again, we receive one
dividend and our payoff is
Therefore, by (9.2), the price should be
Because prices generally grow over time, which complicates analysis, it will be
easier for us to solve for the price-dividend ratio
Let’s write down an expression that this ratio should satisfy.
We can divide both sides of (9.3) by
We can also write this as
where
is the growth rate of dividends.
Our aim is to solve (9.5) but before that we need to specify
the stochastic discount factor
andthe growth rate of dividends
9.4. Choosing the stochastic discount factor#
We will adopt the stochastic discount factor described in Asset Pricing: The Lucas Asset Pricing Model, which has the form
where
For utility, we’ll assume the constant relative risk aversion (CRRA) specification
Inserting the CRRA specification into (9.6) and letting
the growth rate rate of consumption, we obtain
9.5. Solving for the price-dividend ratio#
Substituting (9.8) into (9.5) gives the price-dividend ratio formula
We assume there is a Markov chain
Here
We can think of
We let
We guess that
This means that
By (9.9), the unknown function
where
Since the shocks
We use the following property of lognormal distributions: if
This yields
Conditioning on
for all
Suppose
Then we can think of
for
Equivalently, we can write
where
Rewriting (9.14) in vector form yields
Notice that (9.16) can be written as
The Neumann series lemma tells us that
whenever
Once we specify
obtain
check the spectral radius condition
and, assuming it holds,compute the solution via (9.17).
9.6. Code#
We will use the power iteration algorithm to check the spectral radius condition.
The function below computes the spectral radius of A
.
def power_iteration_sr(A, num_iterations=15, seed=1234):
" Estimates the spectral radius of A via power iteration. "
# Initialize
key = jax.random.PRNGKey(seed)
b_k = jax.random.normal(key, (A.shape[1],))
sr = 0
for _ in range(num_iterations):
# calculate the matrix-by-vector product Ab
b_k1 = jnp.dot(A, b_k)
# calculate the norm
b_k1_norm = jnp.linalg.norm(b_k1)
# Record the current estimate of the spectral radius
sr = jnp.sum(b_k1 * b_k)/jnp.sum(b_k * b_k)
# re-normalize the vector and continue
b_k = b_k1 / b_k1_norm
return sr
power_iteration_sr = jax.jit(power_iteration_sr)
The next function verifies that the spectral radius of a given matrix is
def test_stability(Q):
"""
Assert that the spectral radius of matrix Q is < 1.
"""
sr = power_iteration_sr(Q)
assert sr < 1, f"Spectral radius condition failed with radius = {sr}"
In what follows we assume that
where
To discretize this process we use QuantEcon.py’s tauchen
function.
Below we write a function called create_model()
that returns a namedtuple storing the relevant parameters and arrays.
Model = namedtuple('Model',
('P', 'S', 'β', 'γ', 'μ_c', 'μ_d', 'σ_c', 'σ_d'))
def create_model(N=100, # size of state space for Markov chain
ρ=0.9, # persistence parameter for Markov chain
σ=0.01, # persistence parameter for Markov chain
β=0.98, # discount factor
γ=2.5, # coefficient of risk aversion
μ_c=0.01, # mean growth of consumption
μ_d=0.01, # mean growth of dividends
σ_c=0.02, # consumption volatility
σ_d=0.04): # dividend volatility
# Create the state process
mc = qe.tauchen(N, ρ, σ)
S = mc.state_values
P = mc.P
# Shift arrays to the device
S, P = map(jax.device_put, (S, P))
# Return the namedtuple
return Model(P=P, S=S, β=β, γ=γ, μ_c=μ_c, μ_d=μ_d, σ_c=σ_c, σ_d=σ_d)
Our first step is to construct the matrix
Here’s a function that does this using loops.
def compute_K_loop(model):
# unpack
P, S, β, γ, μ_c, μ_d, σ_c, σ_d = model
N = len(S)
K = np.empty((N, N))
a = μ_d - γ * μ_c
for i, x in enumerate(S):
for j, y in enumerate(S):
e = np.exp(a + (1 - γ) * x + (σ_d**2 + γ**2 * σ_c**2) / 2)
K[i, j] = β * e * P[i, j]
return K
To exploit the parallelization capabilities of JAX, let’s also write a vectorized (i.e., loop-free) implementation.
def compute_K(model):
# unpack
P, S, β, γ, μ_c, μ_d, σ_c, σ_d = model
N = len(S)
# Reshape and multiply pointwise using broadcasting
x = np.reshape(S, (N, 1))
a = μ_d - γ * μ_c
e = np.exp(a + (1 - γ) * x + (σ_d**2 + γ**2 * σ_c**2) / 2)
K = β * e * P
return K
These two functions produce the same output:
model = create_model(N=10)
K1 = compute_K(model)
K2 = compute_K_loop(model)
np.allclose(K1, K2)
True
Now we can compute the price-dividend ratio:
def price_dividend_ratio(model, test_stable=True):
"""
Computes the price-dividend ratio of the asset.
Parameters
----------
model: an instance of Model
contains primitives
Returns
-------
v : array_like
price-dividend ratio
"""
K = compute_K(model)
N = len(model.S)
if test_stable:
test_stability(K)
# Compute v
I = np.identity(N)
ones_vec = np.ones(N)
v = np.linalg.solve(I - K, K @ ones_vec)
return v
Here’s a plot of
model = create_model()
S = model.S
γs = np.linspace(2.0, 3.0, 5)
fig, ax = plt.subplots()
for γ in γs:
model = create_model(γ=γ)
v = price_dividend_ratio(model)
ax.plot(S, v, lw=2, alpha=0.6, label=rf"$\gamma = {γ}$")
ax.set_ylabel("price-dividend ratio")
ax.set_xlabel("state")
ax.legend(loc='upper right')
plt.show()
Notice that
This is because, with a positively correlated state process, higher states indicate higher future consumption growth.
With the stochastic discount factor (9.8), higher growth decreases the discount factor, lowering the weight placed on future dividends.
9.7. An Extended Example#
One problem with the last set is that volatility is constant through time (i.e.,
In reality, financial markets and growth rates of macroeconomic variables exhibit bursts of volatility.
To accommodate this, we now develop a stochastic volatility model.
To begin, suppose that consumption and dividends grow as follows.
where
We assume that
Here
Let
We call
Modifying (9.10) to accommodate the new growth specifications,
we find that
where, as before,
Conditioning on state
As before, we integrate out the independent shocks and use the rules for expectations of lognormals to obtain
Let
where
With this notation, we can write (9.20) more explicitly as
Let’s now write the state using indices, with
Then (9.21) becomes
One way to understand this is to reshape
Then we can write (9.22) as
Provided that the spectral radius condition
9.8. Numpy Version#
Our first implementation will be in NumPy.
Once we have a NumPy version working, we will convert it to JAX and check the difference in the run times.
The code block below provides a function called create_sv_model()
that returns a namedtuple containing arrays and other data that form the primitives of the problem.
It assumes that
SVModel = namedtuple('SVModel',
('P', 'hc_grid',
'Q', 'hd_grid',
'R', 'z_grid',
'β', 'γ', 'bar_σ', 'μ_c', 'μ_d'))
def create_sv_model(β=0.98, # discount factor
γ=2.5, # coefficient of risk aversion
I=14, # size of state space for h_c
ρ_c=0.9, # persistence parameter for h_c
σ_c=0.01, # volatility parameter for h_c
J=14, # size of state space for h_d
ρ_d=0.9, # persistence parameter for h_d
σ_d=0.01, # volatility parameter for h_d
K=14, # size of state space for z
bar_σ=0.01, # volatility scaling parameter
ρ_z=0.9, # persistence parameter for z
σ_z=0.01, # persistence parameter for z
μ_c=0.001, # mean growth of consumption
μ_d=0.005): # mean growth of dividends
mc = qe.tauchen(I, ρ_c, σ_c)
hc_grid = mc.state_values
P = mc.P
mc = qe.tauchen(J, ρ_d, σ_d)
hd_grid = mc.state_values
Q = mc.P
mc = qe.tauchen(K, ρ_z, σ_z)
z_grid = mc.state_values
R = mc.P
return SVModel(P=P, hc_grid=hc_grid,
Q=Q, hd_grid=hd_grid,
R=R, z_grid=z_grid,
β=β, γ=γ, bar_σ=bar_σ, μ_c=μ_c, μ_d=μ_d)
Now we provide a function to compute the matrix
def compute_A(sv_model):
# Set up
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
I, J, K = len(hc_grid), len(hd_grid), len(z_grid)
N = I * J * K
# Reshape and broadcast over (i, j, k, i', j', k')
hc = np.reshape(hc_grid, (I, 1, 1, 1, 1, 1))
hd = np.reshape(hd_grid, (1, J, 1, 1, 1, 1))
z = np.reshape(z_grid, (1, 1, K, 1, 1, 1))
P = np.reshape(P, (I, 1, 1, I, 1, 1))
Q = np.reshape(Q, (1, J, 1, 1, J, 1))
R = np.reshape(R, (1, 1, K, 1, 1, K))
# Compute A and then reshape to create a matrix
a = μ_d - γ * μ_c
b = bar_σ**2 * (np.exp(2 * hd) + γ**2 * np.exp(2 * hc)) / 2
κ = np.exp(a + (1 - γ) * z + b)
A = β * κ * P * Q * R
A = np.reshape(A, (N, N))
return A
Here’s our function to compute the price-dividend ratio for the stochastic volatility model.
def sv_pd_ratio(sv_model, test_stable=True):
"""
Computes the price-dividend ratio of the asset for the stochastic volatility
model.
Parameters
----------
sv_model: an instance of Model
contains primitives
Returns
-------
v : array_like
price-dividend ratio
"""
# unpack
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
I, J, K = len(hc_grid), len(hd_grid), len(z_grid)
N = I * J * K
A = compute_A(sv_model)
# Make sure that a unique solution exists
if test_stable:
test_stability(A)
# Compute v
ones_array = np.ones(N)
Id = np.identity(N)
v = scipy.linalg.solve(Id - A, A @ ones_array)
# Reshape into an array of the form v[i, j, k]
v = np.reshape(v, (I, J, K))
return v
Let’s create an instance of the model and solve it.
sv_model = create_sv_model()
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
Let’s run it to compile.
start = time()
v = sv_pd_ratio(sv_model)
numpy_with_compile = time() - start
print("Numpy compile plus execution time = ", numpy_with_compile)
Numpy compile plus execution time = 0.8721084594726562
Let’s run it again to remove the compile.
start = time()
v = sv_pd_ratio(sv_model)
numpy_without_compile = time() - start
print("Numpy execution time = ", numpy_without_compile)
Numpy execution time = 0.28946471214294434
Here are some plots of the solution
fig, ax = plt.subplots()
ax.plot(hc_grid, v[:, 0, 0], lw=2, alpha=0.6, label="$v$ as a function of $h^c$")
ax.set_ylabel("price-dividend ratio")
ax.set_xlabel("state")
ax.legend()
plt.show()
fig, ax = plt.subplots()
ax.plot(hd_grid, v[0, :, 0], lw=2, alpha=0.6, label="$v$ as a function of $h^d$")
ax.set_ylabel("price-dividend ratio")
ax.set_xlabel("state")
ax.legend()
plt.show()
9.9. JAX Version#
Now let’s write a JAX version that is a simple transformation of the NumPy version.
(Below we will write a more efficient version using JAX’s ability to work with linear operators.)
def create_sv_model_jax(sv_model): # mean growth of dividends
# Take the contents of a NumPy sv_model instance
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
# Shift the arrays to the device (GPU if available)
hc_grid, hd_grid, z_grid = map(jax.device_put, (hc_grid, hd_grid, z_grid))
P, Q, R = map(jax.device_put, (P, Q, R))
# Create a new instance and return it
return SVModel(P=P, hc_grid=hc_grid,
Q=Q, hd_grid=hd_grid,
R=R, z_grid=z_grid,
β=β, γ=γ, bar_σ=bar_σ, μ_c=μ_c, μ_d=μ_d)
Here’s a function to compute
We include the extra argument shapes
to help the compiler understand the size of the arrays.
This is important when we JIT-compile the function below.
def compute_A_jax(sv_model, shapes):
# Set up
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
I, J, K = shapes
N = I * J * K
# Reshape and broadcast over (i, j, k, i', j', k')
hc = jnp.reshape(hc_grid, (I, 1, 1, 1, 1, 1))
hd = jnp.reshape(hd_grid, (1, J, 1, 1, 1, 1))
z = jnp.reshape(z_grid, (1, 1, K, 1, 1, 1))
P = jnp.reshape(P, (I, 1, 1, I, 1, 1))
Q = jnp.reshape(Q, (1, J, 1, 1, J, 1))
R = jnp.reshape(R, (1, 1, K, 1, 1, K))
# Compute A and then reshape to create a matrix
a = μ_d - γ * μ_c
b = bar_σ**2 * (jnp.exp(2 * hd) + γ**2 * jnp.exp(2 * hc)) / 2
κ = jnp.exp(a + (1 - γ) * z + b)
A = β * κ * P * Q * R
A = jnp.reshape(A, (N, N))
return A
Here’s the function that computes the solution.
def sv_pd_ratio_jax(sv_model, shapes):
"""
Computes the price-dividend ratio of the asset for the stochastic volatility
model.
Parameters
----------
sv_model: an instance of Model
contains primitives
Returns
-------
v : array_like
price-dividend ratio
"""
# unpack
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
I, J, K = len(hc_grid), len(hd_grid), len(z_grid)
shapes = I, J, K
N = I * J * K
A = compute_A_jax(sv_model, shapes)
# Compute v, reshape and return
ones_array = jnp.ones(N)
Id = jnp.identity(N)
v = jax.scipy.linalg.solve(Id - A, A @ ones_array)
return jnp.reshape(v, (I, J, K))
Now let’s target these functions for JIT-compilation, while using static_argnums
to indicate that the function will need to be recompiled when shapes
changes.
compute_A_jax = jax.jit(compute_A_jax, static_argnums=(1,))
sv_pd_ratio_jax = jax.jit(sv_pd_ratio_jax, static_argnums=(1,))
sv_model = create_sv_model()
sv_model_jax = create_sv_model_jax(sv_model)
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model_jax
shapes = len(hc_grid), len(hd_grid), len(z_grid)
Let’s see how long it takes to run with compile time included.
start = time()
v_jax = sv_pd_ratio_jax(sv_model_jax, shapes).block_until_ready()
jnp_with_compile = time() - start
print("JAX compile plus execution time = ", jnp_with_compile)
JAX compile plus execution time = 0.6433587074279785
And now let’s see without compile time.
start = time()
v_jax = sv_pd_ratio_jax(sv_model_jax, shapes).block_until_ready()
jnp_without_compile = time() - start
print("JAX execution time = ", jnp_without_compile)
JAX execution time = 0.11757493019104004
Here’s the ratio of times:
jnp_without_compile / numpy_without_compile
0.4061805299879829
Let’s check that the NumPy and JAX versions realize the same solution.
v = jax.device_put(v)
print(jnp.allclose(v, v_jax))
True
9.10. A memory-efficient JAX version#
One problem with the code above is that we instantiate a matrix of size
This quickly becomes impossible as
Fortunately, JAX makes it possible to solve for the price-dividend ratio without instantiating this large matrix.
The first step is to think of
def A(g, sv_model, shapes):
# Set up
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
I, J, K = shapes
# Reshape and broadcast over (i, j, k, i', j', k')
hc = jnp.reshape(hc_grid, (I, 1, 1, 1, 1, 1))
hd = jnp.reshape(hd_grid, (1, J, 1, 1, 1, 1))
z = jnp.reshape(z_grid, (1, 1, K, 1, 1, 1))
P = jnp.reshape(P, (I, 1, 1, I, 1, 1))
Q = jnp.reshape(Q, (1, J, 1, 1, J, 1))
R = jnp.reshape(R, (1, 1, K, 1, 1, K))
g = jnp.reshape(g, (1, 1, 1, I, J, K))
a = μ_d - γ * μ_c
b = bar_σ**2 * (jnp.exp(2 * hd) + γ**2 * jnp.exp(2 * hc)) / 2
κ = jnp.exp(a + (1 - γ) * z + b)
A = β * κ * P * Q * R
Ag = jnp.sum(A * g, axis=(3, 4, 5))
return Ag
Now we write a version of the solution function for the price-dividend ratio
that acts directly on the linear operator A
.
def sv_pd_ratio_linop(sv_model, shapes):
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
I, J, K = shapes
ones_array = jnp.ones((I, J, K))
# Set up the operator g -> (I - A) g
J = lambda g: g - A(g, sv_model, shapes)
# Solve v = (I - A)^{-1} A 1
A1 = A(ones_array, sv_model, shapes)
# Apply an iterative solver that works for linear operators
v = jax.scipy.sparse.linalg.bicgstab(J, A1)[0]
return v
Let’s target these functions for JIT compilation.
A = jax.jit(A, static_argnums=(2,))
sv_pd_ratio_linop = jax.jit(sv_pd_ratio_linop, static_argnums=(1,))
Let’s time the solution with compile time included.
start = time()
v_jax_linop = sv_pd_ratio_linop(sv_model, shapes).block_until_ready()
jnp_linop_with_compile = time() - start
print("JAX compile plus execution time = ", jnp_linop_with_compile)
JAX compile plus execution time = 0.5621523857116699
And now let’s see without compile time.
start = time()
v_jax_linop = sv_pd_ratio_linop(sv_model, shapes).block_until_ready()
jnp_linop_without_compile = time() - start
print("JAX execution time = ", jnp_linop_without_compile)
JAX execution time = 0.006590127944946289
Let’s verify the solution again:
print(jnp.allclose(v, v_jax_linop))
True
Here’s the ratio of times between memory-efficient and direct version:
jnp_linop_without_compile / jnp_without_compile
0.05605045169270701
The speed is somewhat faster and, moreover, we can now work with much larger grids.
Here’s a moderately large example, where the state space has 15,625 elements.
sv_model = create_sv_model(I=25, J=25, K=25)
sv_model_jax = create_sv_model_jax(sv_model)
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model_jax
shapes = len(hc_grid), len(hd_grid), len(z_grid)
%time _ = sv_pd_ratio_linop(sv_model_jax, shapes).block_until_ready()
%time _ = sv_pd_ratio_linop(sv_model_jax, shapes).block_until_ready()
CPU times: user 642 ms, sys: 12.2 ms, total: 654 ms
Wall time: 775 ms
CPU times: user 172 ms, sys: 26 μs, total: 172 ms
Wall time: 171 ms
The solution is computed relatively quickly and without memory issues.
Readers will find that they can push these numbers further, although we refrain from doing so here.