# HG changeset patch # User Jason Riedy # Date 1236736479 14400 # Node ID 12958b7148992e19c0b41268ce2f0ea1ae081d26 # Parent 02034a30a13a1a962fff2ca9b7fc055657f57140 Add an override to Octave's find() for permutation matrices. Because of find()'s count-limiting and direction arguments, this is slightly more complicated than just copying the permutation vector. I suspect this is a common operation for people who don't know about the 'vector' option to lu(). diff --git a/src/ChangeLog b/src/ChangeLog --- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,10 @@ +2009-03-10 Jason Riedy + + * DLD-FUNCTIONS/find.cc (find_nonzero_elem_idx): New override + for find on PermMatrix. + (find): Add a branch testing arg.is_perm_matrix () and calling the + above override. + 2009-03-10 John W. Eaton * c-file-ptr-stream.cc, dynamic-ld.cc, error.cc, lex.l, pager.cc, diff --git a/src/DLD-FUNCTIONS/find.cc b/src/DLD-FUNCTIONS/find.cc --- a/src/DLD-FUNCTIONS/find.cc +++ b/src/DLD-FUNCTIONS/find.cc @@ -333,6 +333,113 @@ template octave_value_list find_nonzero_elem_idx (const Sparse&, int, octave_idx_type, int); +octave_value_list +find_nonzero_elem_idx (const PermMatrix& v, int nargout, + octave_idx_type n_to_find, int direction) +// There are far fewer special cases to handle for a PermMatrix. +{ + octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ()); + + octave_idx_type nc = v.cols(); + octave_idx_type start_nc, end_nc, count; + + // Determine the range to search. + if (n_to_find < 0 || n_to_find >= nc) + { + start_nc = 0; + end_nc = nc; + n_to_find = nc; + count = nc; + } + else if (direction > 0) + { + start_nc = 0; + end_nc = n_to_find; + count = n_to_find; + } + else + { + start_nc = nc - n_to_find; + end_nc = nc; + count = n_to_find; + } + + bool scalar_arg = (v.rows () == 1 && v.cols () == 1); + + Matrix idx (count, 1); + Matrix i_idx (count, 1); + Matrix j_idx (count, 1); + // Every value is 1. + ArrayN val (dim_vector (count, 1), 1.0); + + if (count > 0) + { + const octave_idx_type* p = v.data (); + if (v.is_col_perm ()) + for (octave_idx_type k = 0; k < count; k++) + { + OCTAVE_QUIT; + const octave_idx_type j = start_nc + k; + const octave_idx_type i = p[j]; + i_idx(k) = static_cast (1+i); + j_idx(k) = static_cast (1+j); + idx(k) = j * nc + i + 1; + } + else + for (octave_idx_type k = 0; k < count; k++) + { + OCTAVE_QUIT; + const octave_idx_type i = start_nc + k; + const octave_idx_type j = p[i]; + // Scatter into the index arrays according to + // j adjusted by the start point. + const octave_idx_type koff = j - start_nc; + i_idx(koff) = static_cast (1+i); + j_idx(koff) = static_cast (1+j); + idx(koff) = j * nc + i + 1; + } + } + else if (scalar_arg) + { + // Same odd compatibility case as the other overrides. + idx.resize (0, 0); + i_idx.resize (0, 0); + j_idx.resize (0, 0); + val.resize (dim_vector (0, 0)); + } + + switch (nargout) + { + case 0: + case 1: + retval(0) = idx; + break; + + case 5: + retval(4) = nc; + // Fall through + + case 4: + retval(3) = nc; + // Fall through + + case 3: + retval(2) = val; + // Fall through! + + case 2: + retval(1) = j_idx; + retval(0) = i_idx; + break; + + default: + panic_impossible (); + break; + } + + return retval; +} + DEFUN_DLD (find, args, nargout, "-*- texinfo -*-\n\ @deftypefn {Loadable Function} {} find (@var{x})\n\ @@ -462,6 +569,12 @@ else gripe_wrong_type_arg ("find", arg); } + else if (arg.is_perm_matrix ()) { + PermMatrix P = arg.perm_matrix_value (); + + if (! error_state) + retval = find_nonzero_elem_idx (P, nargout, n_to_find, direction); + } else { if (arg.is_single_type ()) @@ -542,6 +655,24 @@ %! assert(j, [1; 2; 3]); %! assert(v, single([-1; 3; 2])); +%!test +%! pcol = [5 1 4 3 2]; +%! P = eye (5) (:, pcol); +%! [i, j, v] = find (P); +%! [ifull, jfull, vfull] = find (full (P)); +%! assert (i, ifull); +%! assert (j, jfull); +%! assert (all (v == 1)); + +%!test +%! prow = [5 1 4 3 2]; +%! P = eye (5) (prow, :); +%! [i, j, v] = find (P); +%! [ifull, jfull, vfull] = find (full (P)); +%! assert (i, ifull); +%! assert (j, jfull); +%! assert (all (v == 1)); + %!error find (); */