#include "mex.h" #include "math.h" /* * oasis.c - implement the OASIS algorithm * * * This is a MEX-file for MATLAB. * Copyright Gal Chechik 2009 * * * I use a sparse representation as in pamir: * [num_paires, index,value, index,value, index,value... ] * * */ int CountNonzerosOfFullVec(double *full, int dim) { int count=0; int i; for (i=0; i%f\n", sparse_entry, i, full[i]);*/ } } return; } /* Translate a given full vector into a sparse one */ double* SparsifyAFullVec(double *full, int dim) { int count = CountNonzerosOfFullVec(full, dim); int i; /* printf("\t\t\t dim = %d ", dim);*/ /* printf("\t\t\t SparsifyAFullVec: Full="); for (i=0; i 2) { mexErrMsgTxt("Too many output arguments"); } else if (2 > nlhs) { mexErrMsgTxt("Too few output arguments"); } /* W */ if (!mxIsDouble(prhs[0]) || mxIsComplex(prhs[0])){ mexErrMsgTxt("Oasis: Input #1 must be a real vector."); } mrows_W = mxGetM(prhs[0]); ncols_W = mxGetN(prhs[0]); if (mrows_W != ncols_W ) { sprintf(msgbuf, "Oasis: Input #1 must be a square matrix, but dim=[%d %d]", mrows_W, ncols_W); mexErrMsgTxt(msgbuf); } printf("\t\t Good size W %d x %d\n", mrows_W, ncols_W); /* images */ if (!mxIsDouble(prhs[1]) || mxIsComplex(prhs[1])){ mexErrMsgTxt( "Oasis: Input #2 must be a real vector."); } mrows_images = mxGetM(prhs[1]); ncols_images = mxGetN(prhs[1]); if (mrows_images!=mrows_W) { sprintf(msgbuf, "Oasis: Input #2 and #1 must have same n-rows, but %d!=%d", ncols_images, ncols_W); mexErrMsgTxt(msgbuf); } printf("\t\t Good size images %d features x %d images\n", mrows_images, ncols_images); /* class_labels */ if (!mxIsDouble(prhs[2]) || mxIsComplex(prhs[2])){ mexErrMsgTxt( "Oasis: Input #2 must be a real vector."); } mrows_class_labels = mxGetM(prhs[2]); ncols_class_labels = mxGetN(prhs[2]); if (mrows_class_labels!=ncols_images) { sprintf(msgbuf, "Oasis: Input #3 and #2 must have same n-rows, but %d!=%d", mrows_class_labels,ncols_images); mexErrMsgTxt(msgbuf); } if (ncols_class_labels!=1) { sprintf(msgbuf, "Oasis: Input #3 is not a column vector, size = %d x %d", mrows_class_labels, ncols_class_labels); mexErrMsgTxt(msgbuf); } printf("\t\t Good size class_labels: %d x %d\n", mrows_class_labels, ncols_class_labels); /* class_starts */ if (!mxIsDouble(prhs[3]) || mxIsComplex(prhs[3])){ mexErrMsgTxt( "Oasis: Input #4 must be a real vector."); } mrows_class_start = mxGetM(prhs[3]); ncols_class_start = mxGetN(prhs[3]); if (ncols_class_start!=1) { sprintf(msgbuf, "Oasis: Input #4 is not a column vector, size = %d x %d", mrows_class_start, ncols_class_start); mexErrMsgTxt(msgbuf); } printf("\t\t Good size class_starts %d x %d \n", mrows_class_start, ncols_class_start); /* class_sizes */ if (!mxIsDouble(prhs[4]) || mxIsComplex(prhs[4])){ mexErrMsgTxt( "Oasis: Input #5 must be a real vector."); } mrows_class_size = mxGetM(prhs[4]); ncols_class_size = mxGetN(prhs[4]); if (ncols_class_size!=1) { sprintf(msgbuf, "Oasis: Input #5 is not a column vector, size = %d x %d", mrows_class_size, ncols_class_size); mexErrMsgTxt(msgbuf); } if (mrows_class_size!= mrows_class_start) { sprintf(msgbuf, "Oasis: Input #4,#5 should have same number of rows, %d,%d", mrows_class_size, ncols_class_size); mexErrMsgTxt(msgbuf); } printf("\t\t Good size class_sizes %d x %d \n", mrows_class_size, ncols_class_size); /* num_steps */ if (!mxIsDouble(prhs[5]) || mxIsComplex(prhs[5])){ mexErrMsgTxt( "Oasis: Input #6 must be real."); } mrows_numsteps = mxGetM(prhs[5]); ncols_numsteps = mxGetN(prhs[5]); if (mrows_numsteps!=1 || ncols_numsteps!=1) { sprintf(msgbuf, "Oasis: Input #6 (num_steps) is not a scalar, size = %dx%d", mrows_numsteps, ncols_numsteps); mexErrMsgTxt(msgbuf); } printf("\t\t Good size num_steps %d x %d \n", mrows_numsteps, ncols_numsteps); /* aggress */ if (!mxIsDouble(prhs[6]) || mxIsComplex(prhs[6])){ mexErrMsgTxt( "Oasis: Input #7 must be real."); } mrows_aggress = mxGetM(prhs[6]); ncols_aggress = mxGetN(prhs[6]); if (mrows_aggress!=1 || ncols_aggress!=1) { sprintf(msgbuf, "Oasis: Input #7 (aggres) is not a scalar, size = %dx%d", mrows_aggress, ncols_aggress); mexErrMsgTxt(msgbuf); } printf("\t\t Good size aggress %d x %d \n", mrows_aggress, ncols_aggress); /* rseed */ if (!mxIsDouble(prhs[7]) || mxIsComplex(prhs[7])){ mexErrMsgTxt( "Oasis: Input #8 must be real."); } mrows_rseed = mxGetM(prhs[7]); ncols_rseed = mxGetN(prhs[7]); if (mrows_rseed!=1 || ncols_rseed!=1) { sprintf(msgbuf, "Oasis: Input #8 (seed) is not a scalar, size = %dx%d", mrows_rseed, ncols_rseed); mexErrMsgTxt(msgbuf); } /* call_no */ if (!mxIsDouble(prhs[8]) || mxIsComplex(prhs[8])){ mexErrMsgTxt( "Oasis: Input #9 must be real."); } mrows_call_no = mxGetM(prhs[8]); ncols_call_no = mxGetN(prhs[8]); if (mrows_call_no!=1 || ncols_call_no!=1) { sprintf(msgbuf, "Oasis: Input #9 (call_no) is not a scalar, size = %dx%d", mrows_call_no, ncols_call_no); mexErrMsgTxt(msgbuf); } /*printf("\t\t Good size call_no %d x %d \n", mrows_seed, ncols_seed); */ printf("\t\t Finished checking parameters\n"); } void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { /* * Asumes a call of the form: * [W,loss_vec]=oasis(W0, images, class_labels, class_start, class_size, num_steps, aggress, rseed, call_no); * * Inputs: * 0 W0 - initial weight matrix * 1 images - a matrix with images in each COLUMN! * 2 class_labels - a column vector, for each image what class it belongs to * 3 class_start - a column vector, which image is first on each class * 4 class_size - a column vector. Number of images per class * 5 n_steps - number of traiing steps * 6 aggres - the aggrersssivness parameter * 7 rseed - the seed for the random sampling * 8 call_no - a flag that indicates whether to initialize the random seed * Outputs: * W - the updated matrix * loss_vec - binary vector of loss events */ double *images_vec; /* the images */ double *class_labels; /* the labels of the images */ double **images; /* A sparse matrix with the images */ double *W0, *W; /* the weight matrices (input and output) */ double *loss_vec; /* vector of loss occurences */ double *class_start; /* Class labels, vector of starting points (0,...n-1)*/ double *class_size; /* Class labels, vector of class sizes */ int i, j; int dim; /* vocabulary size (dimension of W) */ static char msgbuf[256]; int image_query, image_pos, image_neg; int num_images; int i_step, num_steps; int rseed; int call_no; double aggres; /* aggresiveness parameter */ double loss; /* loss per triplet */ double norm_grad_w; /* norm of V = dloss/dW */ double tau; /* learning rate */ int query_label; /* Check input dimensionality */ CheckInputs(nlhs, plhs, nrhs,prhs); /* Extract inputs */ W0 = mxGetPr(prhs[0]); dim = (int)mxGetM(prhs[0]); images_vec = mxGetPr(prhs[1]); num_images = mxGetN(prhs[1]); class_labels = mxGetPr(prhs[2]); class_start = mxGetPr(prhs[3]); class_size = mxGetPr(prhs[4]); num_steps =*mxGetPr(prhs[5]); aggres = *mxGetPr(prhs[6]); rseed = *mxGetPr(prhs[7]); call_no = *mxGetPr(prhs[8]); /* Create matrix for the return argument. */ plhs[0] = mxCreateDoubleMatrix(dim, dim, mxREAL); W = mxGetPr(plhs[0]); plhs[1] = mxCreateDoubleMatrix(1, num_steps, mxREAL); loss_vec = mxGetPr(plhs[1]); /* printf("\t\t dim=%d, num_images=%d, num_steps=%d, aggres=%g, seed=%d\n", dim, num_images, num_steps, aggres, rseed); */ /* Copy W0 to W */ for (i=0; i(label-1)) { label_neg = label_neg+1; } */ image_query = Sample(num_images); query_label = class_labels[image_query]; /* query labels are in matlab notation 1..n */ image_neg = Sample(num_images-class_size[query_label-1]); /* image_neg is sampled from all the images without those of the same class as the query * so the block of images from the query class should be yanked out */ if (image_neg >= class_start[query_label-1]) { image_neg = image_neg+class_size[query_label-1]; } image_pos = SampleImageForClass(class_start[query_label-1], class_size[query_label-1]); /* printf("Main: Sampled images qry=%d pos=%d neg=%d (num_images=%d)\n", image_query, image_pos, image_neg, num_images); */ loss = 1.0 - ComputeScore(images[image_query], W, images[image_pos], dim) + ComputeScore(images[image_query], W, images[image_neg], dim); /* double dd = ComputeScore(images[image_query], W, images[image_query], dim); printf("\t\t loss: %f \n", dd); */ if (loss > 0) { /* printf("Main: loss = %f>0: Update W\n", loss);*/ norm_grad_w = ComputeSquareNormGradW(images[image_query], images[image_neg], images[image_pos]); tau = loss/norm_grad_w; if (tau> aggres) { tau = aggres; } UpdateW(W, tau, dim, images[image_query], images[image_pos], images[image_neg]); loss_vec[i_step] = 1; } else { loss_vec[i_step] = 0; } }/* Loop over steps */ /* Cleanup memory */ for (i=0; i