Basic Ray Tracing

!pip install pillow --quiet
from qmcpy import *
from numpy import *
from time import time
from PIL import Image
from matplotlib import pyplot
from threading import Thread
%matplotlib inline
# constants
EPS = 1e-8 # numerical precision error tolerance
norm = lambda v: sqrt(dot(v,v))
e = array([0,0,0],dtype=float) # eye location at the orgin
class Camera(object):
    """ An object to render the scene. """
    def __init__(self, ax, scene, px=8, parallel_x_blocks=2, parallel_y_blocks=2, image_angle=90, image_dist=1):
        """
        Args:
            ax (axes): matplotlib ax to plot image on
            scene (int): object with scene.render_px(p) method. See examples later in the notebook.
            px (int): number of pixels in height and width.
                Resolution = px*px.
            parallel_x_blocks (int): number of cuts along the x axis in which to make parallel.
            parallel_y_blocks (int): number of cuts along the y axis in which to make parallel.
                parallel_x/y_blocks must be divisors of px.
                Number of threads = parallel_y_blocks * parallel_x_blocks.
        """
        self.ax = ax
        self.scene = scene
        self.px = px
        self.fl = image_dist # distance from eye to image
        self.i_hw = tan(image_angle/2)*self.fl # half width of the image
        self.px_hw = self.i_hw/px # half width of a pixel
        # set parallelization constants
        self.p_xb = parallel_x_blocks; self.p_yb = parallel_y_blocks
        bs_x = px/self.p_xb; bs_y = px/self.p_yb # block size for x,y
        if bs_x%1!=0 or bs_y%1!=0: raise Exception('parallel_x/y_blocks must divide px')
        self.bs_x = int(bs_x); self.bs_y = int(bs_y)
    def render(self):
        """ Render the image. """
        t0 = time() # start a timer
        img = Image.new('RGB',(self.px,self.px),(0,0,0))
        if self.p_xb==1 and self.p_yb==1:
            # use non-parallel processing (helpful for debugging)
            self.block_render(img,0,self.px,0,self.px)
        else:
            # parallel processing
            threads = [None]*(self.p_xb*self.p_yb)
            i_t = 0 # thread index
            for xb in range(0,self.px,self.bs_x):
                for yb in range(0,self.px,self.bs_y):
                    threads[i_t] = Thread(target=self.block_render, args=(img,xb,xb+self.bs_x,yb,yb+self.bs_y))
                    threads[i_t].start() # start threads
                    i_t += 1
            for i in range(len(threads)): threads[i].join() # wait for all threads to complete
        self.ax.axis('off')
        self.ax.imshow(asarray(img))
        print('Render took %.1f seconds'%(time()-t0))
    def block_render(self, img, px_x_start, px_x_end, px_y_start, px_y_end):
        """
        Render a block of the image.

        Args:
            img (PIL.Image): the image to color pixels of.
            px_x_start (int): x index of pixel to start rendering at.
            px_x_end (int): x index of pixel to end rendering at.
            px_y_start (int): y index of pixel to start rendering at.
            px_y_end (int): y index of pixel to start rendering at.
        """
        for p_x in range(px_x_start,px_x_end):
            for p_y in range(px_y_start,px_y_end):
                p = array([-self.i_hw+2*self.i_hw*p_x/self.px,-self.i_hw+2*self.i_hw*p_y/self.px,self.fl])
                color = self.scene.render_px(p)
                img.putpixel((p_x,self.px-p_y-1),color)
class Plane(object):
    def __init__(self, norm_axis, position, color):
        """
        Args:
            norm_axis (str): either 'x', 'y', or 'z'.
            position (str): constant position of plane along the norm_axis.
            color (tuple): length 3 tuple of rgb values.
        """
        dim_dict = {'x':0,'y':1,'z':2}
        self.d = dim_dict[norm_axis]
        self.pos = position # self.norm_axis coordinate of the floor
        self.color = color
    def hit(self, o, u):
        """
        Test if the beam o+tu hits the plane.

        Args:
            o (ndarray): length 3 origin point.
            u (ndarray): length 3 unit vector.

        Returns:
            tuple:
                - hit (bool): was an object hit?
                - hit_p (ndarray): point where beam intersects object.
                - color (tuple): length 3 tuple rgb value.
        """
        k = u[self.d]
        if k != 0:
            t = (self.pos - o[self.d]) / u[self.d]
            if t > EPS:
                return True, o+t*u # ray intersects the plane
        return False, None # ray misses the plane
    def normal(self, o, u):
        """
        Get the unit normal vector to the plane at this point.

        Args:
            o (ndarray): length 3 origin point.
            u (ndarray): length 3 unit vector in direction of light.


        Returns:
            ndarray: length three unit normal vector.
        """
        v = array([0,0,0])
        v[self.d] = 1
        if dot(v,u)<0:
            v[self.d] = -1
        return v
