// Copyright (C) 2008 Soren Hauberg
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, see .
#include
inline
int MAX(const int a, const int b)
{
if (a > b)
return a;
else
return b;
}
template
octave_value
convn(const MatrixType &A, const MatrixType &B)
{
// Get sizes
const octave_idx_type ndims = A.ndims();
const octave_idx_type B_numel = B.numel();
const dim_vector A_size = A.dims();
const dim_vector B_size = B.dims();
// Check input
if (ndims != B.ndims())
{
error("__convn__: first and second argument must have same dimensionality");
return octave_value();
}
// Allocate output
dim_vector out_size(A_size);
for (octave_idx_type n = 0; n < ndims; n++)
out_size(n) = MAX(A_size(n) - B_size(n) + 1, 0);
MatrixType out = MatrixType(out_size);
const octave_idx_type out_numel = out.numel();
// Iterate over every element of 'out'.
dim_vector idx_dim(ndims);
Array A_idx(idx_dim);
Array B_idx(idx_dim);
Array out_idx(idx_dim, 0);
for (octave_idx_type i = 0; i < out_numel; i++)
{
// For each neighbour
SumType sum = 0;
for (octave_idx_type n = 0; n < ndims; n++)
B_idx(n) = 0;
for (octave_idx_type j = 0; j < B_numel; j++)
{
for (octave_idx_type n = 0; n < ndims; n++)
A_idx(n) = out_idx(n) + (B_size(n)-1-B_idx(n));
sum += A(A_idx)*B(B_idx);
B.increment_index(B_idx, B_size);
}
// Compute filter result
out(out_idx) = sum;
// Prepare for next iteration
out.increment_index(out_idx, out_size);
OCTAVE_QUIT;
}
return octave_value(out);
}
DEFUN_DLD(__convn__, args, , "\
-*- texinfo -*-\n\
@deftypefn {Loadable Function} __convn__(@var{A}, @var{B})\n\
N-dimensional convolution. Only the valid part is computed. This is an internal\n\
function, and should not be called directly. Use @code{convn} instead.\n\
@seealso{convn}\n\
@end deftypefn\n\
")
{
octave_value_list retval;
if (args.length() != 2)
{
print_usage ();
return retval;
}
// Take action depending on input type
if (args(0).is_real_matrix() && args(1).is_real_matrix())
{
const NDArray A = args(0).array_value();
const NDArray B = args(1).array_value();
retval(0) = convn(A, B);
}
else if (args(0).is_complex_matrix() && args(1).is_complex_matrix())
{
const ComplexNDArray A = args(0).complex_matrix_value();
const ComplexNDArray B = args(1).complex_matrix_value();
retval(0) = convn(A, B);
}
else
{
error("__convn__: first and second input should be real, or complex arrays");
return retval;
}
return retval;
}