Source code for spark_rma.median_polish

#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2018 Eli Lilly and Company
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Carry out median polish on grouping of probes, to return a value for each
sample in that grouping.
"""

from __future__ import print_function, unicode_literals
import argparse
import logging
import sys

from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
# pylint: disable=no-name-in-module
# collect_list and col are not being found in path but are legitimate
from pyspark.sql.functions import udf, concat_ws, collect_list, \
    explode, log2, col
# pylint: enable=no-name-in-module
from pyspark.sql.types import StringType, ArrayType


[docs]def probe_summarization(grouped_values): """ Summarization step to be pickled by Spark as a UDF. Receives a groupings data in a list, unpacks it, performs median polish, calculates the expression values from the median polish matrix results, packs it back up, and return it to a new spark Dataframe. :param grouped_values: a list of strings because spark concatenated all the values into one string for each sample. Each item is a sample,probe, value format and all the rows in the input belong to a grouping key that spark handled (transcript_cluster or probeset) :return: a list of lists where each item is length two with (sample, value) """ # need to do all the imports here so they can be pickled and sent to # executors import pandas as pd import numpy as np def split_and_dedup(input_data): """ Build a matrix of probes as columns and samples as rows. There cannot be duplicate probe IDs within a sample here to pivot, but typically there will be many when summarizing to transcript level due to multiple probesets with shared probes that map to the same transcript. Keep these duplicates, but rename the probes while keeping the matrix shape. :param input_data: a list given by spark that groups all the sample, probe, value tuples for a single group/target (probeset or transcript). :return: pivoted data frame, or matrix of probes as columns and samples as rows. The whole matrix represents one target (probeset or transcript cluster). """ # we get our data, we need to unpack as sample, probe, value samples = [] probes = [] values = [] for each in input_data: sample, probe, value = each.split(',') samples.append(sample) probes.append(int(probe)) # was pass as string from spark values.append(float(value)) # passed as string from spark data = pd.DataFrame( {'SAMPLE': samples, 'PROBE': probes, 'VALUE': values}) # we may have duplicate probes, let's rename them but keep them matched # so we don't have duplicate indices in the pivoted data frame. sort # values, group by sample and within groups replace probe with row # number. data['PROBE'] = data.sort_values(['SAMPLE', 'PROBE', 'VALUE']).groupby( 'SAMPLE').cumcount() return data.pivot(index='SAMPLE', columns='PROBE', values='VALUE') def median_polish(matrix, max_iterations=10, eps=0.01): """ Median polish of a pandas data frame, row first then column. This will stop at convergence or max iterations, whichever comes first. Convergence is defined as the sum of all absolute values of residuals not changing more than the eps between iterations. :param matrix: pandas data frame with all values in cells organized by row x columns, for gene expression rows are samples/arrays and columns are probes. The whole matrix should be all the probes mapping to the same transcript cluster and/or probeset region. :param max_iterations: The maximum iterations before stopping if convergence is not met. Must be 1 or greater. :param eps: The tolerance for convergence. Set to 0.0 to get full convergence, but this may be costly or never converge. Should be between 0 and 1. :return: a data frame of residuals, column effect, row effect, grand/overall effect, the status of convergence """ # create row effect, initialize at 0.0 value for each sample in long # format row_effect = pd.Series(0.0, index=matrix.index.values, dtype=np.float64) # pylint: disable=no-member # create column effect: initialize at 0.0 value for each probe in long # format column_effect = pd.Series( 0.0, index=matrix.columns.values, dtype=np.float64 # pylint: disable=no-member ) grand_effect = float(0.0) sum_of_absolute_residuals = 0 converged = False if max_iterations < 1 or not isinstance(max_iterations, int): raise ValueError("max_iterations must be a positive, non-zero " "integer value") for i in range(max_iterations): # columns col_median = matrix.median(0) matrix = matrix.subtract(col_median, axis=1) column_effect = column_effect + col_median diff = row_effect.median(0) row_effect -= diff grand_effect += diff # rows row_median = matrix.median(1) matrix = matrix.subtract(row_median, axis=0) row_effect = row_effect + row_median diff = column_effect.median() column_effect -= diff grand_effect += diff # Convergence check # check if the sum of all the absolute values of residuals has # changed less than the eps between iterations or is 0. new_sar = np.absolute(matrix).sum().sum() if abs(new_sar - sum_of_absolute_residuals) <= eps * new_sar or \ new_sar == 0: converged = True logging.debug( "Convergence reached after %s iterations", i + 1) break sum_of_absolute_residuals = new_sar if not converged: logging.debug("No convergence, reached maximum iterations.") return matrix, column_effect, row_effect, grand_effect, converged def gene_expression(matrix): """ Use median polish to get values and calculate gene expression values. :param matrix: sample, probe, value matrix to median polish :type matrix: pd.DataFrame() :return: expression values in pandas data frame :rtype: pd.DataFrame() """ results = median_polish(matrix) row_effect = results[2] grand_effect = results[3] # expression is equal to the row effect + the overall effect expression = row_effect + grand_effect return expression data_frame = split_and_dedup(grouped_values) result = gene_expression(data_frame) index = result.index.values.tolist() # get the sample names result = result.values.tolist() # convert pandas df to list, # and coerce values into list of lists of [['sample', value], ] structure. result = list(zip(index, result)) return result
[docs]class Summary(object): """ Summarize gene expression values via median polish. """ def __init__(self, spark, input_data, num_samples, repartition_number, **kwargs): self.spark = spark self.input_data = input_data self.num_samples = num_samples self.repartition_number = repartition_number self.group_keys = kwargs.get('grouping')
[docs] def udaf(self, data): """ Apply median polish to groupBy keys and return value for each sample within that grouping. This is a hacked/workaround user-defined aggregate function (UDAF) that passes the grouped data to python to do median polish and return the result back to the dataframe. :returns: spark dataframe """ # register the medianpolish as a UDF medpol = udf(probe_summarization, ArrayType(ArrayType(StringType()))) # repartition by our grouping keys if self.group_keys not in [ ['TRANSCRIPT_CLUSTER'], ['PROBESET'] ]: raise Exception("Invalid grouping keys.") data = data.withColumnRenamed('NORMALIZED_INTENSITY_VALUE', 'VALUE') data = data.repartition(self.repartition_number, self.group_keys) # log 2 values data = data.withColumn('VALUE', log2(data['VALUE']).alias('VALUE')) # group the data while concatenating rest of columns into one value # so we can pass it to collect, one value(list) per row and a list of # lists for the whole grouping, so that we can give it to our UDF as # one item which returns back one item (array or arrays) data = data.withColumn( 'data', concat_ws(',', 'SAMPLE', 'PROBE', 'VALUE')) \ .groupBy(self.group_keys) \ .agg(collect_list('data') .alias('data')) \ .withColumn('data', medpol('data')) def gen_cols(other_cols): """ Create a list for select(). select() can take one list, or *args. generating the grouping keys as columns and adding other column selections to the same list. :param other_cols: list of other column selections :type other_cols: list :returns: single list of columns, expressions, etc. for select() """ cols = [col(s) for s in self.group_keys] cols += other_cols return cols # unpack the first level of nesting vertically, so each array in the # array is a new row (per sample) data = data.select(gen_cols([explode(data['data']).alias( "SAMPLEVALUE")])) # unpack the final nesting laterally, into two new columns data = data.select(gen_cols([ data['SAMPLEVALUE'].getItem(0).alias('SAMPLE'), data['SAMPLEVALUE'].getItem(1).alias("VALUE")])) data = data.repartition(int(self.num_samples)) return data
[docs] def summarize(self): """ Summarize results across samples with median polish within defined groups. """ result = self.udaf(self.input_data) return result
[docs]class TranscriptSummary(Summary): """Summarize probes with transcript cluster groupings""" def __init__(self, spark, input_data, num_samples, repartition_number): super(TranscriptSummary, self).__init__( spark, input_data, num_samples, repartition_number, grouping=['TRANSCRIPT_CLUSTER'])
[docs]class ProbesetSummary(Summary): """Summarize probes with probeset region grouping""" def __init__(self, spark, input_data, num_samples, repartition_number): super(ProbesetSummary, self).__init__( spark, input_data, num_samples, repartition_number, grouping=['PROBESET'])
[docs]def infer_grouping_and_summarize(spark, input_file, output_file, num_samples, repartition_number): """ Read the input file to infer grouping type and select appropriate summarization class. """ logging.info("Reading file at: %s", input_file) data_frame = spark.read.parquet(input_file) headers = data_frame.columns def matched(columns_to_check): """Check columns against file.""" return not bool(list(set(columns_to_check) - set(headers))) def inference(): """Determine grouping and return appropriate class""" if matched(['TRANSCRIPT_CLUSTER']): return TranscriptSummary elif matched(['PROBESET']): return ProbesetSummary else: raise KeyError('Bad Input, no valid headers found.') summarization = inference() summarization_object = summarization(spark, data_frame, num_samples, repartition_number) summarized_data = summarization_object.summarize() logging.info("Writing to file: %s", output_file) summarized_data.write.parquet(output_file)
[docs]def command_line(): """Collect and validate command line arguments.""" class MyParser(argparse.ArgumentParser): """ Override default behavior, print the whole help message for any CLI error. """ def error(self, message): print('error: {}\n'.format(message), file=sys.stderr) self.print_help() sys.exit(2) parser = MyParser(description=""" Summarization step of RMA using Median Polish. Median Polish summarizes probes within a group. That group can be probes with the same transcript cluster and/or probeset region. The input to this should be the parquet output from quantile normalization. The grouping is inferred from the input format as generated by the annotation during the background correction step. Specify the summarization type in the background correction step to determine the grouping choice here. The input headers from quantile normalization: - SAMPLE - PROBE - PROBESET or TRANSCRIPT_CLUSTER - NORMALIZED_INTENSITY_VALUE The output format: - PROBESET or TRANSCRIPT_CLUSTER - SAMPLE - VALUE """, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument('-v', '--verbose', help="Enable verbose logging", action="store_const", dest="loglevel", const=logging.DEBUG, default=logging.INFO) parser.add_argument('-i', '--input', help="Input path", required=True) parser.add_argument('-o', '--output', help="Output filename", default='expression.parquet') parser.add_argument('-ns', '--number_samples', help="Number of samples. Explicitly stating it here " "is much faster than making Spark count the " "records. This helps set partitions.", type=int, required=True) parser.add_argument('-rn', '--repartition_number', help="Number of partitions to use when running " "median polish. This determines how many " "groupings get collected into each task, " "which impacts processing time and memory " "consumption. The default is the number of " "samples if left unset.", type=int) arguments = parser.parse_args() if not arguments.repartition_number: arguments.repartition_number = arguments.number_samples return arguments
[docs]def main(): """ Collect command-line arguments and start spark session when using spark-submit. """ arguments = command_line() logging.basicConfig(level=arguments.loglevel, format='%(asctime)s %(name)-12s %(levelname)-8s %(' 'message)s') logging.info("Starting Spark...") spark_session = SparkSession.builder \ .config(conf=SparkConf() .setAppName("Median Polish")).getOrCreate() infer_grouping_and_summarize( spark=spark_session, input_file=arguments.input, num_samples=arguments.number_samples, output_file=arguments.output, repartition_number=arguments.repartition_number ) logging.info("Complete.") spark_session.stop()
if __name__ == "__main__": main()