help-octave
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

Patch for sparse LHS indexing in assignment [Was: indexed assignment fo


From: David Bateman
Subject: Patch for sparse LHS indexing in assignment [Was: indexed assignment for sparse matrices in ov-2.9.4]
Date: Sun, 22 Jan 2006 01:17:24 +0100
User-agent: Mozilla Thunderbird 1.0.2 (Windows/20050317)

Ok, its a real bug, and a pain in the arse. The problem is that the octave sparse type wants the elements of the matrix strictly increasing in terms of their row and column numbering. Therefore when you disorder them in this fashion in an assign statement, you have to maintain an index of the correspondence between the LHS and RHS indexing. That is you need the equivalent of "[y,idx]=sort(x)" for the octave idx_vector class. This doesn't currently exist. I can create a special sorting class that does this, and have done so. The problem is then that something like "a(1,:) = b(:);" requires that the column indexing by ":" has a vector the length of the number of columns in order to maintain the corrspondence between the LHS and RHS indexing. So you pay for this behaviour in terms of memory consumption for all cases, to treat a case that is unusual and easily written in a different manner. That is, if you can write "a([3,2,1],:)=b", you might just as well write "a([1,2,3],:)=b([3,2,1],:)". Grrr, ok we have to allow this behaviour, so consider the attached patch. I presume John will apply this in due course...

Regards
David


2006-01-21  David Bateman  <address@hidden>

        * sparse-sort.cc (bool octave_sparse_sidxl_comp): 64-bit fix.
        (bool octave_idx_vector_comp): New function.
        (template class octave_sort<octave_idx_vector_sort *>): Instantiate
        indexed idx_vector sorting function.
        * sparse-sort.h (class octave_sparse_sort_idxl): 64-bit fix.
        (class octave_idx_vector_sort): New class for indexed idx_vector
        sorting.
        (bool octave_idx_vector_comp): Declaration.
        * Sparse.cc (int assign1(Sparse<LT>&, Sparse<RT>&)): Treat cases of
        unordered LHS indexes in assignment using new octave_idx_vector_sort
        class.
        (int assign(Sparse<LT>&, Sparse<RT>&)): ditto.

2006-01-21  David Bateman  <address@hidden>

        * build_sparsetest.sh: Add new un-ordered indexing, assignment and
        deletion tests.
*** ./liboctave/Sparse.cc.orig5 2006-01-21 23:52:26.000000000 +0100
--- ./liboctave/Sparse.cc       2006-01-22 01:11:07.353089145 +0100
***************
*** 1938,1944 ****
    octave_idx_type nc = lhs.cols ();
    octave_idx_type nz = lhs.nnz ();
  
!   octave_idx_type n = lhs_idx.freeze (lhs_len, "vector", true, 
liboctave_wrore_flag);
  
    if (n != 0)
      {
--- 1938,1945 ----
    octave_idx_type nc = lhs.cols ();
    octave_idx_type nz = lhs.nnz ();
  
!   octave_idx_type n = lhs_idx.freeze (lhs_len, "vector", true, 
!                                     liboctave_wrore_flag);
  
    if (n != 0)
      {
***************
*** 1953,1958 ****
--- 1954,1996 ----
        {
          octave_idx_type new_nnz = lhs.nnz ();
  
+         OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx, n);
+         if (! lhs_idx.is_colon ())
+           {
+             // Ok here we have to be careful with the indexing,
+             // to treat cases like "a([3,2,1]) = b", and still 
+             // handle the need for strict sorting of the sparse 
+             // elements.
+             OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *, sidx, n);
+             OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort, sidxX, n);
+ 
+             for (octave_idx_type i = 0; i < n; i++)
+               {
+                 sidx[i] = &sidxX[i];
+                 sidx[i]->i = lhs_idx.elem(i);
+                 sidx[i]->idx = i;
+               }
+                         
+             OCTAVE_QUIT;
+             octave_sort<octave_idx_vector_sort *> 
+               sort (octave_idx_vector_comp);
+ 
+             sort.sort (sidx, n);
+ 
+             intNDArray<octave_idx_type> new_idx (dim_vector (n,1));
+ 
+             for (octave_idx_type i = 0; i < n; i++)
+               {
+                 new_idx.xelem(i) = sidx[i]->i + 1;
+                 rhs_idx[i] = sidx[i]->idx;
+               }
+ 
+             lhs_idx = idx_vector (new_idx);
+           }
+         else
+           for (octave_idx_type i = 0; i < n; i++)
+             rhs_idx[i] = i;
+ 
          // First count the number of non-zero elements
          for (octave_idx_type i = 0; i < n; i++)
            {
***************
*** 1961,1967 ****
              octave_idx_type ii = lhs_idx.elem (i);
              if (ii < lhs_len && c_lhs.elem(ii) != LT ())
                new_nnz--;
!             if (rhs.elem(i) != RT ())
                new_nnz++;
            }
  
--- 1999,2005 ----
              octave_idx_type ii = lhs_idx.elem (i);
              if (ii < lhs_len && c_lhs.elem(ii) != LT ())
                new_nnz--;
!             if (rhs.elem(rhs_idx[i]) != RT ())
                new_nnz++;
            }
  
