Skip to content

[Nexus] Optimal Tile Matrix Refactor and Documentation #5697

@brockdyer03

Description

@brockdyer03

After a chat with @jtkrogel and subsequently with @prckent, I took a look at the function optimal_tilematrix() in structure.py, and I think it could use some improvement and documentation.

As I do not work in condensed matter very often, I'm not entirely sure what this function does, but as someone who writes a lot of Python, I've noticed that it is not written in a Pythonic fashion and I think there is much room for improvement.

The Code

To start, here is the entire section of code that is in structure.py that defines the variables and components required for optimal_tilematrix() to work (if I haven't missed anything):

opt_tm_matrices    = obj()
opt_tm_wig_indices = obj()

def trivial_filter(T):
    return True
#end def trival_filter

class MaskFilter(DevBase):
    def set(self,mask,dim=3):
        omask = np.array(mask)
        mask  = np.array(mask,dtype=bool)
        if mask.size==dim:
            mvec = mask.ravel()
            mask = np.empty((dim,dim),dtype=bool)
            i=0
            for mi in mvec:
                j=0
                for mj in mvec:
                    mask[i,j] = mi==mj
                    j+=1
                #end for
                i+=1
            #end for
        elif mask.shape!=(dim,dim):
            error('shape of mask array must be {0},{0}\nshape received: {1},{2}\nmask array received: {3}'.format(dim,mask.shape[0],mask.shape[1],omask),'optimal_tilematrix')
        #end if
        self.mask = not mask
    #end def set

    def __call__(self,T):
        return (T[self.mask]==0).all()
    #end def __call__
#end class MaskFilter
mask_filter = MaskFilter()
            

def optimal_tilematrix(axes,volfac,dn=1,tol=1e-3,filter=trivial_filter,mask=None,nc=5,Tref=None):
    if mask is not None:
        mask_filter.set(mask)
        filter = mask_filter
    #end if
    dim = 3
    if isinstance(axes,Structure):
        axes = axes.axes
    else:
        axes = np.array(axes,dtype=float)
    #end if
    if not isinstance(volfac,int):
        volfac = int(np.around(volfac))
    #end if
    volume = np.abs(det(axes))*volfac
    axinv  = inv(axes)
    cube   = volume**(1./3)*np.identity(dim)
    if Tref is None:
        Tref = np.array(np.around(dot(cube,axinv)),dtype=int)
    else:
        Tref = np.asarray(Tref)
    #end if
    # calculate and store all tiling matrix variations
    if dn not in opt_tm_matrices:
        mats = []
        rng = tuple(range(-dn,dn+1))
        for n1 in rng:
            for n2 in rng:
                for n3 in rng:
                    for n4 in rng:
                        for n5 in rng:
                            for n6 in rng:
                                for n7 in rng:
                                    for n8 in rng:
                                        for n9 in rng:
                                            mats.append((n1,n2,n3,n4,n5,n6,n7,n8,n9))
                                        #end for
                                    #end for
                                #end for
                            #end for
                        #end for
                    #end for
                #end for
            #end for
        #end for
        mats = np.array(mats,dtype=int)
        mats.shape = (2*dn+1)**(dim*dim),dim,dim
        opt_tm_matrices[dn] = mats
    else:
        mats = opt_tm_matrices[dn]
    #end if
    # calculate and store all wigner image indices
    if nc not in opt_tm_wig_indices:
        inds = []
        rng = tuple(range(-nc,nc+1))
        for k in rng:
            for j in rng:
                for i in rng:
                    if i!=0 or j!=0 or k!=0:
                        inds.append((i,j,k))
                    #end if
                #end for
            #end for
        #end for
        inds = np.array(inds,dtype=int)
        opt_tm_wig_indices[nc] = inds
    else:
        inds = opt_tm_wig_indices[nc]
    #end if
    # track counts of tiling matrices
    ntilings        = len(mats)
    nequiv_volume   = 0
    nfilter         = 0
    nequiv_inscribe = 0
    nequiv_wigner   = 0
    nequiv_cubicity = 0
    nequiv_shape    = 0
    # try a faster search for cells w/ target volume
    det_inds_p = [
        [(0,0),(1,1),(2,2)],
        [(0,1),(1,2),(2,0)],
        [(0,2),(1,0),(2,1)]
        ]
    det_inds_m = [
        [(0,0),(1,2),(2,1)],
        [(0,1),(1,0),(2,2)],
        [(0,2),(1,1),(2,0)]
        ]
    volfacs = np.zeros((len(mats),),dtype=int)
    for (i1,j1),(i2,j2),(i3,j3) in det_inds_p:
        volfacs += (Tref[i1,j1]+mats[:,i1,j1])*(Tref[i2,j2]+mats[:,i2,j2])*(Tref[i3,j3]+mats[:,i3,j3])
    #end for
    for (i1,j1),(i2,j2),(i3,j3) in det_inds_m:
        volfacs -= (Tref[i1,j1]+mats[:,i1,j1])*(Tref[i2,j2]+mats[:,i2,j2])*(Tref[i3,j3]+mats[:,i3,j3])
    #end for
    Tmats = mats[np.abs(volfacs)==volfac]
    nequiv_volume = len(Tmats)    
    # find the set of cells with maximal inscribing radius
    inscribe_tilings = []
    rmax = -1e99
    for mat in Tmats:
        T = Tref + mat
        if filter(T):
            nfilter+=1
            Taxes = dot(T,axes)
            rc1 = norm(cross(Taxes[0],Taxes[1]))
            rc2 = norm(cross(Taxes[1],Taxes[2]))
            rc3 = norm(cross(Taxes[2],Taxes[0]))
            r   = 0.5*volume/max(rc1,rc2,rc3) # inscribing radius
            if r>rmax or np.abs(r-rmax)<tol:
                inscribe_tilings.append((r,T,Taxes))
                rmax = r
            #end if
        #end if
    #end for
    # find the set of cells w/ maximal wigner radius out of the inscribing set
    wigner_tilings = []
    rwmax = -1e99
    for r,T,Taxes in inscribe_tilings:
        if np.abs(r-rmax)<tol:
            nequiv_inscribe+=1
            rw = 1e99
            for ind in inds:
                rw = min(rw,0.5*norm(dot(ind,Taxes)))
            #end for
            if rw>rwmax or np.abs(rw-rwmax)<tol:
                wigner_tilings.append((rw,T,Taxes))
                rwmax = rw
            #end if
        #end if
    #end for
    # find the set of cells w/ maximal cubicity
    # (minimum cube_deviation)
    cube_tilings = []            
    cmin = 1e99
    for rw,T,Ta in wigner_tilings:
        if np.abs(rw-rwmax)<tol:
            nequiv_wigner+=1
            dc = volume**(1./3)*sqrt(2.)
            d1 = np.abs(norm(Ta[0]+Ta[1])-dc)
            d2 = np.abs(norm(Ta[1]+Ta[2])-dc)
            d3 = np.abs(norm(Ta[2]+Ta[0])-dc)
            d4 = np.abs(norm(Ta[0]-Ta[1])-dc)
            d5 = np.abs(norm(Ta[1]-Ta[2])-dc)
            d6 = np.abs(norm(Ta[2]-Ta[0])-dc)
            cube_dev = (d1+d2+d3+d4+d5+d6)/(6*dc)
            if cube_dev<cmin or np.abs(cube_dev-cmin)<tol:
                cube_tilings.append((cube_dev,rw,T,Ta))
                cmin = cube_dev
            #end if
        #end if
    #end for
    # prioritize selection by "shapeliness" of tiling matrix
    #   prioritize positive diagonal elements
    #   penalize off-diagonal elements
    #   penalize negative off-diagonal elements
    shapely_tilings = []
    smax = -1e99
    for cd,rw,T,Taxes in cube_tilings:
        if np.abs(cd-cmin)<tol:
            nequiv_cubicity+=1
            d = np.diag(T)
            o = (T-np.diag(d)).ravel()
            s = np.sign(d).sum()-(np.abs(o)>0).sum()-(o<0).sum()
            if s>smax or np.abs(s-smax)<tol:
                shapely_tilings.append((s,rw,T,Taxes))
                smax = s
            #end if
        #end if
    #end for
    # prioritize selection by symmetry of tiling matrix
    ropt   = -1e99
    Topt   = None
    Taxopt = None
    diagonal      = []
    symmetric     = []
    antisymmetric = []
    other         = []
    for s,rw,T,Taxes in shapely_tilings:
        if np.abs(s-smax)<tol:
            nequiv_shape+=1
            Td = np.diag(np.diag(T))
            if np.abs(Td-T).sum()==0:
                diagonal.append((rw,T,Taxes))
            elif np.abs(T.T-T).sum()==0:
                symmetric.append((rw,T,Taxes))
            elif np.abs(T.T+T-2*Td).sum()==0:
                antisymmetric.append((rw,T,Taxes))
            else:
                other.append((rw,T,Taxes))
            #end if
        #end if
    #end for
    s = 1
    if len(diagonal)>0:
        cells = diagonal
    elif len(symmetric)>0:
        cells = symmetric
    elif len(antisymmetric)>0:
        cells = antisymmetric
        s = -1
    elif len(other)>0:
        cells = other
    else:
        cells = []
    #end if
    skew_min = 1e99
    if len(cells)>0:
        for rw,T,Taxes in cells:
            Td = np.diag(np.diag(T))
            skew = np.abs(T.T-s*T-(1-s)*Td).sum()
            if skew<skew_min:
                ropt = rw
                Topt = T
                Taxopt = Taxes
                skew_min = skew
            #end if
        #end for
    #end if
    if Taxopt is None:
        error('optimal tilematrix for volfac={0} not found with tolerance {1}\ndifference range (dn): {2}\ntiling matrices searched: {3}\ncells with target volume: {4}\ncells that passed the filter: {5}\ncells with equivalent inscribing radius: {6}\ncells with equivalent wigner radius: {7}\ncells with equivalent cubicity: {8}\nmatrices with equivalent shapeliness: {9}\nplease try again with dn={10}'.format(volfac,tol,dn,ntilings,nequiv_volume,nfilter,nequiv_inscribe,nequiv_wigner,nequiv_cubicity,nequiv_shape,dn+1))
    #end if
    if det(Taxopt)<0:
        Topt = -Topt
    #end if
    return Topt,ropt
#end def optimal_tilematrix

Quirks

There are a few things about this function that immediately stand out to me.

  1. The variables opt_tm_matrices and opt_tm_wig_indices are defined in the global scope, but only ever referenced inside of optimal_tilematrix.
  2. The class MaskFilter is defined but only referenced in the line immediately after it where an instance is created, again in the global scope. This variable is only referenced inside of optimal_tilematrix.
  3. The function trivial_filter() is also defined globally, but it's only call is inside optimal_tilematrix(), and more importantly no matter what is passed into it it will return True. Seems like this should just be a default variable.

Beyond these, I have a sneaking suspicion there are more things that could be done better (e.g. the 9-deep nested for-loops), but I don't know how this function works yet so I won't comment on those.

To Do

The current list of things that I think should be done (and I will likely take care of at least the first one of these) is as follows:

  1. All variables declared in the global scope but only referenced inside the function should be moved into the function scope.
  2. Documentation for the purpose of the function should be added, as well as additional documentation for the use of the function.
  3. I have a gut feeling that there might be a more optimal way to perform the checks in this function, just on cursory examination of the code, so I think once the documentation is added for it in its current state the algorithm should get reviewed to make sure there aren't any additional ways that this function can be improved.

Metadata

Metadata

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions