/* This is C-language implementation of the mex-function mlp_cgd. It can be compiled by using the mex-command from the Matlab command line: mex mlp_cgd.c - Study the gateway function and try to understand what is done in different stages of mexFunction. - The purpose of this mex-function is to compute value of the error function and its gradient. It performs exactly the same operations as the function mlp_cgd.m from Exercise 9&10. - value of the functional is stored in f and the gradient in fg. */ #include "mex.h" #include void fillve (double *v, int n, double val) { int i; for (i=0; i 2) { mexErrMsgTxt("At most two outputs in the form [f,fg] allowed."); } /* Grab out the inputs v=prhs[0], k=prhs[1], np=prhs[2], x=prhs[3], y=prhs[4]. */ v = mxGetPr(prhs[0]); v_cols = mxGetN(prhs[0]); v_rows = mxGetM(prhs[0]); if (v_cols != 1) { mexErrMsgTxt("Input v must be a column vector.");} np = mxGetScalar(prhs[2]); n1 = (int)np; k = mxGetPr(prhs[1]); if (mxGetN(prhs[1]) != 1) { mexErrMsgTxt("Input k must be a column vector.");} nk = mxGetM(prhs[1]); if (nk != 1 && nk != n1) { mexErrMsgTxt("Input k has invalid length.");} x = mxGetPr(prhs[3]); N = mxGetN(prhs[3]); n0 = mxGetM(prhs[3]); y = mxGetPr(prhs[4]); cols = mxGetN(prhs[4]); n2 = mxGetM(prhs[4]); if (N != cols) { mexErrMsgTxt("Inputs x and y have incompatible sizes.");} if (n1*(n0+1) + n2*(n1+1) != v_rows) { mexErrMsgTxt("Inputs v and x and y have incompatible sizes.");} /* Create the outputs f and fg. */ plhs[0] = mxCreateDoubleMatrix(1,1,mxREAL); f = mxGetPr(plhs[0]); if (nlhs > 1) { plhs[1] = mxCreateDoubleMatrix(v_rows,v_cols,mxREAL); fg = mxGetPr(plhs[1]); } o = mxCalloc(n2,sizeof(double)); e = mxCalloc(n2,sizeof(double)); o1 = mxCalloc(n1,sizeof(double)); d1 = mxCalloc(n1,sizeof(double)); tmp = mxCalloc(n1,sizeof(double)); /* Call the C subroutines. */ c = 1.e0/N; *(f) = 0.e0; if (nlhs > 1) { fillve(fg,v_rows,0.e0); } for (n=0; n 1) { transmul(v+n1*(n0+1)+n2,e,tmp,n2,n1); diag_mul(d1,tmp,n1); ext_outprod(fg,tmp,x+n*n0,n1,n0,c); ext_outprod(fg+n1*(n0+1),e,o1,n2,n1,c); } } *(f) = 0.5e0*c * *(f); }