***************
*** 1992,1998 ****
                    }
                  else
                    {
!                     RT rtmp = rhs.elem (j);
                      if (rtmp != RT ())
                        {
                          tmp.xdata (kk) = rtmp;
--- 2030,2036 ----
                    }
                  else
                    {
!                     RT rtmp = rhs.elem (rhs_idx[j]);
                      if (rtmp != RT ())
                        {
                          tmp.xdata (kk) = rtmp;
***************
*** 2031,2036 ****
--- 2069,2075 ----
                      while (ic <= ii)
                        tmp.xcidx (ic++) = kk;
                      tmp.xdata (kk) = c_lhs.data (i);
+                     tmp.xridx (kk++) = 0;
                      i++;
                      while (ii < nc && c_lhs.cidx(ii+1) <= i)
                        ii++;
***************
*** 2040,2048 ****
                      while (ic <= jj)
                        tmp.xcidx (ic++) = kk;
  
!                     RT rtmp = rhs.elem (j);
                      if (rtmp != RT ())
!                       tmp.xdata (kk) = rtmp;
                      if (ii == jj)
                        {
                          i++;
--- 2079,2090 ----
                      while (ic <= jj)
                        tmp.xcidx (ic++) = kk;
  
!                     RT rtmp = rhs.elem (rhs_idx[j]);
                      if (rtmp != RT ())
!                       {
!                         tmp.xdata (kk) = rtmp;
!                         tmp.xridx (kk++) = 0;
!                       }
                      if (ii == jj)
                        {
                          i++;
***************
*** 2053,2059 ****
                      if (j < n)
                        jj = lhs_idx.elem(j);
                    }
-                 tmp.xridx (kk++) = 0;
                }
  
              for (octave_idx_type iidx = ic; iidx < max_idx+1; iidx++)
--- 2095,2100 ----
***************
*** 2067,2072 ****
--- 2108,2114 ----
          octave_idx_type new_nnz = lhs.nnz ();
          RT scalar = rhs.elem (0);
          bool scalar_non_zero = (scalar != RT ());
+         lhs_idx.sort (true);
  
          // First count the number of non-zero elements
          if (scalar != RT ())
***************
*** 2260,2271 ****
  
    if (n_idx == 2)
      {
!       octave_idx_type n = idx_i.freeze (lhs_nr, "row", true, 
liboctave_wrore_flag);
!       idx_i.sort (true);
! 
!       octave_idx_type m = idx_j.freeze (lhs_nc, "column", true, 
liboctave_wrore_flag);
!       idx_j.sort (true);
! 
  
        int idx_i_is_colon = idx_i.is_colon ();
        int idx_j_is_colon = idx_j.is_colon ();
--- 2302,2311 ----
  
    if (n_idx == 2)
      {
!       octave_idx_type n = idx_i.freeze (lhs_nr, "row", true, 
!                                       liboctave_wrore_flag);
!       octave_idx_type m = idx_j.freeze (lhs_nc, "column", true, 
!                                       liboctave_wrore_flag);
  
        int idx_i_is_colon = idx_i.is_colon ();
        int idx_j_is_colon = idx_j.is_colon ();
***************
*** 2291,2304 ****
  
                  if (n > 0 && m > 0)
                    {
                      octave_idx_type max_row_idx = idx_i_is_colon ? rhs_nr : 
                        idx_i.max () + 1;
                      octave_idx_type max_col_idx = idx_j_is_colon ? rhs_nc : 
                        idx_j.max () + 1;
!                     octave_idx_type new_nr = max_row_idx > lhs_nr ? 
max_row_idx : 
!                       lhs_nr;
!                     octave_idx_type new_nc = max_col_idx > lhs_nc ? 
max_col_idx : 
!                       lhs_nc;
                      RT scalar = rhs.elem (0, 0);
  
                      // Count the number of non-zero terms
--- 2331,2347 ----
  
                  if (n > 0 && m > 0)
                    {
+                     idx_i.sort (true);
+                     idx_j.sort (true);
+ 
                      octave_idx_type max_row_idx = idx_i_is_colon ? rhs_nr : 
                        idx_i.max () + 1;
                      octave_idx_type max_col_idx = idx_j_is_colon ? rhs_nc : 
                        idx_j.max () + 1;
!                     octave_idx_type new_nr = max_row_idx > lhs_nr ? 
!                       max_row_idx : lhs_nr;
!                     octave_idx_type new_nc = max_col_idx > lhs_nc ? 
!                       max_col_idx : lhs_nc;
                      RT scalar = rhs.elem (0, 0);
  
                      // Count the number of non-zero terms
***************
*** 2399,2408 ****
                        idx_i.max () + 1;
                      octave_idx_type max_col_idx = idx_j_is_colon ? rhs_nc : 
                        idx_j.max () + 1;
!                     octave_idx_type new_nr = max_row_idx > lhs_nr ? 
max_row_idx : 
!                       lhs_nr;
!                     octave_idx_type new_nc = max_col_idx > lhs_nc ? 
max_col_idx : 
!                       lhs_nc;
  
                      // Count the number of non-zero terms
                      octave_idx_type new_nnz = lhs.nnz ();
--- 2442,2529 ----
                        idx_i.max () + 1;
                      octave_idx_type max_col_idx = idx_j_is_colon ? rhs_nc : 
                        idx_j.max () + 1;
!                     octave_idx_type new_nr = max_row_idx > lhs_nr ?
!                       max_row_idx : lhs_nr;
!                     octave_idx_type new_nc = max_col_idx > lhs_nc ? 
!                       max_col_idx : lhs_nc;
! 
!                     OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx_i, n);
!                     if (! idx_i.is_colon ())
!                       {
!                         // Ok here we have to be careful with the indexing,
!                         // to treat cases like "a([3,2,1],:) = b", and still 
!                         // handle the need for strict sorting of the sparse 
!                         // elements.
!                         OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *,
!                                              sidx, n);
!                         OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort,
!                                              sidxX, n);
! 
!                         for (octave_idx_type i = 0; i < n; i++)
!                           {
!                             sidx[i] = &sidxX[i];
!                             sidx[i]->i = idx_i.elem(i);
!                             sidx[i]->idx = i;
!                           }
!                         
!                         OCTAVE_QUIT;
!                         octave_sort<octave_idx_vector_sort *> 
!                           sort (octave_idx_vector_comp);
! 
!                         sort.sort (sidx, n);
! 
!                         intNDArray<octave_idx_type> new_idx (dim_vector 
(n,1));
! 
!                         for (octave_idx_type i = 0; i < n; i++)
!                           {
!                             new_idx.xelem(i) = sidx[i]->i + 1;
!                             rhs_idx_i[i] = sidx[i]->idx;
!                           }
! 
!                         idx_i = idx_vector (new_idx);
!                       }
!                     else
!                       for (octave_idx_type i = 0; i < n; i++)
!                         rhs_idx_i[i] = i;
! 
!                     OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx_j, m);
!                     if (! idx_j.is_colon ())
!                       {
!                         // Ok here we have to be careful with the indexing,
!                         // to treat cases like "a([3,2,1],:) = b", and still 
!                         // handle the need for strict sorting of the sparse 
!                         // elements.
!                         OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *,
!                                              sidx, m);
!                         OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort,
!                                              sidxX, m);
! 
!                         for (octave_idx_type i = 0; i < m; i++)
!                           {
!                             sidx[i] = &sidxX[i];
!                             sidx[i]->i = idx_j.elem(i);
!                             sidx[i]->idx = i;
!                           }
!                         
!                         OCTAVE_QUIT;
!                         octave_sort<octave_idx_vector_sort *> 
!                           sort (octave_idx_vector_comp);
! 
!                         sort.sort (sidx, m);
! 
!                         intNDArray<octave_idx_type> new_idx (dim_vector 
(m,1));
! 
!                         for (octave_idx_type i = 0; i < m; i++)
!                           {
!                             new_idx.xelem(i) = sidx[i]->i + 1;
!                             rhs_idx_j[i] = sidx[i]->idx;
!                           }
! 
!                         idx_j = idx_vector (new_idx);
!                       }
!                     else
!                       for (octave_idx_type i = 0; i < m; i++)
!                         rhs_idx_j[i] = i;
  
                      // Count the number of non-zero terms
                      octave_idx_type new_nnz = lhs.nnz ();
***************
*** 2430,2436 ****
                                    }
                                }
                              
