Source code for pyrtools.pyramids.WaveletPyramid

import numpy as np
from .pyramid import Pyramid
from .filters import parse_filter
from .c.wrapper import corrDn, upConv


[docs] class WaveletPyramid(Pyramid): """Multiscale wavelet pyramid Parameters ---------- image : `array_like` 1d or 2d image upon which to construct to the pyramid. height : 'auto' or `int`. The height of the pyramid. If 'auto', will automatically determine based on the size of `image`. filter_name : {'binomN', 'haar', 'qmf8', 'qmf12', 'qmf16', 'daub2', 'daub3', 'daub4', 'qmf5', 'qmf9', 'qmf13'} name of filter to use when constructing pyramid. All scaled so L-2 norm is 1.0 * `'binomN'` - binomial coefficient filter of order N-1 * `'haar'` - Haar wavelet * `'qmf8'`, `'qmf12'`, `'qmf16'` - Symmetric Quadrature Mirror Filters [1]_ * `'daub2'`, `'daub3'`, `'daub4'` - Daubechies wavelet [2]_ * `'qmf5'`, `'qmf9'`, `'qmf13'` - Symmetric Quadrature Mirror Filters [3]_, [4]_ edge_type : {'circular', 'reflect1', 'reflect2', 'repeat', 'zero', 'extend', 'dont-compute'} Specifies how to handle edges. Options are: * `'circular'` - circular convolution * `'reflect1'` - reflect about the edge pixels * `'reflect2'` - reflect, doubling the edge pixels * `'repeat'` - repeat the edge pixels * `'zero'` - assume values of zero outside image boundary * `'extend'` - reflect and invert * `'dont-compute'` - zero output when filter overhangs imput boundaries. Attributes ---------- image : `array_like` The input image used to construct the pyramid. image_size : `tuple` The size of the input image. pyr_type : `str` or `None` Human-readable string specifying the type of pyramid. For base class, is None. edge_type : `str` Specifies how edges were handled. pyr_coeffs : `dict` Dictionary containing the coefficients of the pyramid. Keys are `(level, band)` tuples and values are 1d or 2d numpy arrays (same number of dimensions as the input image) pyr_size : `dict` Dictionary containing the sizes of the pyramid coefficients. Keys are `(level, band)` tuples and values are tuples. is_complex : `bool` Whether the coefficients are complex- or real-valued. Only `SteerablePyramidFreq` can have a value of True, all others must be False. References ---------- .. [1] J D Johnston, "A filter family designed for use in quadrature mirror filter banks", Proc. ICASSP, pp 291-294, 1980. .. [2] I Daubechies, "Orthonormal bases of compactly supported wavelets", Commun. Pure Appl. Math, vol. 42, pp 909-996, 1988. .. [3] E P Simoncelli, "Orthogonal sub-band image transforms", PhD Thesis, MIT Dept. of Elec. Eng. and Comp. Sci. May 1988. Also available as: MIT Media Laboratory Vision and Modeling Technical Report #100. .. [4] E P Simoncelli and E H Adelson, "Subband image coding", Subband Transforms, chapter 4, ed. John W Woods, Kluwer Academic Publishers, Norwell, MA, 1990, pp 143--192. """ def __init__(self, image, height='auto', filter_name='qmf9', edge_type='reflect1'): super().__init__(image=image, edge_type=edge_type) self.pyr_type = 'Wavelet' self.filters = {} self.filters['lo_filter'] = parse_filter(filter_name, normalize=False) self.filters["hi_filter"] = WaveletPyramid._modulate_flip(self.filters['lo_filter']) assert self.filters['lo_filter'].shape == self.filters['hi_filter'].shape # Stagger sampling if filter is odd-length self.stagger = (self.filters['lo_filter'].size + 1) % 2 self._set_num_scales('lo_filter', height) # compute the number of channels per level if min(self.image.shape) == 1: self.num_orientations = 1 else: self.num_orientations = 3 self._build_pyr() def _modulate_flip(lo_filter): '''construct QMF/Wavelet highpass filter from lowpass filter modulate by (-1)^n, reverse order (and shift by one, which is handled by the convolution routines). This is an extension of the original definition of QMF's (e.g., see Simoncelli90). Parameters ---------- lo_filter : `array_like` one-dimensional array (or effectively 1d array) containing the lowpass filter to convert into the highpass filter. Returns ------- hi_filter : `np.array` The highpass filter constructed from the lowpass filter, same shape as the lowpass filter. ''' # check lo_filter is effectively 1D lo_filter_shape = lo_filter.shape assert lo_filter.size == max(lo_filter_shape) lo_filter = lo_filter.flatten() ind = np.arange(lo_filter.size, 0, -1) - (lo_filter.size + 1) // 2 hi_filter = lo_filter[::-1] * (-1.0) ** ind return hi_filter.reshape(lo_filter_shape) def _build_next(self, image): """Build the next level fo the Wavelet pyramid Should not be called by users directly, this is a helper function to construct the pyramid. Parameters ---------- image : `array_like` image to use to construct next level. Returns ------- lolo : `array_like` This is the result of applying the lowpass filter once if `image` is 1d, twice if it's 2d. It's downsampled by a factor of two from the original `image`. hi_tuple : `tuple` If `image` is 1d, this just contains `hihi`, the result of applying the highpass filter . If `image` is 2d, it is `(lohi, hilo, hihi)`, the result of applying the lowpass then the highpass, the highpass then the lowpass, and the highpass twice. All will be downsampled by a factor of two from the original `image`. """ if image.shape[1] == 1: lolo = corrDn(image=image, filt=self.filters['lo_filter'], edge_type=self.edge_type, step=(2, 1), start=(self.stagger, 0)) hihi = corrDn(image=image, filt=self.filters['hi_filter'], edge_type=self.edge_type, step=(2, 1), start=(1, 0)) return lolo, (hihi, ) elif image.shape[0] == 1: lolo = corrDn(image=image, filt=self.filters['lo_filter'].T, edge_type=self.edge_type, step=(1, 2), start=(0, self.stagger)) hihi = corrDn(image=image, filt=self.filters['hi_filter'].T, edge_type=self.edge_type, step=(1, 2), start=(0, 1)) return lolo, (hihi, ) else: lo = corrDn(image=image, filt=self.filters['lo_filter'], edge_type=self.edge_type, step=(2, 1), start=(self.stagger, 0)) hi = corrDn(image=image, filt=self.filters['hi_filter'], edge_type=self.edge_type, step=(2, 1), start=(1, 0)) lolo = corrDn(image=lo, filt=self.filters['lo_filter'].T, edge_type=self.edge_type, step=(1, 2), start=(0, self.stagger)) lohi = corrDn(image=hi, filt=self.filters['lo_filter'].T, edge_type=self.edge_type, step=(1, 2), start=(0, self.stagger)) hilo = corrDn(image=lo, filt=self.filters['hi_filter'].T, edge_type=self.edge_type, step=(1, 2), start=(0, 1)) hihi = corrDn(image=hi, filt=self.filters['hi_filter'].T, edge_type=self.edge_type, step=(1, 2), start=(0, 1)) return lolo, (lohi, hilo, hihi) def _build_pyr(self): im = self.image for lev in range(self.num_scales): im, higher_bands = self._build_next(im) for j, band in enumerate(higher_bands): self.pyr_coeffs[(lev, j)] = band self.pyr_size[(lev, j)] = band.shape self.pyr_coeffs['residual_lowpass'] = im self.pyr_size['residual_lowpass'] = im.shape def _recon_prev(self, image, lev, recon_keys, output_size, lo_filter, hi_filter, edge_type, stagger): """Reconstruct the previous level of the pyramid. Should not be called by users directly, this is a helper function for reconstructing the input image using pyramid coefficients. """ if self.num_orientations == 1: if output_size[0] == 1: recon = upConv(image=image, filt=lo_filter.T, edge_type=edge_type, step=(1, 2), start=(0, stagger), stop=output_size) if (lev, 0) in recon_keys: recon += upConv(image=self.pyr_coeffs[(lev, 0)], filt=hi_filter.T, edge_type=edge_type, step=(1, 2), start=(0, 1), stop=output_size) elif output_size[1] == 1: recon = upConv(image=image, filt=lo_filter, edge_type=edge_type, step=(2, 1), start=(stagger, 0), stop=output_size) if (lev, 0) in recon_keys: recon += upConv(image=self.pyr_coeffs[(lev, 0)], filt=hi_filter, edge_type=edge_type, step=(2, 1), start=(1, 0), stop=output_size) else: lo_size = ([self.pyr_size[(lev, 1)][0], output_size[1]]) hi_size = ([self.pyr_size[(lev, 0)][0], output_size[1]]) tmp_recon = upConv(image=image, filt=lo_filter.T, edge_type=edge_type, step=(1, 2), start=(0, stagger), stop=lo_size) recon = upConv(image=tmp_recon, filt=lo_filter, edge_type=edge_type, step=(2, 1), start=(stagger, 0), stop=output_size) bands_recon_dict = { 0: [{'filt': lo_filter.T, 'start': (0, stagger), 'stop': hi_size}, {'filt': hi_filter, 'start': (1, 0)}], 1: [{'filt': hi_filter.T, 'start': (0, 1), 'stop': lo_size}, {'filt': lo_filter, 'start': (stagger, 0)}], 2: [{'filt': hi_filter.T, 'start': (0, 1), 'stop': hi_size}, {'filt': hi_filter, 'start': (1, 0)}], } for band in range(self.num_orientations): if (lev, band) in recon_keys: tmp_recon = upConv(image=self.pyr_coeffs[(lev, band)], edge_type=edge_type, step=(1, 2), **bands_recon_dict[band][0]) recon += upConv(image=tmp_recon, edge_type=edge_type, step=(2, 1), stop=output_size, **bands_recon_dict[band][1]) return recon
[docs] def recon_pyr(self, filter_name=None, edge_type=None, levels='all', bands='all'): """Reconstruct the input image using pyramid coefficients. This function reconstructs the input image using pyramid coefficients. Parameters ---------- filter_name : {None, 'binomN', 'haar', 'qmf8', 'qmf12', 'qmf16', 'daub2', 'daub3', 'daub4', 'qmf5', 'qmf9', 'qmf13'} name of filter to use for reconstruction. All scaled so L-2 norm is 1.0 * None (default) - use `self.filter_name`, the filter used to construct the pyramid. * `'binomN'` - binomial coefficient filter of order N-1 * `'haar'` - Haar wavelet * `'qmf8'`, `'qmf12'`, `'qmf16'` - Symmetric Quadrature Mirror Filters [1]_ * `'daub2'`, `'daub3'`, `'daub4'` - Daubechies wavelet [2]_ * `'qmf5'`, `'qmf9'`, `'qmf13'` - Symmetric Quadrature Mirror Filters [3]_, [4]_ edge_type : {None, 'circular', 'reflect1', 'reflect2', 'repeat', 'zero', 'extend', 'dont-compute'} Specifies how to handle edges. Options are: * None (default) - use `self.edge_type`, the edge_type used to construct the pyramid * `'circular'` - circular convolution * `'reflect1'` - reflect about the edge pixels * `'reflect2'` - reflect, doubling the edge pixels * `'repeat'` - repeat the edge pixels * `'zero'` - assume values of zero outside image boundary * `'extend'` - reflect and inverts * `'dont-compute'` - zero output when filter overhangs imput boundaries. levels : `list`, `int`, or {`'all'`, `'residual_highpass'`} If `list` should contain some subset of integers from `0` to `self.num_scales-1` (inclusive) and `'residual_lowpass'`. If `'all'`, returned value will contain all valid levels. Otherwise, must be one of the valid levels. bands : `list`, `int`, or `'all'`. If list, should contain some subset of integers from `0` to `self.num_orientations-1`. If `'all'`, returned value will contain all valid orientations. Otherwise, must be one of the valid orientations. Returns ------- recon : `np.array` The reconstructed image. """ # Optional args if filter_name is None: lo_filter = self.filters['lo_filter'] hi_filter = self.filters['hi_filter'] stagger = self.stagger else: lo_filter = parse_filter(filter_name, normalize=False) hi_filter = WaveletPyramid._modulate_flip(lo_filter) stagger = (lo_filter.size + 1) % 2 if edge_type is None: edges = self.edge_type else: edges = edge_type recon_keys = self._recon_keys(levels, bands) # initialize reconstruction if 'residual_lowpass' in recon_keys: recon = self.pyr_coeffs['residual_lowpass'] else: recon = np.zeros_like(self.pyr_coeffs['residual_lowpass']) for lev in reversed(range(self.num_scales)): if self.num_orientations == 1: if lev == 0: output_size = self.image.shape else: output_size = self.pyr_size[(lev-1, 0)] else: output_size = (self.pyr_size[(lev, 0)][0] + self.pyr_size[(lev, 1)][0], self.pyr_size[(lev, 0)][1] + self.pyr_size[(lev, 1)][1]) recon = self._recon_prev(recon, lev, recon_keys, output_size, lo_filter, hi_filter, edges, stagger) return recon