class Ball(object):
    def __init__(self, center, radius, color):
        """
        Args:
            center (ndarray): length 3 center position of the ball.
            radius (float): radius of the ball.
            color (tuple): length 3 tuple of rgb values.
        """
        self.c = center
        self.r = radius
        self.color = color
    def hit(self, o, u):
        """
        Test if the beam o+tu hits the ball.

        Args:
            o (ndarray): length 3 origin point.
            u (ndarray): length 3 unit vector.

        Returns:
            tuple:
                - hit (bool): was an object was hit?
                - hit_p (ndarray): point where beam intersects object.
                - color (tuple): length 3 tuple rgb value.
        """
        q = o - self.c
        a = dot(u,u)
        b = 2*dot(u,q)
        c = dot(q,q) - self.r**2
        d = b**2 - 4*a*c
        if d > 0: # ray intersects sphere
            tt = (-b + array([1,-1],dtype=float)*sqrt(d)) / (2**a)
            tt = tt[tt>EPS] # only want intersection from rays moving in positive direction
            if len(tt) >= 1: # at least one positive intersection
                # beam going forward intersects ball
                t = min(tt)
                return True, o+t*u
        return False, None # ray does not intersect sphere or only intersects in opposite direction
    def normal(self, o, u):
        """
        Get the unit normal vector to the sphere at thi point.

        Args:
            o (ndarray): length 3 origin point.
            u (ndarray): length 3 unit vector in direction of light.

        Returns:
            ndarray: length three unit normal vector.
        """
        v = (o-self.c)
        v_u = v/norm(v)
        return v_u
class PointLight(object):
    """ A lamp that is a point and emits. light in all directions. """
    def __init__(self, position, intensity):
        """
        Args:
            position (ndarray): length 3 coordinate of light position.
            intensity (float): intensity of the light, between 0 and 1.
        """
        self.p = position
        self.i = intensity
