Imaging with R2D2¶
Description¶
The R2D2 algorithm takes a hybrid structure between a Plug-and-Play algorithm and a learned version of the well known Matching Pursuit algorithm. Its reconstruction is formed as a series of residual images, iteratively estimated as outputs of iteration-specific Deep Neural Networks (DNNs) taking the previous iteration’s image estimate and associated data residual as inputs. R2D2's primary application is to solve inverse problems in radio interferometry (RI). The details of R2D2 are discussed in the following papers.
[1] Aghabiglou, A., Chu, C. S., Dabbech, A. & Wiaux, Y., Towards a robust R2D2 paradigm for radio-interferometric imaging: revisiting DNN training and architecture, ArXiv:2503.02554v1.
In this tutorial, we focus on a Python implementation of R2D2 for small-scale monochromatic intensity imaging in RI from the environment setup to imaging.
RI Inverse Problem¶
The RI imaging inverse problem can be formulated as:
$$ \boldsymbol{y} = \boldsymbol{\Phi \bar{x}} + \boldsymbol{n},$$
where $\boldsymbol{y} \in \mathbb{C}^M$ is the measurement vector, $\boldsymbol{\bar{x}} \in \mathbb{R}^N$ is the unknown radio image, $\boldsymbol{\Phi} \in \mathbb{C}^{N \times M}$ is the measurement operator corresponding to incomplete Fourier sampling, and $\boldsymbol{n} \in \mathbb{C}^M$ is a realisation of a white Gaussian noise with standard deviation $\tau > 0$ and mean $0$.
The RI inverse problem can be formulated in the image domain via back-projection of the measurements using the adjoint of the measurement operator $\boldsymbol{\Phi}^\dagger$. The image-domain data, known as the dirty image, reads:
$$ \boldsymbol{x}_{\textrm{d}} = \kappa \textrm{Re}\{ \boldsymbol{\Phi}^\dagger \boldsymbol{y}\} = \kappa \textrm{Re}\{ \boldsymbol{\Phi}^\dagger \boldsymbol{\Phi} \bar{x}+ \boldsymbol{\Phi}^\dagger \boldsymbol{n}\},$$
where $\textrm{Re}\{ \cdot \} $ denotes the real part of its arugment, and $\kappa>0 $ is a normalisation factor ensuring the maximum value of the point spread function (PSF), given by $\kappa \textrm{Re} \{\boldsymbol{\Phi}^\dagger \boldsymbol{\Phi} \boldsymbol{\delta}\} $, is equal to $1$, with $\boldsymbol{\delta} $ denoting the image with value $1$ at its center pixel and $0$ otherwise.
R2D2's Iteration Structure¶
The R2D2 algorithm consists of a series of $I$ DNNs interleaved with $I-1$ data consistency layers. Each DNN $\boldsymbol{N}_{\widehat{\boldsymbol{\theta}}^{(i)}}$, parametrised by $\widehat{\boldsymbol{\theta}}^{(i)}$, takes the previous iteration's image estimate and associated residual dirty image as input. The image estimate and residual dirty image are initialised such that $\boldsymbol{x}^{(0)} = \boldsymbol{0}$ and $\boldsymbol{r}^{(0)} = \boldsymbol{x}_{\textrm{d}}$.
At each iteration $i \leq I$, the residual dirty image is given by:
$\boldsymbol{r}^{(i-1)} = \boldsymbol{x}_{\textrm{d}} - \kappa \textrm{Re} \{ \boldsymbol{\Phi}^\dagger\boldsymbol{\Phi} \} \boldsymbol{x}^{(i-1)}$.
The image estimate is then updated from the output of the DNN $\boldsymbol{N}_{\widehat{\boldsymbol{\theta}}^{(i)}}$ as:
$\boldsymbol{x}^{(i)} = \boldsymbol{x}^{(i-1)} + \boldsymbol{N}_{\widehat{\boldsymbol{\theta}}^{(i)}}(\boldsymbol{x}^{(i-1)}, \boldsymbol{r}^{(i-1)})$.
The R2D2 algorithm underpines by two DNN architecture:
- R2D2$_\mathcal{A1}$, a DNN series whose DNNs take the well knwon U-Net architecture.
- R2D2$_\mathcal{A2}$, a DNN series whose DNNs take the proposed U-WDSR architecture.
R2D2 Repository & Dependencies¶
To start with, install virtual environment and the Jupyter notebook in your preferred terminal. Set up the directory of the virtual environment to be installed, named here as r2d2_env
:
python -m venv r2d2_env
Then, activate the virtual environment by running command below outside the Jupyter Notebook in your terminal:
- On macOS/Linux:
source r2d2_env/bin/activate
- On Windows:
.\r2d2_env\Scripts\activate
Install Jupyter using:
pip install jupyter ipykernel
Add the virtual environment to Jupyter as a new kernel:
python -m ipykernel install --user --name=r2d2_env --display-name "R2D2_env"
Start Jupyter Notebook by running:
jupyter notebook
Select the new kernel in Jupyter:
- Open your Jupyter notebook.
- Go to Kernel -> Change kernel -> R2D2_env.
Clone the repository to the current directory by running the command below in your preferred terminal:
!git clone https://github.com/basp-group/R2D2-RI.git
%cd ./R2D2-RI
PyTorch and torchvision should be installed as per the guidelines of PyTorch to ensure the latest version compatible with your CUDA version is installed. If your CUDA version is older than 11.8, then, refer to these guidelines.
- For example, to install PyTorch with CUDA version 11.6 (OS: Linux/Windows), use the command below:
!pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1!pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
- On macOS:
!pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1
All other required Python packages are listed in the requirements file. Python version 3.10
or higher is required.
Install the packages using the command below:
!pip install -r requirements.txt
Pre-trained R2D2 realisations¶
The pre-trained realisation of both R2D2$_\mathcal{A1}$ and R2D2$_\mathcal{A2}$ algorithms are available at the DOI:10.17861/e3060b95-4fe6-4b61-9f72-d77653c305bb.
They are trained using Fourier sampling patterns from the VLA (combining its configurations A and C) for the formation of images of size $N = 512 \times 512$, under the imaging settings explained in [1].
In this tutorial, we focus on the R2D2$_\mathcal{A2}$ a and introduce methods to study its epistemic uncertainty using different realisations. The same instructions apply for R2D2$_\mathcal{A1}$.
First, create the directory storing the pre-trained DNNs in the project directory:
!mkdir ./ckpt
Then, download and unzip the compressed file of R2D2 DNN series:
urls = {
1: "https://researchportal.hw.ac.uk/files/146314175/R2D2_A2_T2_Realisation1.zip",
2: "https://researchportal.hw.ac.uk/files/146289537/R2D2_A2_T2_Realisation2.zip",
3: "https://researchportal.hw.ac.uk/files/146314176/R2D2_A2_T2_Realisation3.zip",
4: "https://researchportal.hw.ac.uk/files/146314177/R2D2_A2_T2_Realisation4.zip",
5: "https://researchportal.hw.ac.uk/files/146314178/R2D2_A2_T2_Realisation5.zip"
}
for i in range(1, 6):
!mkdir -p ./ckpt/R2D2_A2/V{i}
url = urls[i]
zip_path = f"./ckpt/R2D2_A2_T2_Realisation{i}.zip"
!curl -L -o {zip_path} -H "User-Agent: Mozilla/5.0" -e https://researchportal.hw.ac.uk/ {url}
!unzip -j {zip_path} -d ./ckpt/R2D2_A2/V{i}/
!rm -f {zip_path}
Note: Please note that the "curl" and "unzip" commands should be installed beforehand.
Test dataset¶
In this tutorial, we provide a test dataset, located in ./data/3c353/
, composed of a ground truth in .fits format and its associated measurement file in .mat format. The ground truth is a post-processed image of the radio galaxy 3c353 of size $N = 512 \times 512$, which can be found in ./data/3c353/3c353_GTfits.fits
. The measurements are simulated using VLA antenna configurations A and C.
First, let us read and visualise the ground truth $\bar{\boldsymbol{x}}$ using the following Python code:
from astropy.io import fits
import matplotlib.pyplot as plt
# Load ground truth image
gdth = fits.getdata('data/3c353_gdth.fits')
# Plot ground truth image
fig, ax = plt.subplots()
pcm = ax.imshow(gdth, cmap='afmhot')
cb = plt.colorbar(pcm, ax=ax)
title = ax.set_title(r'Ground Truth (linear scale), $\overline{\boldsymbol{x}}$', y=-0.15)
Given the high dynamic range of the ground truth $\bar{\boldsymbol{x}}$, only bright features are visible in linear scale. For a better visualisation, we map its pixel intensities to the logarithmic scale using the function:
$\textrm{rlog}(\bar{\boldsymbol{x}}) = \bar{\boldsymbol{x}}_{\textrm{max}} \log_{a}(\frac{a}{\bar{\boldsymbol{x}}_{\textrm{max}}} \bar{\boldsymbol{x}} + \boldsymbol{1}),$
where $a$ is its dynamic range, and $\bar{\boldsymbol{x}}_{\textrm{max}}$ is its maximum intensity value. For the considered image, $a = 16429.94$ and $\bar{\boldsymbol{x}}_{\textrm{max}} = 1.0$.
import numpy as np
import matplotlib.colors as colors
# Define rlog
class rlog(colors.Normalize):
def __init__(self, a):
self.a = a
super().__init__(vmin=None, vmax=None, clip=False)
def __call__(self, im):
if len(im) == 0:
return im
im_max = np.max(im)
if im_max < 1/self.a:
return im
else:
return im_max * np.log10(self.a * im / im_max + 1) / np.log10(self.a)
def inverse(self, im):
return (self.a**(im) - 1) / self.a
# Image dynamic range
a = 16429.94
# Plot image
fig, ax = plt.subplots()
pcm = ax.imshow(gdth, norm= rlog(a), cmap='hot')
cb = plt.colorbar(pcm, ax=ax)
ticks = [np.log10(a * tick + 1)/np.log10(a) for tick in [0, 1e-4, 1e-3, 1e-2, 1e-1, 1]]
ticks[-1] = 1
cb.set_ticks(ticks)
cb.set_ticklabels([0, '$10^{-4}$', '$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '1'])
title = ax.set_title(r'Ground Truth (log scale), $\text{rlog}(\overline{\boldsymbol{x}})$', y=-0.15)
The test measurement file ./data/data_3c353.mat
associated with the above ground truth can be loaded using the following Python code:
from scipy.io import loadmat
# Load data
meas = loadmat('data/data_3c353.mat')
for k, v in meas.items():
if not k.startswith('_'):
if max(v.shape) == 1:
val = v.item()
if isinstance(val, (float, int)):
print(f"{k}: {val:.4e}, {type(val)}")
else:
print(f"{k}: {val}, {type(val)}")
else:
print(f"{k}: {v.shape}, {type(v)}")
sigma: 6.0864e-05, <class 'float'> super_resolution: 1.5189e+00, <class 'float'> weight_robustness: 2.5303e-01, <class 'float'> u: (3556332, 1), <class 'numpy.ndarray'> v: (3556332, 1), <class 'numpy.ndarray'> frequency: 1.0000e+09, <class 'float'> nFreqs: 1.0000e+00, <class 'int'> y: (2120879, 1), <class 'numpy.ndarray'> unit: m, <class 'str'> flag: (2, 1, 3556332), <class 'numpy.ndarray'> tau: (1, 7), <class 'numpy.ndarray'> nW: (1, 7), <class 'numpy.ndarray'> tau_index: (1, 7), <class 'numpy.ndarray'> expo_factor: 3.7284e+03, <class 'float'> noise: (512, 512), <class 'numpy.ndarray'> true_noise_norm: 2.9365e-02, <class 'float'>
The expected input variables are:
y
, measurement/ data vector.u
, $\boldsymbol{u}$ coordinates in units of the observation wavelength.v
, $\boldsymbol{v}$ coordinates in units of the observation wavelength.nW
, noise-whitening vector, known as the natural weights, corresponding to the inverse of the noise standard deviation. During imaging, these weights are applied to the measurements and injected into the measurement operator model.frequency
, the observation frequency in Hz.maxProjBaseline
, the maximum projected baseline in units of the observation wavelength, formally defined as $\max\{\sqrt{\boldsymbol{u}^2 + \boldsymbol{v}^2}\}$. It represents the spatial bandwidth of the Fourier sampling.
The collection of the points $(\boldsymbol{u}, \boldsymbol{v})$ constitutes the 2D Fourier sampling pattern also known as the $uv$-coverage which can be visualised using the following Python code:
import matplotlib.pyplot as plt
# plot the 2D Fourier sampling pattern
fig, ax = plt.subplots(figsize=(6, 6))
pcm = ax.scatter(meas['u'], meas['v'], s=1e-3)
title = ax.set_title('uv coverage', y=-0.15)
plt.ticklabel_format(style='sci', axis='both', scilimits=(0,0))
Scripts to generate synthetic measurement files as well as to convert real Measurement Sets (MS) to .mat files readable by this repository are available.
Imaging with R2D2¶
R2D2 imager is launched via the Python file ./src/imager.py
. This file takes as input a configuration file in .yaml
format, where all parameters involved in the imaging process, including the path to the input measurement file, are defined and set to default values where relevant.
The readers are directed to this README for more information on the content of the configuration file.
For this tutorial, we use the example configuration file ./config/inference/R2D2.yaml
. The imager file also accepts optional name-argument pairs to overwrite corresponding fields in the configuration file.
The super-resolution factor is defined as the ratio between the spatial Fourier bandwidth of the target/ ground truth image and the bandwidth of the Fourier sampling pattern (i.e. the maximum projected baseline). The corresponding pixel size in arcsecond is then obtained as:
$\textrm{imPixelsize} = \frac{180 \times 3600}{\pi} \times \frac{1}{\textrm{superresolution} \times 2 \times \textrm{maxProjBaseline}}$
Under these settings, the R2D2 imager can be run by calling the following command from the terminal:
!python ./src/imager.py --config ./config/inference/R2D2.yaml
R2D2 Results¶
R2D2 imager provides as output both the PSF and the dirty image, saved in .fits format. They can be visualised using the following Python code:
from astropy.io import fits
import matplotlib.pyplot as plt
import numpy as np
dirty = fits.getdata('results/data_3c353/V1/dirty_normalised.fits')
dirty_upper_lim = np.percentile(dirty, 99.75)
dirty_lower_lim = np.percentile(dirty, 0.25)
dirty = np.clip(dirty, dirty_lower_lim, dirty_upper_lim)
fig, ax = plt.subplots()
pcm = ax.imshow(dirty, cmap='hot')
cb = plt.colorbar(pcm, ax=ax)
title = ax.set_title(r'Dirty image (linear scale), $\boldsymbol{x}_{\text{d}}$', y = -0.15)
R2D2 reconstructed images are saved in .fits format:
- The estimated model image $\widetilde{\boldsymbol{x}}$.
- The associated residual dirty image $\widetilde{\boldsymbol{r}}$.
They can be visualised using the following Python code:
from astropy.io import fits
import matplotlib.pyplot as plt
import numpy as np
class rlog(colors.Normalize):
def __init__(self, a):
self.a = a
super().__init__(vmin=None, vmax=None, clip=False)
def __call__(self, im):
if len(im) == 0:
return im
im_max = np.max(im)
if im_max < 1/self.a:
return im
else:
return im_max * np.log10(self.a * im / im_max + 1) / np.log10(self.a)
def inverse(self, im):
return (self.a**(im) - 1) / self.a
rec_paths = [
'results/data_3c353/V1/R2D2_model_image.fits',
'results/data_3c353/V2/R2D2_model_image.fits',
'results/data_3c353/V3/R2D2_model_image.fits',
'results/data_3c353/V4/R2D2_model_image.fits',
'results/data_3c353/V5/R2D2_model_image.fits',
'results/data_3c353/R2D2_mean_model_image.fits'
]
images = [fits.getdata(path) for path in rec_paths]
titles = [f'V{i+1}' for i in range(5)] + ['Mean']
fig, axes = plt.subplots(nrows=2, ncols=6, figsize=(18, 6))
a = 16429.94 # dynamic range factor
for i, img in enumerate(images):
im_max = np.max(img)
# Linear scale
ax_lin = axes[0, i]
im_lin = ax_lin.imshow(img, cmap='hot')
ax_lin.set_title(f'Realisation {titles[i]} (linear)', fontsize=10)
ax_lin.axis('off')
fig.colorbar(im_lin, ax=ax_lin, fraction=0.046, pad=0.04)
# Log scale
ax_log = axes[1, i]
norm = rlog(a)
im_log = ax_log.imshow(img, cmap='hot', norm=norm)
ax_log.set_title(f'Realisation {titles[i]} (log)', fontsize=10)
ax_log.axis('off')
cbar = fig.colorbar(im_log, ax=ax_log, fraction=0.046, pad=0.04)
ticks = [np.log10(a * t + 1)/np.log10(a) for t in [0, 1e-4, 1e-3, 1e-2, 1e-1, 1]]
ticks[-1] = 1
cbar.set_ticks(ticks)
cbar.set_ticklabels([0, '$10^{-4}$', '$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '1'])
plt.tight_layout()
plt.show()
We can also compute the below reconstruction evaluation metrics using the following Python code:
$\textrm{SNR}(\bar{x}, \widetilde{x}) = 20 \log_{10}\left(\frac{\|\bar{x}\|_2}{\|\bar{x} - \widetilde{x}\|_2}\right)$,
$\textrm{logSNR}(\bar{x}, \widetilde{x}) = \textrm{SNR}(\textrm{rlog}(\bar{x}), \textrm{rlog}(\widetilde{x}))$.
# Define SNR
def snr(true, y, e=1e-9):
return 20 * np.log10(np.linalg.norm(true.flatten()) / (np.linalg.norm(true.flatten() - y.flatten()) + e))
# Define logSNR
def log_im(im, a):
im_cur = np.copy(im)
im_cur[im_cur < 0] = 0
xmax = im_cur.max()
return xmax * np.log((a / xmax) * im_cur + 1) / np.log(a)
rec_paths = [
'results/data_3c353/V1/R2D2_model_image.fits',
'results/data_3c353/V2/R2D2_model_image.fits',
'results/data_3c353/V3/R2D2_model_image.fits',
'results/data_3c353/V4/R2D2_model_image.fits',
'results/data_3c353/V5/R2D2_model_image.fits',
'results/data_3c353/R2D2_mean_model_image.fits'
]
images = [fits.getdata(path) for path in rec_paths]
titles = [f'V{i+1}' for i in range(5)] + ['Mean']
# Compute and report SNRs
print("Reconstruction metrics:")
for title, rec in zip(titles, images):
SNR_val = snr(gdth, rec)
logSNR_val = snr(log_im(gdth, a), log_im(rec, a))
print(f'{title:>5}: SNR = {SNR_val:.6f} dB; logSNR = {logSNR_val:.5f} dB')
Reconstruction metrics: V1: SNR = 37.712857 dB; logSNR = 32.86923 dB V2: SNR = 37.683392 dB; logSNR = 32.96706 dB V3: SNR = 37.890141 dB; logSNR = 33.40377 dB V4: SNR = 38.427880 dB; logSNR = 33.43043 dB V5: SNR = 37.905982 dB; logSNR = 33.49756 dB Mean: SNR = 38.446797 dB; logSNR = 33.53424 dB
res_paths = [
'results/data_3c353/V1/R2D2_residual_dirty_image.fits',
'results/data_3c353/V2/R2D2_residual_dirty_image.fits',
'results/data_3c353/V3/R2D2_residual_dirty_image.fits',
'results/data_3c353/V4/R2D2_residual_dirty_image.fits',
'results/data_3c353/V5/R2D2_residual_dirty_image.fits'
]
residuals = [fits.getdata(p) for p in res_paths]
titles = [f'Residual V{i+1}' for i in range(5)]
# Plot
fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(20, 4))
for i, (res, ax) in enumerate(zip(residuals, axes)):
im = ax.imshow(res, vmin=-0.0015, vmax=0.0015, cmap='hot')
ax.set_title(titles[i], fontsize=10)
ax.axis('off')
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()
The relative uncertainty image can be visulised as:
def log_im(im, a):
out = np.full_like(im, np.nan) # initialize with NaNs
valid_mask = np.isfinite(im) & (im >= 0)
out[valid_mask] = np.log10(a * im[valid_mask] + 1) / np.log10(a)
return out
RUI = fits.getdata('results/data_3c353/R2D2_std_over_mean_image.fits')
fig, ax = plt.subplots()
pcm = ax.imshow(log_im(RUI,1e4), cmap='turbo')
cb = plt.colorbar(pcm, ax=ax)
title = ax.set_title('relative uncertainty image ', y = -0.15)
print(f'MRU:{np.nanmean(RUI)*1e3:.2f}x1e3')
MRU:10.18x1e3