Bug 7895 - Error with user defined split function in rpart
Summary: Error with user defined split function in rpart
Status: NEW
Alias: None
Product: R
Classification: Unclassified
Component: Add-ons (show other bugs)
Version: old
Hardware: ix86 (32-bit) Windows 32-bit
: P5 normal
Assignee: Jitterbug compatibility account
URL:
Depends on:
Blocks:
 
Reported: 2005-05-25 19:15 UTC by Jitterbug compatibility account
Modified: 2005-05-25 19:15 UTC (History)
0 users

See Also:


Attachments

Note You need to log in before you can comment on or make changes to this bug.
Description Jitterbug compatibility account 2005-05-25 19:15:19 UTC
From: wheelerb@imsweb.com
Full_Name: Bill Wheeler
Version: 2.0.1
OS: Windows 2000
Submission from: (NULL) (67.130.36.229)


The program to reproduce the error is below. I am calling rpart with a
user-defined split function for a binary response variable and one continuous
independent variable. The split function works for some datasets but not
others.
The error is:
Error in "$<-.data.frame"(`*tmp*`, "yval2", value = c(0, 15, 10, 0.6,  : 
        replacement has 5 rows, data has 1


#
# Test out the "user mode" functions, with a binary response
# 

rm(list=ls(all=TRUE))

options(warn = 1);
library(rpart);

set.seed(7);

nobs <- 25;
mydata        <- data.frame(indx=1:nobs);
mydata[, "y"] <- floor(runif(nobs, min=0, max=2));
mydata[, "x"] <- runif(nobs, min=0, max=2);
mydata$indx   <- NULL;

################################################################
# The 'evaluation' function.  Called once per node.
#  Produce a label (1 or more elements long) for labeling each node,
#  and a deviance.  The latter is
#	- of length 1
#       - equal to 0 if the node is "pure" in some sense (unsplittable)
#       - does not need to be a deviance: any measure that gets larger
#            as the node is less acceptable is fine.
#       - the measure underlies cost-complexity pruning, however
temp1 <- function(y, wt, parms) {
 print("***** START: TEMP1 *****");
 
    n <- length(y);

    # Get the number of y's in each category
    sumyEqual0 <- sum(y == 0);
    sumyEqual1 <- sum(y == 1);

    # Get the proportion of 0's and 1's
    p0 <- sumyEqual0/n;
    p1 <- sumyEqual1/n;

    if (p0 >= p1) {
      dev = sumyEqual1;
    } else {
      dev = sumyEqual0;
    }
  
    # Get the vector of labels
    labels <- matrix(nrow=1, ncol=5);

    # labels[1] is the fitted y category
    # labels[2] is sum(y == 0)
    # labels[3] is sum(y == 1)
    # labels[4] is sum(y == 0)/n
    # labels[5] is sum(y == 1)/n

    if (p0 >= p1) {
      labels[1] = 0;
    } else {
      labels[1] = 1;
    }
    labels[2] <- sumyEqual0;
    labels[3] <- sumyEqual1;
    labels[4] <- sumyEqual0/n;
    labels[5] <- sumyEqual1/n;

    ret <- list(label=labels, deviance=dev)
 

    print("***** END: TEMP1 *****");
    ret
    }

# The split function, where most of the work occurs.
#   Called once per split variable per node.
# If continuous=T
#   The actual x variable is ordered
#   y is supplied in the sort order of x, with no missings,
#   return two vectors of length (n-1):
#      goodness = goodness of the split, larger numbers are better.
#                 0 = couldn't find any worthwhile split
#        the ith value of goodness evaluates splitting obs 1:i vs (i+1):n
#      direction= -1 = send "y< cutpoint" to the left side of the tree
#                  1 = send "y< cutpoint" to the right
#         this is not a big deal, but making larger "mean y's" move towards
#         the right of the tree, as we do here, seems to make it easier to
#         read
# If continuos=F, x is a set of integers defining the groups for an
#   unordered predictor.  In this case:
#       direction = a vector of length m= "# groups".  It asserts that the
#           best split can be found by lining the groups up in this order
#           and going from left to right, so that only m-1 splits need to
#           be evaluated rather than 2^(m-1)
#       goodness = m-1 values, as before.
#
# The reason for returning a vector of goodness is that the C routine
#   enforces the "minbucket" constraint. It selects the best return value
#   that is not too close to an edge.
temp2 <- function(y, wt, x, parms, continuous) {
    print("***** START: TEMP2 *****");
 
    n <- length(y)
 
    # For binary y, get P(Y=0)/n and P(Y=1)/n at each split
    temp <- cumsum(y*wt)[-n]
 
    left.wt  <- cumsum(wt)[-n]  
    right.wt <- sum(wt) - left.wt
   
    lp <- temp/left.wt
 
    rsum <- matrix(nrow=1, ncol=n-1, data=0);
    for (i in seq(1, n-1))
    {
      for (j in seq(i+1, n))
      {
        rsum[i] <- rsum[i] + y[j];
      }  
    }

    rp <- rsum/right.wt
 
    lprop <- 1 - lp;
    rprop <- rp;
 
    # Get the direction
    direc <- matrix(nrow=1, ncol=length(lp), data=1);
    for (i in seq(1, length(lp)))
    {
      if (lprop[i] >= rprop[i])
        direc[i] <- -1;
    }
   
    goodness <- (lprop + rprop);
      
    ret <- list(goodness= goodness, direction=direc)
  
    print("***** END: TEMP2 *****");

    ret
    }
	
# The init function:
#   fix up y to deal with offsets
#   return a dummy parms list
#   numresp is the number of values produced by the eval routine's "label"
#   numy is the number of columns for y
#   summary is a function used to print one line in summary.rpart
#   text is a function used to put text on the plot in text.rpart
# In general, this function would also check for bad data, see rpart.poisson
#   for instace.
temp3 <- function(y, offset, parms, wt) {
    print("***** START: TEMP3 *****");
 
    if (!is.null(offset)) y <- y-offset
    ret <- list(y=y, parms=0, numresp=5, numy=1,
	      summary= function(yval, dev, wt, ylevel, digits ) {
		  paste("  mean=", format(signif(yval, digits)),
			", MSE=" , format(signif(dev/wt, digits)),
			sep='')
	     },
    	 text= function(yval, dev, wt, ylevel, digits, n, use.n ) {
	     if(use.n) {paste(formatg(yval,digits),"\nn=", n,sep="")}
             else{paste(formatg(yval,digits))}
           })
    print("***** END: TEMP3 *****");
 
    ret
    }

alist <- list(eval=temp1, split=temp2, init=temp3);

fit1 <- rpart(y ~ ., data=mydata, method=alist, control=list(cp=0));



Comment 1 Jitterbug compatibility account 2008-08-10 00:02:00 UTC
NOTES:
 Non-maintainer report on contributed package
User error?
Comment 2 Jitterbug compatibility account 2008-08-10 02:02:21 UTC
Audit (from Jitterbug):
Fri Jun 10 12:47:29 2005	ripley	changed notes
Fri Jun 10 12:47:29 2005	ripley	moved from incoming to Add-ons
Sat Aug  9 21:02:21 2008	ripley	changed notes