function [k_hat,angles]=phi_func_wrapper(X,s,k_seed,scalescale,subs,k_true); %wrapper function for phi_func % %standardizes input data X %sets up default local variables %calls phi_func (useful for coarse-to-fine iterations, for example) %also sets up some display information % %inputs: X,s as in phi_func.m, below % k_seed = seed value of estimate. Must be nonzero; must have same number of rows as X has columns. % Number of columns should equal dimension of desired projection space. % scalescale = an overall scale factor for the kernel width % subs = a vector (>=1) of subsampling factors - the higher the sub, the faster (but less accurate) the code. % phi_func will operate on a randomly selected 1/subs(i) of the full data on each call % k_true is an optional input, used for debugging/ performance testing % typically k_true = true k used to generate data % %outs: k_hat = final estimate for k % angles = a vector of some performance info if(nargin<6) k_true=[]; end; %abbreviations used throughout phi_func code N=size(X,1); dim_X=size(X,2); m=size(k_seed,2); p1=s/N; p0=1-p1; fig_num=5464; if(m>1) fig_num=0; end; %2D figs are too time-consuming %assorted input vars; these could also be included at the input line, if preferred %see comments of phi_func.m for definitions n_ints=min(200,N); n_rand_dirs=30; n_samps=20; jk=0; int=randperm(size(X,1)); int=int(1:n_ints); s_i=find(int<=s); ns_i=find(int>s); if(dim_X>2) n_iters=3*dim_X*m*ones(length(subs),1); else n_iters=ones(length(subs),1); end; scales=scalescale*ones(size(subs)).*(N*subs.^-1).^-(1/(m*3)); %standardize data X=X-repmat(mean(X,1),N,1); if(~isempty(k_true)) %display true conditional dists [M,px,px1,p1_x,dk]=box_func(X,k_true,int,m,std(X*k_true)*scales(1),s,s_i,ns_i,p0,p1,fig_num,1); end; if(fig_num) %set up display figure(fig_num); subplot(2,5,6); ylabel('p(kx)'); subplot(2,5,1); ylabel('p(s|kx)'); title('true k'); end; %display seed conditional dists prior to whitening [M,px,px1,p1_x,dk]=box_func(X,k_seed,int,m,std(X*k_seed)*scales(1),s,s_i,ns_i,p0,p1,fig_num,2); if(fig_num) figure(fig_num); subplot(2,5,2); title('seed k'); end; sC=sqrtm((X'*X)/N); wX=X*inv(sC); k=orth(sC*k_seed); for(sub=1:length(subs)) %main loop; call to phi_func disp(sprintf('.....................data fraction = 1/%i .......................',subs(sub))); sub1=randperm(s); sub0=randperm(N-s); sub1s=ceil(s/subs(sub)); sub0s=ceil((N-s)/subs(sub)); [k,M]=phi_func([wX(sub1(1:sub1s),:); wX(s+sub0(1:sub0s),:)],sub1s,k,n_ints,n_iters(sub),n_rand_dirs,n_samps,jk,scales(sub),p0,p1,fig_num); end; %outputs k_hat=orth(inv(sC)*k); %view cond dists in original space [M,px,px1,p1_x,dk]=box_func(X,k_hat,int,m,std(X*k_hat)*scales(sub),s,s_i,ns_i,p0,p1,fig_num,5); if(fig_num) figure(fig_num); subplot(2,5,5); title('final k'); end; %display some performance info disp(sprintf('final M = %1.4f',M)); if(~isempty(k_true)) angles=[subspace(k_true,k_hat) subspace(k_seed,k_hat) subspace(k_true,k_seed)]; else angles=[subspace(k_seed,k_hat)]; end; %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function [k,M]=phi_func(X,s,k,n_ints,n_iters,n_rand_dirs,n_samps,jk,scale,p0,p1,fig_num); % % main function for computing phi estimator; LP 12/02/02 % see Paninski '03 paper on spike-triggered analysis techniques for more details % %in: X= stimulus data matrix. Each row is a different observation; % each column is a different dimension. Assumed centered and white. % The first s rows correspond to spikes; the rest, to no spikes. % k= seed value of estimate. Must be nonzero; must have same number of rows as X has columns. % Number of columns should equal dimension of desired projection space. % n_ints= number of points to be used for Monte Carlo integration (see below). % n_iters= number of iterations % n_rand_dirs,n_samps; see below % jk; jackknife (not really used in this version of the code) % scale= kernel width % p0,p1 = s/N (fraction of spiking stimuli) and 1-s/N, respectively % fig_num= figure handle used for display during progress of algorithm % %out: k= the final estimate % M= M(k), the value of the function to be maximized at k %local variables N=size(X,1); dim_X=size(X,2); m=size(k,2); box_scale=scale*2; box_scale2=box_scale^2; KK=0; II=0; local_min=0; old_e_0s=zeros(dim_X,2,n_iters*m); i_old_e_0s=0; last_info_cov_iter=1; doh=10^-4; %this is a noise-floor term %these vars are not used in this version if(jk>0) fj=floor(N/jk); else fj=0; end; r_jk=randperm(N); %int selects a subset, int, of the data, which will serve as the Monte Carlo samples % used for estimating densities, divergence functions, etc. n_ints=min(n_ints,N); int_1=randperm(s); int_0=randperm(N-s); int=[int_1(1:min(s,floor(n_ints/2))) s+int_0(1:min(N-s,ceil(n_ints/2)))]; n_ints=length(int); s_i=find(int<=s); ns_i=find(int>s); %initialize M [M_seed,px,px1,p1_x,dk]=box_func(X,k,int,m,box_scale,s,s_i,ns_i,p0,p1,fig_num,3); if(fig_num) figure(fig_num); subplot(2,5,3); title('whitened seed k'); end; disp(sprintf('M_seed = %1.4f',M_seed)); M=M_seed; %main loop for(iter=1:n_iters) for(dim=1:m) %choose line-search direction e_0 if(local_min<=2) msg_str='choosing e_0 by gradient method'; %uses gradient information, using a gaussian kernel if(local_min) disp([msg_str ' (random scale, orthogonalized)']); rs=(1-log(rand))*scale; %uses a random scale and subtracts off old, ineffective search directions else disp([msg_str ' (true scale)']); rs=scale; %uses true scale if this hasn't been tried before end; [dummy_M,dk,px,px1]=gaus_func(0,k,zeros(dim_X,1),rs,X,int,s,dim,N,m,0,r_jk,fj,p0,p1); e_0=zeros(dim_X,1); %compute gradient for(e=1:dim_X) de=shiftdim(repmat(X(:,e),[1 1 n_ints]),2)-shiftdim(repmat(X(int,e)',[1 1 N]),1); dedk=de.*dk(:,:,dim); e_0(e)=sum((-px1(s_i).*sum(dedk(s_i,s+1:end),2))./px(s_i))*p1/length(s_i) - sum((-px1(ns_i).*sum(dedk(ns_i,1:s),2))./px(ns_i))*p0/length(ns_i); end; sqo=squeeze(old_e_0s(:,2,i_old_e_0s-(0:local_min-1))); if(~isempty(sqo)) e_0=e_0-sqo*(sqo'*e_0); end; else %if gradient-based approaches have failed to increase M twice in a row if(iter<=dim_X+last_info_cov_iter) disp('choosing e_0 by maximum-distance method'); %chooses e_0 to be as far away from previous searches as possible rand_e_0s=randn(dim_X,n_rand_dirs); rand_e_0s=rand_e_0s-k*(k'*rand_e_0s); angs=zeros(n_rand_dirs,i_old_e_0s); for(rd=1:n_rand_dirs) rand_e_0s(:,rd)=rand_e_0s(:,rd)/norm(rand_e_0s(:,rd)); for(od=1:i_old_e_0s) angs(rd,od)=min(svd(old_e_0s(:,:,od)'*[k(:,dim) rand_e_0s(:,rd)])); end; end; [y,i]=min(max(angs,[],2)); if(local_min>20) disp(sprintf('minimum dot product approximately %1.4f',y)); end; e_0=rand_e_0s(:,i); else disp('choosing e_0 by info-covariance method'); %this is the PCA-type method; only used after dim(X) previous tries [v,d]=eig(II,KK); [y,i]=sort(abs(diag(d))); e_0=orth(v(:,i(end-(rand<.3*(last_info_cov_iter>1))))); last_info_cov_iter=iter-dim_X+7; end; end; %make sure e_0 is orthogonal to k e_0=orth(e_0-k*(k'*e_0)); %maximize M(t) t=0; oM=M; %set up "crossing times" vector, ts ddm=shiftdim(repmat(X*e_0,[1 1 n_ints]),2)-shiftdim(repmat((X(int,:)*e_0)',[1 1 N]),1); if(m>1) all_in=all(abs(dk(:,:,setdiff(1:m,dim)))box_scale2)); sq=sqrt((dkd(f2).^2+ddm(f2).^2-box_scale2))*box_scale; %these lines solve a quadratic problem involved in finding the crossing times ts1=(-dkd(f2).*ddm(f2)+sq)./(ddm(f2).^2-box_scale2); ts2=(-dkd(f2).*ddm(f2)-sq)./(ddm(f2).^2-box_scale2); [ts,f3]=sort([ts1 ts2]); f2_2=[f2 f2]; f4=f2_2(f3); if(length(ts)>n_samps) %samp_t is used to grab samples for the info-cov method of picking e_0 samp_t=[unique(ceil(rand(n_samps,1)*length(ts))); length(ts)+10]; else samp_t=[1:length(ts) length(ts)+10]'; end; %compute M(t) at infinity in=(abs(ddm)<=box_scale).*all_in; px=sum(in,2); px1=sum(in(:,1:s),2); p1_x=px1./px; m0=sum(p1_x(s_i))*p1/length(s_i)+sum(1-p1_x(ns_i))*p0/length(ns_i); %optimize over crossing times in=in*2-1; [best_t,M,tti,t_samps]=box_D_func(f4,[ts ts(end)+1],n_ints,px1,px,in,int,s,length(s_i)/p1,length(ns_i)/p0,m0,samp_t-1); %figure(93234); plot(t_samps,tti); disp(best_t); disp(M); pause disp(sprintf('M=%1.4f',M)); if(M>=oM+doh) %if new M is sufficiently bigger than old M if(M==m0) k(:,dim)=e_0; else k(:,dim)=(k(:,dim)+best_t*e_0)/sqrt(1+best_t^2); end; local_min=0; else local_min=local_min+1; end; %display [direct_m,px,px1,p1_x,dk]=box_func(X,k,int,m,box_scale,s,s_i,ns_i,p0,p1,fig_num,4); if(fig_num) figure(fig_num); subplot(2,5,4); title('current k'); end; %some error checking if(abs(M-direct_m)>10^-10) disp(M-direct_m); disp('doh; true M doesn''t match stepped M'); end; if(M