class CustomScene(object):
    def __init__(self, objs, light, n, d, mc_type):
        self.objs = objs
        self.light = light
        self.black = array([0,0,0],dtype=float) # the color black
        self.n = n
        self.d = d
        self.mc_type = mc_type
        # generate constant samples to be used for every ray tracing event
        if self.mc_type == 'IID':
            self.pts = IIDStdUniform(2*self.d).gen_samples(self.n)
        elif self.mc_type == 'SOBOL':
            self.pts = Sobol(2*self.d,graycode=True).gen_samples(self.n)
        else:
            raise Exception("mc_type must be IID or Sobol")
    def find_closest_obj(self,o,v):
        """
        Find the closest object to point o heading in direction v

        Args:
            o (ndarray): length 3 coordinate of point we will try and find closest object to

        Returns:
            tuple:
                hit (bool): weather any objects were hit.
                hit_p (ndarray): length 3 coordinates of where obj was hit.
                hit_dist (float): distance from hit_p to o.
                hit_obj (object): the object that was hit.
        """
        hit,hit_p,hit_dist,hit_obj = False,None,inf,None
        for obj in self.objs:
            obj_hit,obj_p = obj.hit(o,v)
            if obj_hit:
                v2 = obj_p-o # vector from o to object position
                obj_dist = sqrt(dot(v2,v2))
                if obj_dist < hit_dist:
                    hit,hit_p,hit_dist,hit_obj = True,obj_p,obj_dist,obj
        return hit,hit_p,hit_dist,hit_obj
    def get_obj_color(self,obj,p,l):
        """
        Get the objects color at point p with light in direction l.

        Args:
            obj (object): object on which p lies
            p (ndarray): length 3 coordinate of point on the object
            l (ndarray): length 3 vector of direction from p to light source

        Returns:
            ndarray: length 3 RGB color
        """
        n_v = obj.normal(p,l) # normal vector to obj at point p
        color = obj.color*self.light.i*dot(n_v,l) / (norm(n_v)*norm(l))
        return color
    def beam_r(self,o,v,n,d,pts,nidx,didx):
        """
        Recursive (Quasi-)Monte Carlo simulation of a light beam

        Args:
            o (ndarray): length 3 coordinate of current light beam position
            v (ndarray): length 3 vector of light beam direction
            n (ndarray): number of rays to cast when it cannot find light directly
            d (int): remaining bounces before beam gives up
            pts (ndarray): n samples x d dimension ndarray of samples generated by QMCPy
            nidx (int): 2*(index of the beam)
            didx (int): index of the sample
        """
        hit,hit_p,hit_dist,hit_obj = self.find_closest_obj(o,v)
        if hit: # an object was hit
            l = self.light.p-hit_p # vector from position where beam hit to the lamp
            l_dist = norm(l) # distance from hit location to lamp
            l_u = l/l_dist # unit vector of l
            itw,itw_p,itw_dist,itw_obj = self.find_closest_obj(hit_p,l_u) # find any object in the way
            if itw and itw_dist<= l_dist: # object between hit object and the lamp
                if d==0:
                    # no remaining bounces --> return black (give up)
                    return self.black
                else:
                    # beam has remaining bounces
                    color_total = self.black.copy()
                    for i in range(n):
                        theta_0 = 2*pi*pts[nidx+i,didx]
                        theta_1 = 2*pi*pts[nidx+i,didx+1]
                        x = sin(theta_0)*sin(theta_1)
                        y = sin(theta_0)*cos(theta_0)
                        z = sin(theta_1)
                        v_rand = array([x,y,z],dtype=float) # random direction
                        ho_n = hit_obj.normal(hit_p,l_u)
                        if dot(v_rand,ho_n) < 0: v_rand = -v_rand # flip direction to correct hemisphere
                        obj_color = self.get_obj_color(hit_obj,hit_p,l_u)
                        color_total += obj_color*self.beam_r(hit_p,v_rand,n=1,d=d-1,pts=pts,nidx=nidx+i,didx=didx+2)
                    return color_total/n # take the average of many simulations
            else: # nothin between the object and the light source
                # get the color based on point, normal to obj, and direction to light
                return self.get_obj_color(hit_obj,hit_p,l_u)
        return self.black # nothing hit --> return black '
    def render_px(self,p):
        """
        Get pixel value for ball-lamp-floor scene

        Args:
            p (ndarray): length 3 array coordinates of center of pixel to render
        """
        u = (p-e)/norm(p-e) # unit vector in direction of eye to pixel
        color_0_1 = self.beam_r(e,u,n=self.n,d=self.d,pts=self.pts,nidx=0,didx=0)
        color = (color_0_1*256).astype(int)
        return color[0],color[1],color[2]
# create a scene
objs = [
    Plane(norm_axis='y', position=-50, color=array([.75,.75,.75],dtype=float)), # floor
    Plane(norm_axis='y', position=50,  color=array([.75,.75,.75],dtype=float)), # ceiling
    Plane(norm_axis='x', position=50,  color=array([.75,.75,.75],dtype=float)), # right wall
    Plane(norm_axis='x', position=-50, color=array([.75,.75,.75],dtype=float)), # left wall
    Plane(norm_axis='z', position=150,  color=array([.75,.75,.75],dtype=float)), # back wall
    Ball(center=array([-25,25,75],dtype=float), radius=20, color=array([1,0,0],dtype=float)), # ball
    Ball(center=array([25,25,75],dtype=float),  radius=20, color=array([0,1,0],dtype=float)), # ball
    Ball(center=array([0,-25,75],dtype=float),  radius=20, color=array([0,0,1],dtype=float)), # ball
]
light = PointLight(position=array([0,25,0],dtype=float), intensity=1)
# parameters
n = 16 # number of beams
d = 8 # max bounces of any given beam
px = 128
# render image
fig,ax = pyplot.subplots(ncols=2,nrows=1,figsize=(20,10))# render scene
# IID (MC)
scene = CustomScene(objs,light,n,d,mc_type='IID')
camera = Camera(ax[0], scene, px=px, parallel_x_blocks=1, parallel_y_blocks=1, image_angle=pi/2, image_dist=1)
camera.render()
# Sobol (QMC)
scene = CustomScene(objs,light,n,d,mc_type='SOBOL')
camera = Camera(ax[1], scene, px=px, parallel_x_blocks=1, parallel_y_blocks=1, image_angle=pi/2, image_dist=1)
camera.render()
Render took 4.6 seconds
Render took 5.1 seconds
../_images/ray_tracing_9_1.png