Source code for pyrtools.pyramids.SteerablePyramidFreq

import warnings
import numpy as np
from scipy.special import factorial
from .pyramid import SteerablePyramidBase
from .c.wrapper import pointOp
from ..tools.utils import rcosFn


[docs] class SteerablePyramidFreq(SteerablePyramidBase): """Steerable frequency pyramid. Construct a steerable pyramid on matrix IM, in the Fourier domain. This is similar to Spyr, except that: + Reconstruction is exact (within floating point errors) + It can produce any number of orientation bands. - Typically slower, especially for non-power-of-two sizes. - Boundary-handling is circular. The squared radial functions tile the Fourier plane with a raised-cosine falloff. Angular functions are cos(theta- k*pi/order+1)^(order). Note that reconstruction will not be exact if the image has an odd shape (due to boundary-handling issues) or if the pyramid is complex with order=0. Notes ----- Transform described in [1]_, filter kernel design described in [2]_. Parameters ---------- image : `array_like` 2d image upon which to construct to the pyramid. height : 'auto' or `int`. The height of the pyramid. If 'auto', will automatically determine based on the size of `image`. If an int, must be non-negative. When height=0, only returns the residuals. order : `int`. The Gaussian derivative order used for the steerable filters. Default value is 3. Note that to achieve steerability the minimum number of orientation is `order` + 1, and is used here. To get more orientations at the same order, use the method `steer_coeffs` twidth : `int` The width of the transition region of the radial lowpass function, in octaves is_complex : `bool` Whether the pyramid coefficients should be complex or not. If True, the real and imaginary parts correspond to a pair of odd and even symmetric filters. If False, the coefficients only include the real part / odd symmetric filter. Attributes ---------- image : `array_like` The input image used to construct the pyramid. image_size : `tuple` The size of the input image. pyr_type : `str` or `None` Human-readable string specifying the type of pyramid. For base class, is None. pyr_coeffs : `dict` Dictionary containing the coefficients of the pyramid. Keys are `(level, band)` tuples and values are 1d or 2d numpy arrays (same number of dimensions as the input image), running from fine to coarse. pyr_size : `dict` Dictionary containing the sizes of the pyramid coefficients. Keys are `(level, band)` tuples and values are tuples. is_complex : `bool` Whether the coefficients are complex- or real-valued. References ---------- .. [1] E P Simoncelli and W T Freeman, "The Steerable Pyramid: A Flexible Architecture for Multi-Scale Derivative Computation," Second Int'l Conf on Image Processing, Washington, DC, Oct 1995. .. [2] A Karasaridis and E P Simoncelli, "A Filter Design Technique for Steerable Pyramid Image Transforms", ICASSP, Atlanta, GA, May 1996. """ def __init__(self, image, height='auto', order=3, twidth=1, is_complex=False): # in the Fourier domain, there's only one choice for how do edge-handling: circular. to # emphasize that thisisn'ta choice, we use None here. super().__init__(image=image, edge_type=None) self.pyr_type = 'SteerableFrequency' self.is_complex = is_complex # SteerablePyramidFreq doesn't have filters, they're constructed in the frequency space self.filters = {} self.order = int(order) if (image.shape[0] % 2 != 0) or (image.shape[1] % 2 != 0): warnings.warn("Reconstruction will not be perfect with odd-sized images") if self.order == 0 and self.is_complex: raise ValueError( "Complex pyramid cannot have order=0! See " "https://github.com/plenoptic-org/plenoptic/issues/326 " "for an explanation." ) # we can't use the base class's _set_num_scales method because the max height is calculated # slightly differently max_ht = np.floor(np.log2(min(self.image.shape))) - 2 if height == 'auto' or height is None: self.num_scales = int(max_ht) elif height > max_ht: raise ValueError("Cannot build pyramid higher than %d levels." % (max_ht)) elif height < 0: raise ValueError("Height must be a non-negative int.") else: self.num_scales = int(height) if self.order > 15 or self.order < 0: raise ValueError("order must be an integer in the range [0,15].") self.num_orientations = int(order + 1) if twidth <= 0: raise ValueError("twidth must be positive.") twidth = int(twidth) dims = np.asarray(self.image.shape) ctr = np.ceil((np.asarray(dims)+0.5)/2).astype(int) (xramp, yramp) = np.meshgrid(np.linspace(-1, 1, dims[1]+1)[:-1], np.linspace(-1, 1, dims[0]+1)[:-1]) angle = np.arctan2(yramp, xramp) log_rad = np.sqrt(xramp**2 + yramp**2) log_rad[ctr[0]-1, ctr[1]-1] = log_rad[ctr[0]-1, ctr[1]-2] log_rad = np.log2(log_rad) # Radial transition function (a raised cosine in log-frequency): (Xrcos, Yrcos) = rcosFn(twidth, (-twidth/2.0), np.asarray([0, 1])) Yrcos = np.sqrt(Yrcos) YIrcos = np.sqrt(1.0 - Yrcos**2) lo0mask = pointOp(log_rad, YIrcos, Xrcos[0], Xrcos[1]-Xrcos[0]) self._lo0mask = lo0mask imdft = np.fft.fftshift(np.fft.fft2(self.image)) hi0mask = pointOp(log_rad, Yrcos, Xrcos[0], Xrcos[1]-Xrcos[0]) self._hi0mask = hi0mask hi0dft = imdft * hi0mask.reshape(imdft.shape[0], imdft.shape[1]) hi0 = np.fft.ifft2(np.fft.ifftshift(hi0dft)) self.pyr_coeffs['residual_highpass'] = np.real(hi0) self.pyr_size['residual_highpass'] = hi0.shape lo0mask = lo0mask.reshape(imdft.shape[0], imdft.shape[1]) lodft = imdft * lo0mask self._anglemasks = [] self._himasks = [] self._lomasks = [] for i in range(self.num_scales): Xrcos -= np.log2(2) lutsize = 1024 Xcosn = np.pi * np.arange(-(2*lutsize+1), (lutsize+2)) / lutsize const = (2**(2*self.order))*(factorial(self.order, exact=True)**2)/ float(self.num_orientations*factorial(2*self.order, exact=True)) if self.is_complex: # TODO clean that up and give comments alfa = ((np.pi+Xcosn) % (2.0*np.pi)) - np.pi Ycosn = (2.0 * np.sqrt(const) * (np.cos(Xcosn) ** self.order) * (np.abs(alfa) < np.pi/2.0).astype(int)) else: Ycosn = np.sqrt(const) * (np.cos(Xcosn))**self.order log_rad_test = np.reshape(log_rad, (1, log_rad.shape[0] * log_rad.shape[1])) himask = pointOp(log_rad_test, Yrcos, Xrcos[0], Xrcos[1]-Xrcos[0]) himask = himask.reshape((lodft.shape[0], lodft.shape[1])) self._himasks.append(himask) anglemasks = [] for b in range(self.num_orientations): angle_tmp = np.reshape(angle, (1, angle.shape[0] * angle.shape[1])) anglemask = pointOp(angle_tmp, Ycosn, Xcosn[0]+np.pi*b/self.num_orientations, Xcosn[1]-Xcosn[0]) anglemask = anglemask.reshape(lodft.shape[0], lodft.shape[1]) anglemasks.append(anglemask) # that (-1j)**order term in the beginning will be 1, -j, -1, j for order 0, 1, 2, # 3, and will then loop again banddft = (-1j) ** self.order * lodft * anglemask * himask band = np.fft.ifft2(np.fft.ifftshift(banddft)) if not self.is_complex: self.pyr_coeffs[(i, b)] = np.real(band.copy()) else: self.pyr_coeffs[(i, b)] = band.copy() self.pyr_size[(i, b)] = band.shape self._anglemasks.append(anglemasks) dims = np.asarray(lodft.shape) ctr = np.ceil((dims+0.5)/2).astype(int) lodims = np.ceil((dims-0.5)/2).astype(int) loctr = np.ceil((lodims+0.5)/2).astype(int) lostart = ctr - loctr loend = lostart + lodims log_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]] angle = angle[lostart[0]:loend[0], lostart[1]:loend[1]] lodft = lodft[lostart[0]:loend[0], lostart[1]:loend[1]] YIrcos = np.abs(np.sqrt(1.0 - Yrcos**2)) log_rad_tmp = np.reshape(log_rad, (1, log_rad.shape[0] * log_rad.shape[1])) lomask = pointOp(log_rad_tmp, YIrcos, Xrcos[0], Xrcos[1]-Xrcos[0]) lomask = lomask.reshape(lodft.shape[0], lodft.shape[1]) self._lomasks.append(lomask) lodft = lodft * lomask lodft = np.fft.ifft2(np.fft.ifftshift(lodft)) self.pyr_coeffs['residual_lowpass'] = np.real(np.asarray(lodft).copy()) self.pyr_size['residual_lowpass'] = lodft.shape
[docs] def recon_pyr(self, levels='all', bands='all', twidth=1): """Reconstruct the image, optionally using subset of pyramid coefficients. Parameters ---------- levels : `list`, `int`, or {`'all'`, `'residual_highpass'`} If `list` should contain some subset of integers from `0` to `self.num_scales-1` (inclusive) and `'residual_lowpass'`. If `'all'`, returned value will contain all valid levels. Otherwise, must be one of the valid levels. bands : `list`, `int`, or `'all'`. If list, should contain some subset of integers from `0` to `self.num_orientations-1`. If `'all'`, returned value will contain all valid orientations. Otherwise, must be one of the valid orientations. twidth : `int` The width of the transition region of the radial lowpass function, in octaves Returns ------- recon : `np.array` The reconstructed image. """ if twidth <= 0: raise ValueError("twidth must be positive.") recon_keys = self._recon_keys(levels, bands) # make list of dims and bounds bound_list = [] dim_list = [] # we go through pyr_sizes from smallest to largest for dims in sorted(self.pyr_size.values()): if dims in dim_list: continue dim_list.append(dims) dims = np.asarray(dims) ctr = np.ceil((dims+0.5)/2).astype(int) lodims = np.ceil((dims-0.5)/2).astype(int) loctr = np.ceil((lodims+0.5)/2).astype(int) lostart = ctr - loctr loend = lostart + lodims bounds = (lostart[0], lostart[1], loend[0], loend[1]) bound_list.append(bounds) bound_list.append((0, 0, dim_list[-1][0], dim_list[-1][1])) dim_list.append((dim_list[-1][0], dim_list[-1][1])) # matlab code starts here dims = np.asarray(self.pyr_size['residual_highpass']) ctr = np.ceil((dims+0.5)/2.0).astype(int) (xramp, yramp) = np.meshgrid((np.arange(1, dims[1]+1)-ctr[1]) / (dims[1]/2.), (np.arange(1, dims[0]+1)-ctr[0]) / (dims[0]/2.)) angle = np.arctan2(yramp, xramp) log_rad = np.sqrt(xramp**2 + yramp**2) log_rad[ctr[0]-1, ctr[1]-1] = log_rad[ctr[0]-1, ctr[1]-2] log_rad = np.log2(log_rad) # Radial transition function (a raised cosine in log-frequency): (Xrcos, Yrcos) = rcosFn(twidth, (-twidth/2.0), np.asarray([0, 1])) Yrcos = np.sqrt(Yrcos) YIrcos = np.sqrt(1.0 - Yrcos**2) # from reconSFpyrLevs lutsize = 1024 Xcosn = np.pi * np.arange(-(2*lutsize+1), (lutsize+2)) / lutsize const = (2**(2*self.order))*(factorial(self.order, exact=True)**2) / float(self.num_orientations*factorial(2*self.order, exact=True)) Ycosn = np.sqrt(const) * (np.cos(Xcosn))**self.order # lowest band # initialize reconstruction if 'residual_lowpass' in recon_keys: nresdft = np.fft.fftshift(np.fft.fft2(self.pyr_coeffs['residual_lowpass'])) else: nresdft = np.zeros_like(self.pyr_coeffs['residual_lowpass']) resdft = np.zeros(dim_list[1]) + 0j bounds = (0, 0, 0, 0) for idx in range(len(bound_list)-2, 0, -1): diff = (bound_list[idx][2]-bound_list[idx][0], bound_list[idx][3]-bound_list[idx][1]) bounds = (bounds[0]+bound_list[idx][0], bounds[1]+bound_list[idx][1], bounds[0]+bound_list[idx][0] + diff[0], bounds[1]+bound_list[idx][1] + diff[1]) Xrcos -= np.log2(2.0) nlog_rad = log_rad[bounds[0]:bounds[2], bounds[1]:bounds[3]] nlog_rad_tmp = np.reshape(nlog_rad, (1, nlog_rad.shape[0]*nlog_rad.shape[1])) lomask = pointOp(nlog_rad_tmp, YIrcos, Xrcos[0], Xrcos[1]-Xrcos[0]) lomask = lomask.reshape(nresdft.shape[0], nresdft.shape[1]) lomask = lomask + 0j resdft[bound_list[1][0]:bound_list[1][2], bound_list[1][1]:bound_list[1][3]] = nresdft * lomask # middle bands for idx in range(1, len(bound_list)-1): bounds1 = (0, 0, 0, 0) bounds2 = (0, 0, 0, 0) for boundIdx in range(len(bound_list) - 1, idx - 1, -1): diff = (bound_list[boundIdx][2]-bound_list[boundIdx][0], bound_list[boundIdx][3]-bound_list[boundIdx][1]) bound2tmp = bounds2 bounds2 = (bounds2[0]+bound_list[boundIdx][0], bounds2[1]+bound_list[boundIdx][1], bounds2[0]+bound_list[boundIdx][0] + diff[0], bounds2[1]+bound_list[boundIdx][1] + diff[1]) bounds1 = bound2tmp nlog_rad1 = log_rad[bounds1[0]:bounds1[2], bounds1[1]:bounds1[3]] nlog_rad2 = log_rad[bounds2[0]:bounds2[2], bounds2[1]:bounds2[3]] dims = dim_list[idx] nangle = angle[bounds1[0]:bounds1[2], bounds1[1]:bounds1[3]] YIrcos = np.abs(np.sqrt(1.0 - Yrcos**2)) if idx > 1: Xrcos += np.log2(2.0) nlog_rad2_tmp = np.reshape(nlog_rad2, (1, nlog_rad2.shape[0]*nlog_rad2.shape[1])) lomask = pointOp(nlog_rad2_tmp, YIrcos, Xrcos[0], Xrcos[1]-Xrcos[0]) lomask = lomask.reshape(bounds2[2]-bounds2[0], bounds2[3]-bounds2[1]) lomask = lomask + 0j nresdft = np.zeros(dim_list[idx]) + 0j nresdft[bound_list[idx][0]:bound_list[idx][2], bound_list[idx][1]:bound_list[idx][3]] = resdft * lomask resdft = nresdft.copy() # reconSFpyrLevs if idx != 0 and idx != len(bound_list)-1: for b in range(self.num_orientations): nlog_rad1_tmp = np.reshape(nlog_rad1, (1, nlog_rad1.shape[0]*nlog_rad1.shape[1])) himask = pointOp(nlog_rad1_tmp, Yrcos, Xrcos[0], Xrcos[1]-Xrcos[0]) himask = himask.reshape(nlog_rad1.shape) nangle_tmp = np.reshape(nangle, (1, nangle.shape[0]*nangle.shape[1])) anglemask = pointOp(nangle_tmp, Ycosn, Xcosn[0]+np.pi*b/self.num_orientations, Xcosn[1]-Xcosn[0]) anglemask = anglemask.reshape(nangle.shape) # either the coefficients will already be real-valued (if # self.is_complex=False) or complex (if self.is_complex=True). in the # former case, this np.real() does nothing. in the latter, we want to only # reconstruct with the real portion curLev = self.num_scales-1 - (idx-1) band = np.real(self.pyr_coeffs[(curLev, b)]) if (curLev, b) in recon_keys: banddft = np.fft.fftshift(np.fft.fft2(band)) else: banddft = np.zeros(band.shape) resdft += ((np.power(-1+0j, 0.5))**(self.num_orientations-1) * banddft * anglemask * himask) # apply lo0mask Xrcos += np.log2(2.0) lo0mask = pointOp(log_rad, YIrcos, Xrcos[0], Xrcos[1]-Xrcos[0]) lo0mask = lo0mask.reshape(dims[0], dims[1]) resdft = resdft * lo0mask # residual highpass subband hi0mask = pointOp(log_rad, Yrcos, Xrcos[0], Xrcos[1]-Xrcos[0]) hi0mask = hi0mask.reshape(resdft.shape[0], resdft.shape[1]) if 'residual_highpass' in recon_keys: hidft = np.fft.fftshift(np.fft.fft2(self.pyr_coeffs['residual_highpass'])) else: hidft = np.zeros_like(self.pyr_coeffs['residual_highpass']) resdft += hidft * hi0mask outresdft = np.real(np.fft.ifft2(np.fft.ifftshift(resdft))) return outresdft