!                             if (rhs.elem(i,j) != RT ())
                                new_nnz++;
                            }
                        }
--- 2551,2557 ----
                                    }
                                }
                              
!                             if (rhs.elem(rhs_idx_i[i],rhs_idx_j[j]) != RT ())
                                new_nnz++;
                            }
                        }
***************
*** 2453,2459 ****
  
                                  if (iii < n && ii == i)
                                    {
!                                     RT rtmp = rhs.elem (iii, jji);
                                      if (rtmp != RT ())
                                        {
                                          stmp.data(kk) = rtmp;
--- 2574,2581 ----
  
                                  if (iii < n && ii == i)
                                    {
!                                     RT rtmp = rhs.elem (rhs_idx_i[iii], 
!                                                         rhs_idx_j[jji]);
                                      if (rtmp != RT ())
                                        {
                                          stmp.data(kk) = rtmp;
***************
*** 2529,2536 ****
        {
          octave_idx_type lhs_len = lhs.length ();
  
!         octave_idx_type n = idx_i.freeze (lhs_len, 0, true, 
liboctave_wrore_flag);
!         idx_i.sort (true);
  
          if (idx_i)
            {
--- 2651,2658 ----
        {
          octave_idx_type lhs_len = lhs.length ();
  
!         octave_idx_type n = idx_i.freeze (lhs_len, 0, true, 
!                                           liboctave_wrore_flag);
  
          if (idx_i)
            {
***************
*** 2570,2576 ****
        else if (lhs_nr == 1)
        {
          idx_i.freeze (lhs_nc, "vector", true, liboctave_wrore_flag);
-         idx_i.sort (true);
  
          if (idx_i)
            {
--- 2692,2697 ----
***************
*** 2584,2590 ****
        else if (lhs_nc == 1)
        {
          idx_i.freeze (lhs_nr, "vector", true, liboctave_wrore_flag);
-         idx_i.sort (true);
  
          if (idx_i)
            {
--- 2705,2710 ----
***************
*** 2608,2614 ****
          octave_idx_type lhs_len = lhs.length ();
  
          octave_idx_type len = idx_i.freeze (lhs_nr * lhs_nc, "matrix");
-         idx_i.sort (true);
  
          if (idx_i)
            {
--- 2728,2733 ----
***************
*** 2628,2633 ****
--- 2747,2791 ----
              else if (len == rhs_nr * rhs_nc)
                {
                  octave_idx_type new_nnz = lhs_nz;
+                 OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx, len);
+                 
+                 if (! idx_i.is_colon ())
+                   {
+                     // Ok here we have to be careful with the indexing, to
+                     // treat cases like "a([3,2,1]) = b", and still handle
+                     // the need for strict sorting of the sparse elements.
+ 
+                     OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *, sidx, 
+                                          len);
+                     OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort, sidxX, 
+                                          len);
+ 
+                     for (octave_idx_type i = 0; i < len; i++)
+                       {
+                         sidx[i] = &sidxX[i];
+                         sidx[i]->i = idx_i.elem(i);
+                         sidx[i]->idx = i;
+                       }
+ 
+                     OCTAVE_QUIT;
+                     octave_sort<octave_idx_vector_sort *> 
+                       sort (octave_idx_vector_comp);
+ 
+                     sort.sort (sidx, len);
+ 
+                     intNDArray<octave_idx_type> new_idx (dim_vector (len,1));
+ 
+                     for (octave_idx_type i = 0; i < len; i++)
+                       {
+                         new_idx.xelem(i) = sidx[i]->i + 1;
+                         rhs_idx[i] = sidx[i]->idx;
+                       }
+ 
+                     idx_i = idx_vector (new_idx);
+                   }
+                 else
+                   for (octave_idx_type i = 0; i < len; i++)
+                     rhs_idx[i] = i;
  
                  // First count the number of non-zero elements
                  for (octave_idx_type i = 0; i < len; i++)
***************
*** 2637,2643 ****
                      octave_idx_type ii = idx_i.elem (i);
                      if (ii < lhs_len && c_lhs.elem(ii) != LT ())
                        new_nnz--;
!                     if (rhs.elem(i) != RT ())
                        new_nnz++;
                    }
  
--- 2795,2801 ----
                      octave_idx_type ii = idx_i.elem (i);
                      if (ii < lhs_len && c_lhs.elem(ii) != LT ())
                        new_nnz--;
!                     if (rhs.elem(rhs_idx[i]) != RT ())
                        new_nnz++;
                    }
  
***************
*** 2679,2685 ****
                        {
                          while (kc <= jc)
                            stmp.xcidx (kc++) = kk;
!                         RT rtmp = rhs.elem (j);
                          if (rtmp != RT ())
                            {
                              stmp.xdata (kk) = rtmp;
--- 2837,2843 ----
                        {
                          while (kc <= jc)
                            stmp.xcidx (kc++) = kk;
!                         RT rtmp = rhs.elem (rhs_idx[j]);
                          if (rtmp != RT ())
                            {
                              stmp.xdata (kk) = rtmp;
***************
*** 2704,2711 ****
                    }
  
                  for (octave_idx_type iidx = kc; iidx < lhs_nc+1; iidx++)
!                   stmp.xcidx(iidx) = kk;
!                 
  
                  lhs = stmp;
                }
--- 2862,2868 ----
                    }
  
                  for (octave_idx_type iidx = kc; iidx < lhs_nc+1; iidx++)
!                   stmp.xcidx(iidx) = kk; 
  
                  lhs = stmp;
                }
***************
*** 2713,2718 ****
--- 2870,2876 ----
                {
                  RT scalar = rhs.elem (0, 0);
                  octave_idx_type new_nnz = lhs_nz;
+                 idx_i.sort (true);
  
                  // First count the number of non-zero elements
                  if (scalar != RT ())
*** ./liboctave/sparse-sort.cc.orig5    2006-01-21 23:52:38.000000000 +0100
--- ./liboctave/sparse-sort.cc  2006-01-21 22:25:24.587972101 +0100
***************
*** 39,45 ****
  octave_sparse_sidxl_comp (octave_sparse_sort_idxl* i, 
                          octave_sparse_sort_idxl* j)
  {
!   int tmp = i->c - j->c;
    if (tmp < 0)
      return true;
    else if (tmp > 0)
--- 39,45 ----
  octave_sparse_sidxl_comp (octave_sparse_sort_idxl* i, 
                          octave_sparse_sort_idxl* j)
  {
!   octave_idx_type tmp = i->c - j->c;
    if (tmp < 0)
      return true;
    else if (tmp > 0)
***************
*** 50,55 ****
--- 50,67 ----
  // Instantiate the sparse sorting class
  template class octave_sort<octave_sparse_sort_idxl *>;
  
+ // Need to know the original order of the sorted indexes in
+ // sparse assignments, and this class does that
+ bool
+ octave_idx_vector_comp (octave_idx_vector_sort* i,
+                       octave_idx_vector_sort* j)
+ {
+   return (i->i < j->i);
+ }
+ 
+ // Instantiate the sparse index sorting class
+ template class octave_sort<octave_idx_vector_sort *>;
+ 
  // Instantiate the sorting class of octave_idx_type, need in MUL macro
  template class octave_sort<octave_idx_type>;
  
*** ./liboctave/sparse-sort.h.orig5     2006-01-21 23:52:49.000000000 +0100
--- ./liboctave/sparse-sort.h   2006-01-21 22:25:27.000000000 +0100
***************
*** 28,42 ****
  class
  octave_sparse_sort_idxl
  {
!  public:
!   unsigned int r;
!   unsigned int c;
!   unsigned int idx; 
  };
  
  bool octave_sparse_sidxl_comp (octave_sparse_sort_idxl* i,
                               octave_sparse_sort_idxl* j);
  
  #endif
  
  /*
--- 28,53 ----
  class
  octave_sparse_sort_idxl
  {
! public:
!   octave_idx_type r;
!   octave_idx_type c;
!   octave_idx_type idx; 
  };
  
  bool octave_sparse_sidxl_comp (octave_sparse_sort_idxl* i,
                               octave_sparse_sort_idxl* j);
  
+ class
+ octave_idx_vector_sort
+ {
+ public:
+   octave_idx_type i;
+   octave_idx_type idx;
+ };
+ 
+ bool octave_idx_vector_comp (octave_idx_vector_sort* i,
+                            octave_idx_vector_sort* j);
+ 
  #endif
  
  /*
*** ./test/build_sparse_tests.sh.orig5  2006-01-21 23:53:06.000000000 +0100
--- ./test/build_sparse_tests.sh        2006-01-22 00:04:06.000000000 +0100
***************
*** 772,777 ****
--- 772,778 ----
  %!assert(sparse(as(idx),true),sparse(af(idx),true));
  %!assert(as(idx),sparse(af(idx),true));
  %!assert(as(idx'),sparse(af(idx'),true));
+ %!assert(as(flipud(idx(:))),sparse(af(flipud(idx(:))),true))
  %!assert(as([idx,idx]),sparse(af([idx,idx]),true));
  %!error(as(reshape([idx;idx],[1,length(idx),2])));
  
***************
*** 780,785 ****
--- 781,810 ----
  %!assert(as(ridx,:), sparse(af(ridx,:),true))
  %!assert(as(:,cidx), sparse(af(:,cidx),true))
  %!assert(as(:,:), sparse(af(:,:),true))
+ %!assert(as((size(as,1):-1:1),:),sparse(af((size(af,1):-1:1),:),true))
+ %!assert(as(:,(size(as,2):-1:1)),sparse(af(:,(size(af,2):-1:1)),true))
+ 
+ %% Assignment test
+ %!test
+ %! ts(:,:)=as(fliplr(1:size(as,1)),:);tf(:,:)=af(fliplr(1:size(af,1)),:);
+ %! assert(ts,sparse(tf,true));
+ %!test
+ %! ts(fliplr(1:size(as,1)),:)=as;tf(fliplr(1:size(af,1)),:)=af;
+ %! assert(ts,sparse(tf,true));
+ %!test
+ %! ts(:,fliplr(1:size(as,2)))=as;tf(:,fliplr(1:size(af,2)))=af;
+ %! assert(ts,sparse(tf,true));
+ %!test
+ %! ts(fliplr(1:size(as,1)))=as(:,1);tf(fliplr(1:size(af,1)))=af(:,1);
+ %! assert(ts,sparse(tf,true));
+ 
+ %% Deletion tests
+ %!test
+ %! ts=as;ts(1,:)=[];tf=af;tf(1,:)=[];
+ %! assert(ts,sparse(tf,true));
+ %!test
+ %! ts=as;ts(:,1)=[];tf=af;tf(:,1)=[];
+ %! assert(ts,sparse(tf,true));
  
  %% Test 'end' keyword
  %!assert(as(end),af(end))

reply via email to

[Prev in Thread] Current Thread [Next in Thread]