// Copyright (C) 2008 Soren Hauberg // // This program is free software; you can redistribute it and/or // modify it under the terms of the GNU General Public License // as published by the Free Software Foundation; either version 3 // of the License, or (at your option) any later version. // // This program is distributed in the hope that it will be useful, but // WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program; if not, see . #include inline int MAX(const int a, const int b) { if (a > b) return a; else return b; } template octave_value convn(const MatrixType &A, const MatrixType &B) { // Get sizes const octave_idx_type ndims = A.ndims(); const octave_idx_type B_numel = B.numel(); const dim_vector A_size = A.dims(); const dim_vector B_size = B.dims(); // Check input if (ndims != B.ndims()) { error("__convn__: first and second argument must have same dimensionality"); return octave_value(); } // Allocate output dim_vector out_size(A_size); for (octave_idx_type n = 0; n < ndims; n++) out_size(n) = MAX(A_size(n) - B_size(n) + 1, 0); MatrixType out = MatrixType(out_size); const octave_idx_type out_numel = out.numel(); // Iterate over every element of 'out'. dim_vector idx_dim(ndims); Array A_idx(idx_dim); Array B_idx(idx_dim); Array out_idx(idx_dim, 0); for (octave_idx_type i = 0; i < out_numel; i++) { // For each neighbour SumType sum = 0; for (octave_idx_type n = 0; n < ndims; n++) B_idx(n) = 0; for (octave_idx_type j = 0; j < B_numel; j++) { for (octave_idx_type n = 0; n < ndims; n++) A_idx(n) = out_idx(n) + (B_size(n)-1-B_idx(n)); sum += A(A_idx)*B(B_idx); B.increment_index(B_idx, B_size); } // Compute filter result out(out_idx) = sum; // Prepare for next iteration out.increment_index(out_idx, out_size); OCTAVE_QUIT; } return octave_value(out); } DEFUN_DLD(__convn__, args, , "\ -*- texinfo -*-\n\ @deftypefn {Loadable Function} __convn__(@var{A}, @var{B})\n\ N-dimensional convolution. Only the valid part is computed. This is an internal\n\ function, and should not be called directly. Use @code{convn} instead.\n\ @seealso{convn}\n\ @end deftypefn\n\ ") { octave_value_list retval; if (args.length() != 2) { print_usage (); return retval; } // Take action depending on input type if (args(0).is_real_matrix() && args(1).is_real_matrix()) { const NDArray A = args(0).array_value(); const NDArray B = args(1).array_value(); retval(0) = convn(A, B); } else if (args(0).is_complex_matrix() && args(1).is_complex_matrix()) { const ComplexNDArray A = args(0).complex_matrix_value(); const ComplexNDArray B = args(1).complex_matrix_value(); retval(0) = convn(A, B); } else { error("__convn__: first and second input should be real, or complex arrays"); return retval; } return retval; }