root/trunk/src/random.py

Revision 475, 4.4 KB (checked in by emuller, 18 months ago)

Get rid of annoying 'scipy missing' message when using NeuroTools.parameters

  • Property svn:eol-style set to native
Line 
1"""
2NeuroTools.random
3=====================
4
5A set of classes representing statistical distributions, with an interface that
6is compatible with the ParameterSpace class in the parameters module.
7
8Classes
9-------
10
11GammaDist   - gamma.pdf(x,a,b) = x**(a-1)*exp(-x/b)/gamma(a)/b**a
12NormalDist  - normal distribution
13UniformDist - uniform distribution
14
15"""
16
17from NeuroTools import check_dependency
18
19import numpy, numpy.random
20
21 
22class ParameterDist(object):
23
24    def __init__(self,**params):
25        self.params = params
26        self.dist_name = 'ParameterDist'
27   
28    def __repr__(self):
29        if len(self.params)==0:
30            return '%s()'% (self.dist_name,)
31        s = '%s('% (self.dist_name,)
32        for key in self.params:
33            s+='%s=%s,' % (key,str(self.params[key]))
34        return s[:-1]+')'
35
36    def next(self,n=1):
37        raise NotImplementedError('This is an abstract base class and cannot be used directly')
38
39    def from_stats(self,vals,bias=0.0,expand=1.0):
40        self.__init__(mean=numpy.mean(vals)+bias, std=numpy.std(vals)*expand)
41
42    def __eq__(self, o):
43        # should we track the state of the rng and return False if it is different between self and o?
44        if (type(self) == type(o) and
45            self.dist_name == o.dist_name and
46            self.params == o.params):
47            return True
48        else:
49            return False
50
51class GammaDist(ParameterDist):
52    """
53    gamma.pdf(x,a,b) = x**(a-1)*exp(-x/b)/gamma(a)/b**a
54
55    Yields strictly positive numbers.
56    Generally the distribution is implemented by scipy.stats.gamma.pdf(x/b,a)/b
57    For more info, in ipython type:
58    >>> ? scipy.stats.gamma
59
60    """
61   
62    def __init__(self,mean=None,std=None,repr_mode='ms',**params):
63        """
64        repr_mode specifies how the dist is displayed,
65        either mean,var ('ms', the default) or a,b ('ab')
66        """
67
68        if check_dependency('scipy'):
69            self.next = self._next_scipy
70
71        self.repr_mode = repr_mode
72        if 'm' in params and mean==None:
73            mean = params['m']
74        if 's' in params and std==None:
75            std = params['s']
76
77        # both mean and std not specified
78        if (mean,std)==(None,None):
79            if 'a' in params:
80                a = params['a']
81            else:
82                a = 1.0
83            if 'b' in params:
84                b = params['b']
85            else:
86                b = 1.0
87        else:
88            if mean==None:
89                mean = 0.0
90            if std==None:
91                std=1.0
92            a = mean**2/std**2
93            b = mean/a   
94        ParameterDist.__init__(self,a=a,b=b)
95        self.dist_name = 'GammaDist'
96
97    def _next_scipy(self,n=1):
98        import scipy.stats
99        return scipy.stats.gamma.rvs(self.params['a'],size=n)*self.params['b']
100    def _next_no_scipy(self,n=1):
101        raise Exception('Error scipy was not found at import time.  GammaDist realization disabled.')
102
103    next = _next_no_scipy
104   
105    def mean(self):
106        return self.params['a']*self.params['b']
107
108    def std(self):
109        return self.params['b']*numpy.sqrt(self.params['a'])
110
111    def __repr__(self):
112        if self.repr_mode == 'ms':
113            return '%s(m=%f,s=%f)' % (self.dist_name,self.mean(),self.std())
114        else:
115            return '%s(a=%f,b=%f)' % (self.dist_name,self.params['a'],self.params['b'])
116       
117
118class NormalDist(ParameterDist):
119    """
120    normal distribution with parameters
121    mean + std
122   
123    """
124   
125    def __init__(self,mean=0.0,std=1.0):
126        ParameterDist.__init__(self,mean=mean,std=std)
127        self.dist_name = 'NormalDist'
128       
129    def next(self,n=1):
130        return numpy.random.normal(loc=self.params['mean'],scale=self.params['std'],size=n)
131       
132
133class UniformDist(ParameterDist):
134    """
135    uniform distribution with min,max
136    """
137
138    def __init__(self,min=0.0,max=1.0, return_type=float):
139        ParameterDist.__init__(self,min=min,max=max)
140        self.dist_name = 'UniformDist'
141        self.return_type = return_type
142       
143    def next(self,n=1):
144        vals = numpy.random.uniform(low=self.params['min'],high=self.params['max'],size=n)
145        if self.return_type != float:
146            vals = vals.astype(self.return_type)
147        return vals
148
149    def from_stats(self,vals,bias=0.0,expand=1.0):
150        mn = numpy.min(vals)
151        mx = numpy.max(vals)
152        center = 0.5*(mx+mn)+bias
153        hw = 0.5*(mx-mn)*expand
154        self.__init__(min=center-hw,max=center+hw)
Note: See TracBrowser for help on using the browser.