# HG changeset patch
# User Jason Riedy
# Date 1236736479 14400
# Node ID 12958b7148992e19c0b41268ce2f0ea1ae081d26
# Parent 02034a30a13a1a962fff2ca9b7fc055657f57140
Add an override to Octave's find() for permutation matrices.
Because of find()'s count-limiting and direction arguments, this is
slightly more complicated than just copying the permutation vector.
I suspect this is a common operation for people who don't know about
the 'vector' option to lu().
diff --git a/src/ChangeLog b/src/ChangeLog
--- a/src/ChangeLog
+++ b/src/ChangeLog
@@ -1,3 +1,10 @@
+2009-03-10 Jason Riedy
+
+ * DLD-FUNCTIONS/find.cc (find_nonzero_elem_idx): New override
+ for find on PermMatrix.
+ (find): Add a branch testing arg.is_perm_matrix () and calling the
+ above override.
+
2009-03-10 John W. Eaton
* c-file-ptr-stream.cc, dynamic-ld.cc, error.cc, lex.l, pager.cc,
diff --git a/src/DLD-FUNCTIONS/find.cc b/src/DLD-FUNCTIONS/find.cc
--- a/src/DLD-FUNCTIONS/find.cc
+++ b/src/DLD-FUNCTIONS/find.cc
@@ -333,6 +333,113 @@
template octave_value_list find_nonzero_elem_idx (const Sparse&, int,
octave_idx_type, int);
+octave_value_list
+find_nonzero_elem_idx (const PermMatrix& v, int nargout,
+ octave_idx_type n_to_find, int direction)
+// There are far fewer special cases to handle for a PermMatrix.
+{
+ octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
+
+ octave_idx_type nc = v.cols();
+ octave_idx_type start_nc, end_nc, count;
+
+ // Determine the range to search.
+ if (n_to_find < 0 || n_to_find >= nc)
+ {
+ start_nc = 0;
+ end_nc = nc;
+ n_to_find = nc;
+ count = nc;
+ }
+ else if (direction > 0)
+ {
+ start_nc = 0;
+ end_nc = n_to_find;
+ count = n_to_find;
+ }
+ else
+ {
+ start_nc = nc - n_to_find;
+ end_nc = nc;
+ count = n_to_find;
+ }
+
+ bool scalar_arg = (v.rows () == 1 && v.cols () == 1);
+
+ Matrix idx (count, 1);
+ Matrix i_idx (count, 1);
+ Matrix j_idx (count, 1);
+ // Every value is 1.
+ ArrayN val (dim_vector (count, 1), 1.0);
+
+ if (count > 0)
+ {
+ const octave_idx_type* p = v.data ();
+ if (v.is_col_perm ())
+ for (octave_idx_type k = 0; k < count; k++)
+ {
+ OCTAVE_QUIT;
+ const octave_idx_type j = start_nc + k;
+ const octave_idx_type i = p[j];
+ i_idx(k) = static_cast (1+i);
+ j_idx(k) = static_cast (1+j);
+ idx(k) = j * nc + i + 1;
+ }
+ else
+ for (octave_idx_type k = 0; k < count; k++)
+ {
+ OCTAVE_QUIT;
+ const octave_idx_type i = start_nc + k;
+ const octave_idx_type j = p[i];
+ // Scatter into the index arrays according to
+ // j adjusted by the start point.
+ const octave_idx_type koff = j - start_nc;
+ i_idx(koff) = static_cast (1+i);
+ j_idx(koff) = static_cast (1+j);
+ idx(koff) = j * nc + i + 1;
+ }
+ }
+ else if (scalar_arg)
+ {
+ // Same odd compatibility case as the other overrides.
+ idx.resize (0, 0);
+ i_idx.resize (0, 0);
+ j_idx.resize (0, 0);
+ val.resize (dim_vector (0, 0));
+ }
+
+ switch (nargout)
+ {
+ case 0:
+ case 1:
+ retval(0) = idx;
+ break;
+
+ case 5:
+ retval(4) = nc;
+ // Fall through
+
+ case 4:
+ retval(3) = nc;
+ // Fall through
+
+ case 3:
+ retval(2) = val;
+ // Fall through!
+
+ case 2:
+ retval(1) = j_idx;
+ retval(0) = i_idx;
+ break;
+
+ default:
+ panic_impossible ();
+ break;
+ }
+
+ return retval;
+}
+
DEFUN_DLD (find, args, nargout,
"-*- texinfo -*-\n\
@deftypefn {Loadable Function} {} find (@var{x})\n\
@@ -462,6 +569,12 @@
else
gripe_wrong_type_arg ("find", arg);
}
+ else if (arg.is_perm_matrix ()) {
+ PermMatrix P = arg.perm_matrix_value ();
+
+ if (! error_state)
+ retval = find_nonzero_elem_idx (P, nargout, n_to_find, direction);
+ }
else
{
if (arg.is_single_type ())
@@ -542,6 +655,24 @@
%! assert(j, [1; 2; 3]);
%! assert(v, single([-1; 3; 2]));
+%!test
+%! pcol = [5 1 4 3 2];
+%! P = eye (5) (:, pcol);
+%! [i, j, v] = find (P);
+%! [ifull, jfull, vfull] = find (full (P));
+%! assert (i, ifull);
+%! assert (j, jfull);
+%! assert (all (v == 1));
+
+%!test
+%! prow = [5 1 4 3 2];
+%! P = eye (5) (prow, :);
+%! [i, j, v] = find (P);
+%! [ifull, jfull, vfull] = find (full (P));
+%! assert (i, ifull);
+%! assert (j, jfull);
+%! assert (all (v == 1));
+
%!error find ();
*/