3x3 Median Filter

A 3x3 median filter is applied to a Numpy array. Uses Numpy and Matplotlib for display. Creates a random, non-square matrix and filters it 20 times. The before and after matrices are plotted with Matplotlib.

 License of this example: Public Domain Date: 2010-11-29 PyCUDA version: 0.94.2

```   1 #
2 # 3x3 Median Filter ported to PyCuda by Nick Hilton.
3 #
4
5 from matplotlib import pylab
6 import numpy
7
8 import pycuda.autoinit
9 from pycuda.compiler import SourceModule
10
11 kernel_cu = """
12
13 #define BLOCK_X  16
14 #define BLOCK_Y  16
15
16 #define s2(a,b)            { float tmp = a; a = min(a,b); b = max(tmp,b); }
17 #define mn3(a,b,c)         s2(a,b); s2(a,c);
18 #define mx3(a,b,c)         s2(b,c); s2(a,c);
19
20 #define mnmx3(a,b,c)       mx3(a,b,c); s2(a,b);                               // 3 exchanges
21 #define mnmx4(a,b,c,d)     s2(a,b); s2(c,d); s2(a,c); s2(b,d);                // 4 exchanges
22 #define mnmx5(a,b,c,d,e)   s2(a,b); s2(c,d); mn3(a,c,e); mx3(b,d,e);          // 6 exchanges
23 #define mnmx6(a,b,c,d,e,f) s2(a,d); s2(b,e); s2(c,f); mn3(a,b,c); mx3(d,e,f); // 7 exchanges
24
25 #define SMEM(x,y)  smem[(x)+1][(y)+1]
26
27 #define  IN(x,y)    d_in[((y)-1) + ((x)-1) * NY]
28 #define OUT(x,y)   d_out[((y)-1) + ((x)-1) * NY]
29
30 //////////////////////////////////////////////////////////////////////////////
31 __global__
32 void
33 medianFilter(
34         float *       d_out,
35         const float * d_in,
36         const int     NX,       // Number of rows
37         const int     NY)       // Number of cols
38 {
39     const int tx = threadIdx.x;
40     const int ty = threadIdx.y;
41
42     // Guards, at the boundary?
43     bool is_x_top = (tx == 0);
44     bool is_x_bot = (tx == BLOCK_X-1);
45     bool is_y_top = (ty == 0);
46     bool is_y_bot = (ty == BLOCK_Y-1);
47
48     __shared__ float smem[BLOCK_X+2][BLOCK_Y+2];
49
50     // Clear out shared memory (zero padding)
51     if (is_x_top)           SMEM(tx-1, ty  ) = 0;
52     else if (is_x_bot)      SMEM(tx+1, ty  ) = 0;
53     if (is_y_top) {         SMEM(tx  , ty-1) = 0;
54         if (is_x_top)       SMEM(tx-1, ty-1) = 0;
55         else if (is_x_bot)  SMEM(tx+1, ty-1) = 0;
56     } else if (is_y_bot) {  SMEM(tx  , ty+1) = 0;
57         if (is_x_top)       SMEM(tx-1, ty+1) = 0;
58         else if (is_x_bot)  SMEM(tx+1, ty+1) = 0;
59     }
60
61     // x,y are 1 based indicies, the macros IN, OUT subtract 1
62     int x = blockIdx.x * blockDim.x + tx;
63     int y = blockIdx.y * blockDim.y + ty;
64
65     // Guards, at the boundary and still more image to process?
66     is_x_top &= (x > 0);
67     is_x_bot &= (x < NX);
68     is_y_top &= (y > 0);
69     is_y_bot &= (y < NY);
70
72
73                             SMEM(tx  , ty  ) = IN(x  , y  ); // self
74     if (is_x_top)           SMEM(tx-1, ty  ) = IN(x-1, y  );
75     else if (is_x_bot)      SMEM(tx+1, ty  ) = IN(x+1, y  );
76     if (is_y_top) {         SMEM(tx  , ty-1) = IN(x  , y-1);
77         if (is_x_top)       SMEM(tx-1, ty-1) = IN(x-1, y-1);
78         else if (is_x_bot)  SMEM(tx+1, ty-1) = IN(x+1, y-1);
79     } else if (is_y_bot) {  SMEM(tx  , ty+1) = IN(x  , y+1);
80         if (is_x_top)       SMEM(tx-1, ty+1) = IN(x-1, y+1);
81         else if (is_x_bot)  SMEM(tx+1, ty+1) = IN(x+1, y+1);
82     }
84
85     // Pull top six values from shared memory
86
87     float v[6] =
88     {
89         SMEM(tx-1, ty-1),    //  NW     (North West neighbor)
90         SMEM(tx  , ty-1),    //   W
91         SMEM(tx+1, ty-1),    //  SW
92         SMEM(tx-1, ty  ),    //  N
93         SMEM(tx  , ty  ),    //     self
94         SMEM(tx+1, ty  )     //  S
95     };
96
97     // With each pass, remove min and max values and add new value
98     mnmx6(v[0], v[1], v[2], v[3], v[4], v[5]);
99
100     // Replace Max with new value.
101
102     v[5] = SMEM(tx-1, ty+1);    // NE
103
104     mnmx5(v[1], v[2], v[3], v[4], v[5]);
105
106     v[5] = SMEM(tx  , ty+1);    //  E
107
108     mnmx4(v[2], v[3], v[4], v[5]);
109
110     v[5] = SMEM(tx+1, ty+1);    // SE
111
112     mnmx3(v[3], v[4], v[5]);
113
114     // v[4] now contains the middle value.
115
116     // Guard against indicies out of range.
117     if(x >= 1 && x <= NX && y >= 1 && y <= NY)
118     {
119         OUT(x,y) = v[4];
120     }
121 }
122
123 """
124
125 SIZE_M = 16*2-1
126 SIZE_N = 16*2+1
127
128 gpu = SourceModule(kernel_cu)
129
130 medianFilter = gpu.get_function("medianFilter")
131
132 x = numpy.random.random((SIZE_M,SIZE_N)).astype(numpy.float32)
133
134 pylab.figure()
135 pylab.imshow(x, interpolation = "nearest", cmap = pylab.cm.gray_r)
136 pylab.title("before")
137 pylab.axis("tight")
138
139 y = numpy.zeros((SIZE_M,SIZE_N)).astype(numpy.float32)
140
141 grid_m = int(round(SIZE_M / 16.0 + 0.5))
142 grid_n = int(round(SIZE_N / 16.0 + 0.5))
143
144 print "grid = %dx%d" %(grid_m, grid_n)
145
146 medianFilter(
147         pycuda.driver.InOut(y),
148         pycuda.driver.In(x),
149         numpy.int32(SIZE_M),
150         numpy.int32(SIZE_N),
151         block=(16,16,1),
152         grid=(grid_m,grid_n))
153
154 for i in range(20):
155         x = numpy.array(y)
156
157         medianFilter(
158                 pycuda.driver.Out(y),
159                 pycuda.driver.In(x),
160                 block=(16,16,1),
161                 grid=(grid_m,grid_n))
162
163 pylab.figure()
164 pylab.imshow(y, interpolation = "nearest", cmap = pylab.cm.gray_r)
165 pylab.title("after")
166 pylab.axis("tight")
167
168 pylab.show()
```

MedianFilter (last edited 2010-11-29 17:54:41 by 66-146-167-66)