function [eMean, eVar] = LearnEpitome(x, eSize, patchSize, numIterations, varargin) %LEARNEPITOME Learns the epitome of array data % [eMean, eVar] = LearnEpitome(x, eSize, patchSize, numIterations) learns % an epitome of size "eSize" of the data, x, using patches of size % "patchSize" with "numIterations" number of iterations. eSize > % patchSize. % % x is array data with dimension no greater than 3, eg. videos, images, % audio, etc. The array can however have one additional trailing % dimension that is assumed to be independent, such as colour. % % [...] = LearnEpitome(..., 'init', eMean, eVar) uses the provided % epitome mean and variance as the initialization of the epitome. % Default value of the epitome mean is random noise with mean and % variance of that of x and a variance of unity. % % [...] = LearnEpitome(..., 'patchPeriod', T) samples patches from x on % average [patchSize .* T] distance apart. Default value of T is 0.5. A % value of 0 means that every overlapping patch is used. % % [...] = LearnEpitome(..., 'fullPost', P) uses the full posterior during % learning if P is true and uses the mode if P if false. Default value % of P is true. % % [...] = LearnEpitome(..., 'grayScale', G) identifies the data x as % being gray-scale, meaning that the trailing dimension is a singleton. % This argument disambiguates a 2D colour image from a 3D gray-scale % video as they both have 3 non-singleton dimensions, i.e. % length(size(x)) is 3 in both cases. Default value of G is false. % % [...] = LearnEpitome(..., 'filename', F) saves partial results using % the filename provided. By default, no partial results are saved. % % See also EpitomeReconstruct % Reference: % % 1. V. Cheung, B. J. Frey, and N. Jojic. Video epitomes. In Proc. % IEEE Conf. Computer Vision and Pattern Recognition (CVPR), 2005. % % 2. N. Jojic, B. J. Frey, and A. Kannan. Epitomic analysis of appearance % and shape. In Proc. IEEE Conf. Computer Vision (ICCV), 2003. % Copyright (C) 2005 Vincent Cheung (vincent@psi.toronto.edu, http://www.psi.toronto.edu/~vincent/) % % This program is free software; you can redistribute it and/or % modify it under the terms of the GNU General Public License % as published by the Free Software Foundation; either version 2 % of the License, or (at your option) any later version. % % This program is distributed in the hope that it will be useful, % but WITHOUT ANY WARRANTY; without even the implied warranty of % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the % GNU General Public License for more details. % % $Revision: 0.9 $ $Date: Sept. 30, 2005 $ if(nargin < 4) error('MATLAB:LearnEpitome:NotEnoughInputs', 'Requires at least 4 inputs.') end % default values for optional arguments patchPeriod = 0.5; fullPost = true; verbose = true; grayScale = false; filename = ''; eMean = []; eVar = []; % parse the variable number of arguments i = 1; while(i < length(varargin)) if(isstr(varargin{i})) if(strcmp(upper(varargin{i}), 'INIT') && i < length(varargin)-1) if(isnumeric(varargin{i+1}) & isnumeric(varargin{i+2})) eMean = varargin{i+1}; eVar = varargin{i+2}; i = i + 2; end elseif(strcmp(upper(varargin{i}), 'PATCHPERIOD')) if(isnumeric(varargin{i+1})) patchPeriod = varargin{i+1}; i = i + 1; end elseif(strcmp(upper(varargin{i}), 'FULLPOST')) if(islogical(varargin{i+1})) fullPost = varargin{i+1}; i = i + 1; end elseif(strcmp(upper(varargin{i}), 'VERBOSE')) if(islogical(varargin{i+1})) verbose = varargin{i+1}; i = i + 1; end elseif(strcmp(upper(varargin{i}), 'GRAYSCALE')) if(islogical(varargin{i+1})) grayScale = varargin{i+1}; i = i + 1; end elseif(strcmp(upper(varargin{i}), 'FILENAME')) if(isstr(varargin{i+1})) filename = varargin{i+1}; i = i + 1; end end end i = i + 1; end dataNumDim = length(size(x)) - ~grayScale; % handle things differently depending on whether or not an intialization is % provided for the epitome if(length(eMean) > 0 && length(eVar) > 0) eNumDim = length(size(eMean)) - ~grayScale; % use the size of the intialization as the size of the epitome eSize = size(eMean); if(~grayScale) eSize = eSize(1:end-1); end % permute the dimensions of 2D and 1D data to simulate 3D data if(eNumDim == 2) eMean = permute(eMean, [4 1 2 3]); eVar = permute(eVar, [4 1 2 3]); end if(eNumDim == 1) eMean = permute(eMean, [3 4 1 2]); eVar = permute(eVar, [3 4 1 2]); end else eNumDim = length(eSize); end if(eNumDim > 3 || dataNumDim > 3) error('MATLAB:LearnEpitome:TooManyDimensions', 'The dimension of the data cannot exceed 3.'); end % permute the dimensions of 2D and 1D data to simulate 3D data if(dataNumDim == 2) x = permute(x, [4 1 2 3]); end if(dataNumDim == 1) x = permute(x, [3 4 1 2]); end % pad the epitome and patch size until they describe 3D data while(length(eSize) < 3) eSize = [1 eSize]; end while(length(patchSize) < 3) patchSize = [1 patchSize]; end % pad only if the patchPeriod is of length 2 if(length(patchPeriod) == 2) patchPeriod = [1 patchPeriod]; end % the minimum epitome variance minVar = 1e-6; % numerical precision tolerance tol = 1e-10; % the size of the data, making sure to accomodate trailing singletons xSize = size(x); if(length(xSize) == 3) xSize = [xSize 1]; end % initialize if(length(eMean) == 0 || length(eVar) == 0) % pixel sample statistics sumX = sum(sum(sum(x, 1), 2), 3); sumXX = sum(sum(sum(x.^2, 1), 2), 3); pixelMean = sumX ./ prod(xSize(1:end-1)); pixelStd = sqrt(sumXX ./ prod(xSize(1:end-1)) - pixelMean.^2); % initialize the mean to be normally distributed using the sample mean and % standard deviation, making sure to keep the pixels between 0 and 1 % put the assumed independent dimension first for now, and permute the % array after intialization eMean = zeros([xSize(end), eSize]) - 1; done = false; while(~done) done = true; for k = 1 : xSize(end) idx = eMean(k, :, :, :) < 0 | eMean(k, :, :, :) > 1; N = sum(idx(:)); if(N > 0) done = false; eMean(k, idx) = randn(N, 1) .* pixelStd(k) + pixelMean(k); end end end eMean = permute(eMean, [2 3 4 1]); % initialize the variance to unity eVar = ones(size(eMean)); end % last possible location of patches in x xEndIdx = [xSize(1:end-1)] - patchSize + 1; % the distance between sampled patches patchDisp = max(floor(patchSize .* patchPeriod), 1); % patches can wiggle from their initial position up to half of the patch % displacement in each direction patchWiggle = floor(patchDisp/2); % set-up a grid of patch locations leftOver = mod(xEndIdx-1, patchDisp); leftOver(leftOver == 0) = patchDisp(leftOver == 0); % the first patch location not on an edge startIdx = floor((leftOver + patchDisp)/2) + 1; % the grid tPatchIdx = [1, startIdx(1):patchDisp(1):xEndIdx(1)-1, xEndIdx(1)]; rPatchIdx = [1, startIdx(2):patchDisp(2):xEndIdx(2)-1, xEndIdx(2)]; cPatchIdx = [1, startIdx(3):patchDisp(3):xEndIdx(3)-1, xEndIdx(3)]; patchFixedIdx = zeros(length(tPatchIdx) * length(rPatchIdx) * length(cPatchIdx), 3); count = 1; % set-up the grid, where the patch locations are the rows in the % patchFixedIdx matrix for t = tPatchIdx for r = rPatchIdx for c = cPatchIdx patchFixedIdx(count, :) = [t r c]; count = count + 1; end end end % cumulative sum matrices eeCumSum = zeros(eSize+patchSize); % the size of the FFTs FFTSize = eSize + patchSize - 1; % set the FFT optimization strategy % fftw('planner','exhaustive'); % the all ones FFT (for summing log(eVar)) onesFFT = fft(fft(fft(ones(patchSize), FFTSize(1), 1), FFTSize(2), 2), FFTSize(3), 3); % subscripts to be used to obtain the valid areas of the convolution % note that the epitome is circularly extended by the size of the patch % during convolution convSubs = cell(3, 1); for i = 1 : length(convSubs) convSubs{i} = patchSize(i) : eSize(i) + patchSize(i) - 1; end % temporary matrices used when only the mode of the posterior is used if(~fullPost) minDist = zeros(xSize(1:end-1)); minIdx = zeros(xSize(1:end-1)); end % temporary matrices for collecting sufficient statistics sumQ = zeros(eSize); sumQX = zeros(size(eMean)); sumQXX = zeros(size(eMean)); tic for iteration = 1 : numIterations % the training patches patchIdx = patchFixedIdx; % allow the patches not on a "face" to move from their initial position, % but ensure that the patches fall between 1 and xEndIdx tempIdx = patchIdx(:, 1) ~= 1 & patchIdx(:, 1) ~= xEndIdx(1); patchIdx(tempIdx, 1) = min(max(patchIdx(tempIdx, 1) + floor(rand(sum(tempIdx), 1) * (patchWiggle(1)*2 + 1)) - patchWiggle(1), 1), xEndIdx(1)); tempIdx = patchIdx(:, 2) ~= 1 & patchIdx(:, 2) ~= xEndIdx(2); patchIdx(tempIdx, 2) = min(max(patchIdx(tempIdx, 2) + floor(rand(sum(tempIdx), 1) * (patchWiggle(2)*2 + 1)) - patchWiggle(2), 1), xEndIdx(2)); tempIdx = patchIdx(:, 3) ~= 1 & patchIdx(:, 3) ~= xEndIdx(3); patchIdx(tempIdx, 3) = min(max(patchIdx(tempIdx, 3) + floor(rand(sum(tempIdx), 1) * (patchWiggle(3)*2 + 1)) - patchWiggle(3), 1), xEndIdx(3)); % circularly wrap the epitome and take its cumulative sum eeCumSum(2:end, 2:end, 2:end) = cumsum(cumsum(cumsum(sum(eMean([1:end 1:patchSize(1)-1], [1:end 1:patchSize(2)-1], [1:end 1:patchSize(3)-1], :).^2 ./ eVar([1:end 1:patchSize(1)-1], [1:end 1:patchSize(2)-1], [1:end 1:patchSize(3)-1], :), 4), 1), 2), 3); eMeanOverVar = eMean([1:end 1:patchSize(1)-1], [1:end 1:patchSize(2)-1], [1:end 1:patchSize(3)-1], :) ./ eVar([1:end 1:patchSize(1)-1], [1:end 1:patchSize(2)-1], [1:end 1:patchSize(3)-1], :); eMeanOverVarFFT = fft(fft(fft(eMeanOverVar, FFTSize(1), 1), FFTSize(2), 2), FFTSize(3), 3); eInvVarFFT = fft(fft(fft(1 ./ eVar([1:end 1:patchSize(1)-1], [1:end 1:patchSize(2)-1], [1:end 1:patchSize(3)-1], :), FFTSize(1), 1), FFTSize(2), 2), FFTSize(3), 3); % compute the sum of the log(eVar) for each patch in the epitome eLogVarFFT = fft(fft(fft(sum(log(eVar([1:end 1:patchSize(1)-1], [1:end 1:patchSize(2)-1], [1:end 1:patchSize(3)-1], :)), 4), FFTSize(1), 1), FFTSize(2), 2), FFTSize(3), 3); % ifft can give a small imaginary component because of round-off, so % only keep the real part of the result eLogVarFullConv = real(ifft(ifft(ifft(eLogVarFFT .* onesFFT, [], 3), [], 2), [], 1)); eLogVarSum = eLogVarFullConv(convSubs{:}); % clear the matrices used for collecting sufficient statistics sumQ(:) = 0; sumQX(:) = 0; sumQXX(:) = 0; % for each training case for i = 1 : size(patchIdx, 1) % display progress and estimation of time to completion if(verbose && mod(i-1, ceil(length(patchIdx)/10)) == 0) disp(['Iteration ' num2str(iteration) ', ' num2str(round((i-1)/length(patchIdx)*100)) '% Complete']); if(iteration > 1 || i > 1) disp(['Time Remaining: ', num2str(toc * (1 / (((iteration-1) * length(patchIdx) + (i-1)) / (length(patchIdx) * numIterations)) - 1)) ' seconds']); end disp(' '); end % the location of the current patch time = patchIdx(i, 1); row = patchIdx(i, 2); col = patchIdx(i, 3); % ============================ % E-Step % ============================ xPatch = x(time:time+patchSize(1)-1, row:row+patchSize(2)-1, col:col+patchSize(3)-1, :); % the sum of eMean.^2 ./ eVar for each patch in the epitome eeSum = eeCumSum(patchSize(1)+1:end, patchSize(2)+1:end, patchSize(3)+1:end) - eeCumSum(patchSize(1)+1:end, patchSize(2)+1:end, 1:end-patchSize(3)) - eeCumSum(patchSize(1)+1:end, 1:end-patchSize(2), patchSize(3)+1:end) + eeCumSum(patchSize(1)+1:end, 1:end-patchSize(2), 1:end-patchSize(3)) - (eeCumSum(1:end-patchSize(1), patchSize(2)+1:end, patchSize(3)+1:end) - eeCumSum(1:end-patchSize(1), patchSize(2)+1:end, 1:end-patchSize(3)) - eeCumSum(1:end-patchSize(1), 1:end-patchSize(2), patchSize(3)+1:end) + eeCumSum(1:end-patchSize(1), 1:end-patchSize(2), 1:end-patchSize(3))); % compute correlations using convolutions via FFT % flip x when computing the FFT so that correlations are performed xPatchFFT = fft(fft(fft(xPatch(end:-1:1, end:-1:1, end:-1:1, :), FFTSize(1), 1), FFTSize(2), 2), FFTSize(3), 3); xxPatchFFT = fft(fft(fft(xPatch(end:-1:1, end:-1:1, end:-1:1, :).^2, FFTSize(1), 1), FFTSize(2), 2), FFTSize(3), 3); % xeFullConv = sum(real(ifft(ifft(ifft(eMeanOverVarFFT .* xPatchFFT, [], 1), [], 2), [], 3)), 4); % xxFullConv = sum(real(ifft(ifft(ifft(eInvVarFFT .* xxPatchFFT, [], 1), [], 2), [], 3)), 4); xeFFT = sum(eMeanOverVarFFT .* xPatchFFT, 4); xxFFT = sum(eInvVarFFT .* xxPatchFFT, 4); xeFullConv = real(ifft(ifft(ifft(xeFFT, [], 3), [], 2), [], 1)); xxFullConv = real(ifft(ifft(ifft(xxFFT, [], 3), [], 2), [], 1)); % only look at the part of the convolution that does not assume % zero-padded arrays. xeSum = xeFullConv(convSubs{:}); xxSum = xxFullConv(convSubs{:}); % the unnormalized log posterior post = -1/2 * (eeSum - 2*xeSum + xxSum + eLogVarSum); % ============================ % M-Step % ============================ if(fullPost) % logsum to normalize alpha = max(post(:)) - log(realmax)/2 + 2*log(prod(size(post))); normalizingConst = log(sum(exp(post(:) - alpha))) + alpha; post = exp(post - normalizingConst); % use convolution to compute sufficient statistics % flip the posterior when computing the FFT because xPatchFFT and % xxPatchFFT are flipped since this time we need convolutions, not % correlations postFFT = fft(fft(fft(post(end:-1:1, end:-1:1, end:-1:1), FFTSize(1), 1), FFTSize(2), 2), FFTSize(3), 3); % denominator for the mean and variance tempFullConv = real(ifft(ifft(ifft(postFFT .* onesFFT, [], 3), [], 2), [], 1)); % flip the convolution since the posterior is flipped tempFullConv = tempFullConv(end:-1:1, end:-1:1, end:-1:1); % circularly wrap tempFullConv(1:patchSize(1)-1, :, :) = tempFullConv(1:patchSize(1)-1, :, :) + tempFullConv(end-patchSize(1)+2:end, :, :); tempFullConv(:, 1:patchSize(2)-1, :) = tempFullConv(:, 1:patchSize(2)-1, :) + tempFullConv(:, end-patchSize(2)+2:end, :); tempFullConv(:, :, 1:patchSize(3)-1) = tempFullConv(:, :, 1:patchSize(3)-1) + tempFullConv(:, :, end-patchSize(3)+2:end); sumQ = sumQ + tempFullConv(1:eSize(1), 1:eSize(2), 1:eSize(3)); % mean tempFullConv = real(ifft(ifft(ifft(postFFT(:, :, :, ones(size(x,4), 1)) .* xPatchFFT, [], 3), [], 2), [], 1)); tempFullConv = tempFullConv(end:-1:1, end:-1:1, end:-1:1, :); % circularly wrap tempFullConv(1:patchSize(1)-1, :, :, :) = tempFullConv(1:patchSize(1)-1, :, :, :) + tempFullConv(end-patchSize(1)+2:end, :, :, :); tempFullConv(:, 1:patchSize(2)-1, :, :) = tempFullConv(:, 1:patchSize(2)-1, :, :) + tempFullConv(:, end-patchSize(2)+2:end, :, :); tempFullConv(:, :, 1:patchSize(3)-1, :) = tempFullConv(:, :, 1:patchSize(3)-1, :) + tempFullConv(:, :, end-patchSize(3)+2:end, :); sumQX = sumQX + tempFullConv(1:eSize(1), 1:eSize(2), 1:eSize(3), :); % variance tempFullConv = real(ifft(ifft(ifft(postFFT(:, :, :, ones(size(x,4), 1)) .* xxPatchFFT, [], 3), [], 2), [], 1)); tempFullConv = tempFullConv(end:-1:1, end:-1:1, end:-1:1, :); % circularly wrap tempFullConv(1:patchSize(1)-1, :, :, :) = tempFullConv(1:patchSize(1)-1, :, :, :) + tempFullConv(end-patchSize(1)+2:end, :, :, :); tempFullConv(:, 1:patchSize(2)-1, :, :) = tempFullConv(:, 1:patchSize(2)-1, :, :) + tempFullConv(:, end-patchSize(2)+2:end, :, :); tempFullConv(:, :, 1:patchSize(3)-1, :) = tempFullConv(:, :, 1:patchSize(3)-1, :) + tempFullConv(:, :, end-patchSize(3)+2:end, :); sumQXX = sumQXX + tempFullConv(1:eSize(1), 1:eSize(2), 1:eSize(3), :); % just use the mode of the posterior for the M-step else % find the maximum log posterior and the corresponding index [maxDist, tempIdx] = max(post(:)); [t, r, c] = ind2sub(size(post), tempIdx); % collect sufficient statistics, taking into consideration that the % epitome wraps around eWrapIdx = {mod(t-1:t+patchSize(1)-2, eSize(1)) + 1, mod(r-1:r+patchSize(2)-2, eSize(2)) + 1, mod(c-1:c+patchSize(3)-2, eSize(3)) + 1}; sumQ(eWrapIdx{:}) = sumQ(eWrapIdx{:}) + 1; sumQX(eWrapIdx{:}, :) = sumQX(eWrapIdx{:}, :) + xPatch; sumQXX(eWrapIdx{:}, :) = sumQXX(eWrapIdx{:}, :) + xPatch.^2; end end % avoid numerical problems zeroIdx = sumQ(:, :, :, ones(size(x, 4), 1)) <= tol; sumQX(zeroIdx) = eMean(zeroIdx); sumQXX(zeroIdx) = eVar(zeroIdx) + eMean(zeroIdx).^2; sumQ(sumQ <= tol) = 1; % compute the new epitome eMean = sumQX ./ sumQ(:, :, :, ones(size(x, 4), 1)); eVar = sumQXX ./ sumQ(:, :, :, ones(size(x, 4), 1)) - eMean.^2; % make sure that the mean is within the range 0 - 1 (may not because of % numerical issues) eMean = min(max(eMean, 0), 1); % enforce a minimum variance eVar = max(eVar, minVar); % save partial data if filename is provided if(length(filename) > 0) save([filename '_' num2str(iteration)]); end end if(eNumDim == 2) eMean = permute(eMean, [2 3 4 1]); eVar = permute(eVar, [2 3 4 1]); end if(eNumDim == 1) eMean = permute(eMean, [3 4 1 2]); eVar = permute(eVar, [3 4 1 2]); end